Asserting on conditions with tf.Assert()
Yet another way to debug TensorFlow models is to insert conditional asserts. The tf.Assert()
function takes a condition, and if the condition is false, it then prints the lists of given tensors and throws tf.errors.InvalidArgumentError
.
- The
tf.Assert()
function has the following signature:
tf.Assert( condition, data, summarize=None, name=None )
- An assert operation does not fall in the path of the graph like the
tf.Print()
function. To make sure that thetf.Assert()
operation gets executed, we need to add it to the dependencies. For example, let us define an assertion to check that all the inputs are positive:
assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x])
- Add
assert_op
to the dependencies at the time of defining the model, as follows:
with tf.control_dependencies([assert_op]): # x is input layer layer = x # add hidden layers for i in range(num_layers): layer = tf.nn.relu(tf.matmul(layer, w[i]) + b[i]) # add output layer layer...