Loops in symbolic computing
The Python for
loop can be used outside the symbolic graph, as in a normal Python program. But outside the graph, a traditional Python for
loop isn't compiled, so it will not be optimized with parallel and algebra libraries, cannot be automatically differentiated, and introduces costly data transfers if the computation subgraph has been optimized for GPU.
That's why a symbolic operator, T.scan
, is designed to create a for
loop as an operator inside the graph. Theano will unroll the loop into the graph structure and the whole unrolled loop is going to be compiled on the target architecture as the rest of the computation graph. Its signature is as follows:
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)
The scan
operator is very useful to implement array loops, reductions, maps, multi-dimensional derivatives such as Jacobian or Hessian, and recurrences.
The scan
operator is running the fn
function repeatedly for n_steps
. If n_steps
is None
, the operator will find out by the length of the sequences:
Note
The step fn
function is a function that builds a symbolic graph, and that function will only get called once. However, that graph will then be compiled into another Theano function that will be called repeatedly. Some users try to pass a compile Theano function as fn
, which is not possible.
Sequences are the lists of input variables to loop over. The number of steps will correspond to the shortest sequence in the list. Let's have a look:
>>> a = T.matrix() >>> b = T.matrix() >>> def fn(x): return x + 1 >>> results, updates = theano.scan(fn, sequences=a) >>> f = theano.function([a], results, updates=updates) >>> f(numpy.ones((2,3)).astype(theano.config.floatX)) array([[ 2., 2., 2.], [ 2., 2., 2.]], dtype=float32)
The scan
operator has been running the function against all elements in the input tensor, a
, and kept the same shape as the input tensor, (2,3)
.
Note
It is a good practice to add the updates returned by theano.scan
in the theano.function
, even if these updates are empty.
The arguments given to the fn
function can be much more complicated. T.scan
will call the fn
function at each step with the following argument list, in the following order:
fn( sequences (if any), prior results (if needed), non-sequences (if any) )
As shown in the following figure, three arrows are directed towards the fn
step function and represent the three types of possible input at each time step in the loop:
If specified, the outputs_info
parameter is the initial state to use to start recurrence from. The parameter name does not sound very good, but the initial state also gives the shape information of the last state, as well as all other states. The initial state can be seen as the first output. The final output will be an array of states.
For example, to compute the cumulative sum in a vector, with an initial state of the sum at 0
, use this code:
>>> a = T.vector()
>>> s0 = T.scalar("s0")
>>> def fn( current_element, prior ):
... return prior + current_element
>>> results, updates = theano.scan(fn=fn,outputs_info=s0,sequences=a)
>>> f = theano.function([a,s0], results, updates=updates)
>>> f([0,3,5],0)
array([ 0., 3., 8.], dtype=float32)
When outputs_info
is set, the first dimension of the outputs_info
and sequence variables is the time step. The second dimension is the dimensionality of data at each time step.
In particular, outputs_info
has the number of previous time-steps required to compute the first step.
Here is the same example, but with a vector at each time step instead of a scalar for the input data:
>>> a = T.matrix() >>> s0 = T.scalar("s0") >>> def fn( current_element, prior ): ... return prior + current_element.sum() >>> results, updates = theano.scan(fn=fn,outputs_info=s0,sequences=a) >>> f = theano.function([a,s0], results, updates=updates) >>> f(numpy.ones((20,5)).astype(theano.config.floatX),0) array([ 5., 10., 15., 20., 25., 30., 35., 40., 45., 50., 55., 60., 65., 70., 75., 80., 85., 90., 95., 100.], dtype=float32)
Twenty steps along the rows (times) have accumulated the sum of all elements. Note that initial state (here 0
) given by the outputs_info
argument is not part of the output sequence.
The recurrent function, fn
, may be provided with some fixed data, independent of the step in the loop, thanks to the non_sequences
scan parameter:
>>> a = T.vector() >>> s0 = T.scalar("s0") >>> def fn( current_element, prior, non_seq ): ... return non_seq * prior + current_element >>> results, updates = theano.scan(fn=fn,n_steps=10,sequences=a,outputs_info=T.constant(0.0),non_sequences=s0) >>> f = theano.function([a,s0], results, updates=updates) >>> f(numpy.ones((20)).astype(theano.),5) array([ 1.00000000e+00, 6.00000000e+00, 3.10000000e+01, 1.56000000e+02, 7.81000000e+02, 3.90600000e+03, 1.95310000e+04, 9.76560000e+04, 4.88281000e+05, 2.44140600e+06], dtype=float32)
It is multiplying the prior value by 5
and adding the new element.
Note that T.scan
in the optimized graph on GPU does not execute different iterations of the loop in parallel, even in the absence of recurrence.