Saving and restoring a model in TensorFlow
In the previous section, we built a graph and trained it. How about doing the actual prediction on the held out test set? The problem is that we did not save the model parameters; so, once the execution of the preceding statements are finished and we exit the tf.Session
environment, all the variables and their allocated memories are freed.
One solution is to train a model, and as soon as the training is finished, we can feed it our test set. However, this is not a good approach since deep neural network models are typically trained over multiple hours, days, or even weeks.
The best approach is to save the trained model for future use. For this purpose, we need to add a new node to the graph, an instance of the tf.train.Saver
class, which we call saver
.
In the following statement, we can add more nodes to a particular graph. In this case, we are adding saver
to the graph g
:
>>> with g.as_default(): ... saver = tf.train.Saver()
Next, we can...