Now, we are going to perform the same procedure as before, except that we will reset, regroup, and try a new algorithm: K-Nearest Neighbors (KNN).
Putting it all together
How to do it...
- Start by importing the model from sklearn, followed by a balanced split:
from sklearn.neighbors import KNeighborsClassifier
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state = 0)
The random_state parameter fixes the random_seed in the function train_test_split. In the preceding example, the random_state is set to zero and can be set to any integer.
- Construct two different KNN models by varying the n_neighbors parameter. Observe that the number of folds is now 10. Tenfold cross-validation is common in the machine learning community, particularly in data science competitions:
from sklearn.model_selection import cross_val_score
knn_3_clf = KNeighborsClassifier(n_neighbors = 3)
knn_5_clf = KNeighborsClassifier(n_neighbors = 5)
knn_3_scores = cross_val_score(knn_3_clf, X_train, y_train, cv=10)
knn_5_scores = cross_val_score(knn_5_clf, X_train, y_train, cv=10)
- Score and print out the scores for selection:
print "knn_3 mean scores: ", knn_3_scores.mean(), "knn_3 std: ",knn_3_scores.std()
print "knn_5 mean scores: ", knn_5_scores.mean(), " knn_5 std: ",knn_5_scores.std()
knn_3 mean scores: 0.798333333333 knn_3 std: 0.0908142181722
knn_5 mean scores: 0.806666666667 knn_5 std: 0.0559320575496
Both nearest neighbor types score similarly, yet the KNN with parameter n_neighbors = 5 is a bit more stable. This is an example of hyperparameter optimization which we will examine closely throughout the book.
There's more...
You could have just as easily run a simple loop to score the function more quickly:
all_scores = []
for n_neighbors in range(3,9,1):
knn_clf = KNeighborsClassifier(n_neighbors = n_neighbors)
all_scores.append((n_neighbors, cross_val_score(knn_clf, X_train, y_train, cv=10).mean()))
sorted(all_scores, key = lambda x:x[1], reverse = True)
Its output suggests that n_neighbors = 4 is a good choice:
[(4, 0.85111111111111115),
(7, 0.82611111111111113),
(6, 0.82333333333333347),
(5, 0.80666666666666664),
(3, 0.79833333333333334),
(8, 0.79833333333333334)]