Splitting data
After loading data, splitting it is a crucial step. This recipe will explain why we need to split data, as well as how to do it.
Getting ready
Why do we need to split data? An ML model is quite like a student.
You provide a student with many lectures and exercises, with or without the answers. But more often than not, students are evaluated on a completely new problem. To make sure they fully understand the concepts and methods, they not only learn the exercises and solutions – they also understand the underlying concepts.
An ML model is no different: you train the model on training data and then evaluate it on test data. This way, you make sure the model fully understands the task and generalizes well to new, unseen data.
So, the dataset is usually split into train and test sets:
- The train set must be as large as possible to give as many samples as possible to the model
- The test set must be large enough to be statistically significant in evaluating the model
Typical splits can be anywhere between 80% to 20% for rather small datasets (for example, hundreds of samples), and 99% to 1% for very large datasets (for example, millions of samples and more).
For this recipe and the others in this chapter, it is assumed that the code has been executed in the same notebook as the previous recipe since each recipe reuses the code from the previous ones.
How to do it…
Here are the steps to try out this recipe:
- You can split the data rather easily with scikit-learn and the
train_test_split()
function:# Import the train_test_split function
from sklearn.model_selection import train_test_split
# Split the data
X_train, X_test, y_train, y_test = train_test_split(
df.drop(columns=['Survived']), df['Survived'],
test_size=0.2, stratify=df['Survived'],
random_state=0)
This function uses the following parameters as input:
X
: All columns but the'
Survived'
labely
: The'Survived'
label columntest_size
: This is0.2
, which means the training size will be 80%stratify
: This specifies the'Survived'
column to ensure the same label balance is used in both splitsrandom_state
:0
is any integer to ensure reproducibility
It returns the following outputs:
X_train
: The train split ofX
X_test
: The test split ofX
y_train
: The training split ofy
, associated withX_train
y_test
: The test split ofy
, associated withX_test
Note
The stratify
option is not mandatory but can be critical to ensure a balanced split of any qualitative feature, not just the labels, as is the case with imbalanced data.
This split should be done as early as possible when performing data processing so that you avoid any potential data leakage. From now on, all the preprocessing will be computed on the train set, and only then applied to the test set, in agreement with Figure 2.2.
See also
See the official documentation for the train_test_split
function: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html.