Tracking model metrics
The default metric for the text classification model in the PyTorch lightning-flash
package is Accuracy. If we want to change the metric to F1 score (a harmonic mean of precision and recall), which is a very common metric for measuring a classifier's performance, then we need to change the configuration of the classifier model before we start the model training process. Let's learn how to make this change and then use MLflow's non-auto-logging API to log the metrics:
- When defining the classifier variable, instead of using the default metric, we will pass a metric function called
torchmetrics.F1
as a variable, as follows:classifier_model = TextClassifier(backbone="prajjwal1/bert-tiny", num_classes=datamodule.num_classes, metrics=torchmetrics.F1(datamodule.num_classes))
This uses the built-in metrics function of torchmetrics
, the F1
module, along with the number of classes in the data we need to classify as a parameter. This...