Yet another way to debug TensorFlow models is to insert conditional asserts. The tf.Assert() function takes a condition, and if the condition is false, it then prints the lists of given tensors and throws tf.errors.InvalidArgumentError.
- The tf.Assert() function has the following signature:
tf.Assert(
condition,
data,
summarize=None,
name=None
)
- An assert operation does not fall in the path of the graph like the tf.Print() function. To make sure that the tf.Assert() operation gets executed, we need to add it to the dependencies. For example, let us define an assertion to check that all the inputs are positive:
assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x])
- Addassert_op to the dependencies at the time of defining the model, as follows:
with tf.control_dependencies([assert_op]):
# x is input layer
layer...