scikit-learn estimators
Estimators are scikit-learn's
abstraction, allowing for the standardized implementation of a large number of classification algorithms. Estimators are used for classification. Estimators have the following two main functions:
fit()
: This performs the training of the algorithm and sets internal parameters. It takes two inputs, the training sample dataset and the corresponding classes for those samples.predict()
: This predicts the class of the testing samples that is given as input. This function returns an array with the predictions of each input testing sample.
Most scikit-learn
estimators use the NumPy
arrays or a related format for input and output.
There are a large number of estimators in scikit-learn. These include support vector machines (SVM), random
forests, and neural networks. Many of these algorithms will be used in later chapters. In this chapter, we will use a different estimator from scikit-learn
: nearest neighbor.
Note
For this chapter, you will...