Using integrated gradients to aid in understanding predictions
At the time of writing, two packages provide easy-to-use classes and methods to compute integrated gradients, which are the captum
and shap
libraries. In this tutorial, we will be using the captum
library. The captum
library supports models from TensorFlow and PyTorch. We will be using PyTorch here. In this tutorial, we will be working on explaining a SoTA transformer model called DeBERTA on the task of text sentiment multiclass classification. Let’s go through the use case step by step:
- First, let’s import the necessary libraries and methods:
from transformers import ( DebertaForSequenceClassification, EvalPrediction, DebertaConfig, DebertaTokenizer, Trainer, TrainingArguments, IntervalStrategy, EarlyStoppingCallback...