Training a Temporal Fusion Transformer with GluonTS
The TFT is an attention-based architecture developed at Google. It has recurrent layers to learn temporal relationships at different scales combined with self-attention layers for interpretability. TFTs also use variable selection networks for feature selection, gating layers to suppress unnecessary components, and quantile loss as their loss function to produce forecasting intervals.
In this section, we delve into training and performing inference with a TFT model using the GluonTS framework.
Getting ready
Ensure you have the GluonTS library and PyTorch backend installed in your environment. We’ll use the nn5_daily_without_missing
dataset from the GluonTS repository as a working example:
from gluonts.dataset.common import ListDataset, FieldName from gluonts.dataset.repository.datasets import get_dataset dataset = get_dataset("nn5_daily_without_missing", regenerate=False) train_ds = ListDataset( ...