Tensorflow TensorArray Simple Example

A small example on how to use Tensorflow TensorArray.

import numpy as np
import tensorflow as tf    

matrix = tf.placeholder(tf.int32, shape=(5, 3), name="input_matrix")
matrix_rows = tf.shape(matrix)[0] #should be 5

# each element of the tensor_array corresponds to each row of the matrix
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows)

init_state = (0, ta)

#Also can write as :-   condition = lambda i, _: i < matrix_rows
def condition(i,ta):
    return (i < matrix_rows)

#Also can write as :-   body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * (i+1)))
def body(i,ta):    
    ta = ta.write(i, matrix[i] * (i+1)) # at index i of the tensor_array, write (i+1) * matrix_row[i]  
    i = tf.add(i,1) # do this for all the elements of the tensor_array
    return i,ta

n, ta_final = tf.while_loop(condition, body, init_state)

#get the final result
ta_final_result = ta_final.stack()

#run the graph
with tf.Session() as sess:
    # print the output of ta_final_result
    a,b = sess.run([n,ta_final_result], feed_dict={matrix: np.ones(shape=(5,3), dtype=np.int32)})
    print('no of loops completed = ',a)
    print('Final content of tensorarray = ',b)
    

Output:

    no of loops completed =  5
    Final content of tensorarray =  [[1 1 1]
     [2 2 2]
     [3 3 3]
     [4 4 4]
     [5 5 5]]
Written on September 12, 2016