Various algorithms can be used for text classification. You can build a classifier in scikit using the following code:
from sklearn.linear_model import LogisticRegression as LR
from sklearn.pipeline import Pipeline
Let's dissect the preceding code, line by line.
The initial two lines are simple imports. We import the fairly well-known Logistic Regression model and rename the import LR. The next is a pipeline import:
"Sequentially apply a list of transforms and a final estimator. Intermediate steps of the pipeline must be "transforms", that is, they must implement fit and transform methods. The final estimator only needs to implement fit."
Scikit-learn pipelines are, logistically, lists of operations that are applied, one after another. First, we applied the two operations we have already seen: CountVectorizer() and TfidfTransformer(). This was followed by LR(). The pipeline was created with Pipeline(...), but hasn't been executed. It is only executed when we call the fit() function from the Pipeline object:
text_lr_clf = Pipeline([('vect', CountVectorizer()), ('tfidf', TfidfTransformer()), ('clf',LR())])
text_lr_clf = text_lr_clf.fit(twenty_train.data, twenty_train.target)
When this is called, it calls the transform function of all but the last object. For the last object – our Logistic Regression classifier – its fit() function is called. These transforms and classifiers are also referred to as estimators:
"All estimators in a pipeline, except the last one, must be transformers (that is, they must have a transform method). The last estimator may be any type (transformer, classifier, and so on)."
Let's calculate the accuracy of this model on the test data. For calculating the means on a large number of values, we will be using a scientific library called numpy:
import numpy as np
lr_predicted = text_lr_clf.predict(twenty_test.data)
lr_clf_accuracy = np.mean(lr_predicted == twenty_test.target) * 100.
print(f'Test Accuracy is {lr_clf_accuracy}')
This prints out the following output:
Test Accuracy is 82.79341476367499
We used the LR default parameters here. We can later optimize these using GridSearch or RandomSearch to improve the accuracy even more.
If you're going to remember only one thing from this section, remember to try a linear model such as logistic regression. They are often quite good for sparse high-dimensional data such as text, bag-of-words, or TF-IDF.
In addition to accuracy, it is useful to understand which categories of text are being confused for which other categories. We will call this a confusion matrix.
The following code uses the same variables we used to calculate the test accuracy for finding out the confusion matrix:
from sklearn.metrics import confusion_matrix
cf = confusion_matrix(y_true=twenty_test.target, y_pred=lr_predicted)
print(cf)
This prints a giant list of numbers which is not very interpretable. Let's try pretty printing this by using the print-json hack:
import json
print(json.dumps(cf.tolist(), indent=2))
This returns the following code:
[
[
236,
2,
0,
0,
1,
1,
3,
0,
3,
3,
1,
1,
2,
9,
2,
35,
3,
4,
1,
12
],
...
[
38,
4,
0,
0,
0,
0,
4,
0,
0,
2,
2,
0,
0,
8,
3,
48,
17,
2,
9,
114
]
]
This is slightly better. We now understand that this is a 20 × 20 grid of numbers. However, interpreting these numbers is a tedious task unless we can bring some visualization into this game. Let's do that next:
# this line ensures that the plot is rendered inside the Jupyter we used for testing this code
%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))
ax = sns.heatmap(cf, annot=True, fmt="d",linewidths=.5, center = 90, vmax = 200)
# plt.show() # optional, un-comment if the plot does not show
This gives us the following amazing plot:
This plot highlights information of interest to us in different color schemes. For instance, the light diagonal from the lupper-left corner to the lower-right corner shows everything we got right. The other grids are darker-colored if we confused those more. For instance, 97 samples of one class got wrongly tagged, which is quickly visible by the dark black color in row 18 and column 16.
We will dive deeper into both parts of this section – model interpretation and data visualization – in slightly more detail later in this book.