Model training
Model training is implemented in the fit(..)
method. It takes the following parameters:
train_X
:array_like, shape (n_samples, n_features)
, Training datatrain_Y
:array_like, shape (n_samples, n_classes)
, Training labelsval_X
:array_like, shape (N, n_features) optional, (default = None)
, Validation dataval_Y
:array_like, shape (N, n_classes) optional, (default = None)
, Validation labelsgraph
:tf.Graph, optional (default = None)
, TensorFlow Graph object
Next, we look at the implementation of fit(...)
function where the model is trained and saved in the model path specified by model_path
.
def fit(self, train_X, train_Y, val_X=None, val_Y=None, graph=None): if len(train_Y.shape) != 1: num_classes = train_Y.shape[1] else: raise Exception("Please convert the labels with one-hot encoding.") g = graph if graph is not None else self.tf_graph with g.as_default(): # Build model self.build_model(train_X.shape[1], num_classes) with...