Skip to content Skip to sidebar Skip to footer

Assign On A Tf.variable Tensor Slice

I am trying to do the following state[0,:] = state[0,:].assign( 0.9*prev_state + 0.1*( tf.matmul(inputs, weights) + biases ) ) for i in xrange(1,BATCH_SIZE): state[i,:] = state

Solution 1:

Tensorflow Variable objects have limited support for updating slices, using the tf.scatter_update(), tf.scatter_add(), and tf.scatter_sub() ops. Each of these ops allows you to specify a variable, a vector of slice indices (representing indices in the 0th dimension of the variable, which indicate the contiguous slices to be mutated) and a tensor of values (representing the new values to be applied to the variable, at the corresponding slice indices).

To update a single row of the variable, you can use tf.scatter_update(). For example, to update the 0th row of state, you would do:

updated_state = tf.scatter_update(
    state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs, weights) + biases))

To chain multiple updates, you can use the mutable updated_state tensor that is returned from tf.scatter_update():

for i in xrange(1, BATCH_SIZE):
  updated_state = tf.scatter_update(
      updated_state, [i], 0.9 * updated_state[i-1, :] + ...)

prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])

Finally, you can evaluate the resulting updated_state.op to apply all of the updates to state:

sess.run(updated_state.op)  # or `sess.run(updated_state)` to fetch the result

PS. You might find it more efficient to use tf.scan() to compute the intermediate states, and just materialize prev_state in a variable.


Post a Comment for "Assign On A Tf.variable Tensor Slice"