Training a model
Once data has been fully cleaned and prepared, it is fairly easy to train a model thanks to scikit-learn. In this recipe, before training a logistic regression model on the Titanic dataset, we will quickly recap the ML paradigm and the different types of ML we can use.
Getting ready
If you were asked how to differentiate a car from a truck, you may be tempted to provide a list of rules, such as the number of wheels, size, weight, and so on. By doing so, you would be able to provide a set of explicit rules that would allow anyone to identify a car and a truck as different types of vehicles.
Traditional programming is not so different. While developing algorithms, programmers often build explicit rules, which allow them to map from data input (for example, a vehicle) to answers (for example, a car). We can summarize this paradigm as data + rules = answers.
If we were to train an ML model to discriminate cars from trucks, we would use another strategy: we would feed an ML algorithm with many pieces of data and their associated answers, expecting the model to learn to correct rules by itself. This is a different approach that can be summarized as data + answers = rules. This paradigm difference is summarized in Figure 2.4. As little as it might look to ML practitioners, it changes everything in terms of regularization:
Figure 2.4 – Comparing traditional programming with ML algorithms
Regularizing traditional algorithms is conceptually straightforward. For example, what if the rules for defining a truck overlap with the bus definition? If so, we can add the fact that buses have lots of windows.
Regularization in ML is intrinsically implicit. What if the model in this case does not discriminate between buses and trucks?
- Should we add more data?
- Is the model complex enough to capture such a difference?
- Is it underfitting or overfitting?
This fundamental property of ML makes regularization complex.
ML can be applied to many tasks. Anyone who uses ML knows there is not just one type of ML model.
Arguably, most ML models fall into three main categories:
- Supervised learning
- Unsupervised learning
- Reinforcement learning
As is usually the case for categories, the landscape is more complex, with sub-categories and methods overlapping several categories. But this is beyond the scope of this book.
This book will focus on regularization for supervised learning. In supervised learning, the problem is usually quite easy to specify: we have input features, X (for example, apartment surface), and labels, y (for example, apartment price). The goal is to train a model so that it’s robust enough to predict y, given X.
The two major types of ML are classification and regression:
- Classification: The labels are made of qualitative data. For example, the task is predicting between two or more classes such as car, bus, and truck.
- Regression: The labels are made of quantitative data. For example, the task is predicting an actual value, such as an apartment price.
Again, the line can be blurry; some tasks can be solved with classification while the labels are quantitative data, while others tasks can be both classification and regression ones. See Figure 2.5:
Figure 2.5 – Regularization versus classification
How to do it…
Assuming we want to train a logistic regression model (which will be explained properly in the next chapter), the scikit-learn library provides the LogisticRegression
class, along with the fit()
and predict()
methods. Let’s learn how to use it:
- Import the
LogisticRegression
class:from sklearn.linear_model import LogisticRegression
- Instantiate a
LogisticRegression
object:# Instantiate the model
lr = LogisticRegression()
- Fit the model on the train set:
# Fit on the training data
lr.fit(X_train, y_train)
- Optionally, compute predictions by using that model on the test set:
# Compute and store predictions on the test data
y_pred = lr.predict(X_test)
See also
Even though more details will be provided in the next chapter, you might be interested in looking at the documentation of the LogisticRegression
class: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html.