The easiest way to create an Estimator is to convert a Keras model. After the model has been compiled, call tf.keras.estimator.model_to_estimator():
estimator = tf.keras.estimator.model_to_estimator(model, model_dir='./estimator_dir')
The model_dir argument allows you to specify a location where the checkpoints of the model will be saved. As mentioned earlier, Estimators will automatically save checkpoints for our models.
Training an Estimator requires the use of an input function—a function that returns data in a specific format. One of the accepted formats is a TensorFlow dataset. The dataset API is described in depth in Chapter 7, Training on Complex and Scarce Datasets. For now, we'll define the following function, which returns the dataset defined in the first part of this chapter in the correct format, in batches of 32 samples:
BATCH_SIZE = 32
def train_input_fn():
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train...