Assign On A Tf.variable Tensor Slice
Solution 1:
Tensorflow Variable
objects have limited support for updating slices, using the tf.scatter_update()
, tf.scatter_add()
, and tf.scatter_sub()
ops. Each of these ops allows you to specify a variable, a vector of slice indices (representing indices in the 0th dimension of the variable, which indicate the contiguous slices to be mutated) and a tensor of values (representing the new values to be applied to the variable, at the corresponding slice indices).
To update a single row of the variable, you can use tf.scatter_update()
. For example, to update the 0th row of state
, you would do:
updated_state = tf.scatter_update(
state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs, weights) + biases))
To chain multiple updates, you can use the mutable updated_state
tensor that is returned from tf.scatter_update()
:
for i in xrange(1, BATCH_SIZE):
updated_state = tf.scatter_update(
updated_state, [i], 0.9 * updated_state[i-1, :] + ...)
prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])
Finally, you can evaluate the resulting updated_state.op
to apply all of the updates to state
:
sess.run(updated_state.op) # or `sess.run(updated_state)` to fetch the result
PS. You might find it more efficient to use tf.scan()
to compute the intermediate states, and just materialize prev_state
in a variable.
Post a Comment for "Assign On A Tf.variable Tensor Slice"