Forecasting with a GRU using PyTorch
In this recipe, you will use the same train_model_pt
function from the previous Forecasting with an RNN using PyTorch recipe. The function trains the model, captures loss function scores, evaluates the model, makes a forecast using the test set, and finally, produces plots for further evaluation.
You will still need to define a new class for the GRU model.
How to do it...
- Create a GRU class that will inherit from
Module
class. The setup will be similar to theRNN
class, but unlike theLSTM
class, you only handle one state (the hidden state):class GRU(nn.Module): def __init__(self, input_size, output_size, n_features, n_layers): super(GRU, self).__init__() self.n_layers = n_layers self.hidden_dim = n_features self.gru...