First i splitted the original tensor and then after some operations i want to combine the tensor to the original shape and the original tensor, before splitting it. I’m not sure i can just use the old tensor with graph mode in tensorflow.
Each dimension of the four dimension of tensor_a has at least a size of 2.
tensor_a = tf.split(tensor_c, split_into, axis=1)) # creating additional dimension # some operations tensor_a = tf.convert_to_tensor(tensor_a) first, second, third, fourth = tensor_a.shape tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
Advertisement
Answer
The code below is self-contained, and shows a couple of ways of reconstructing the original tensor after splitting. I think the second approach, using tf.concat() instead of tf_convert_to_tensor() is the neatest. I’m hoping the code is self-explanatory.
import tensorflow as tf # Construct a test tensor, to be split and then reconstructed tensor_c = tf.reshape(tf.constant([i for i in range(24)]), [2,6,2]) print("tensor_c") print(tensor_c.numpy()) # Split it, as the question does list_of_tensor_a = tf.split(tensor_c, 3, axis=1) print("nlist_of_tensor_a") print([t.numpy() for t in list_of_tensor_a]) # Create a tensor shape (3, 2, 2, 2), as the question does # This changes the original ordering of tensor_c. It was split on axis 1, # and is now reassembled by creating a new axis 0 tensor_a = tf.convert_to_tensor(list_of_tensor_a) print("ntensor_a") print(tensor_a.shape) print(tensor_a.numpy()) # Reshape as in the question. # Does not reconstruct tensor_c, since the ordering has been changed first, second, third, fourth = tensor_a.shape tensor_b = tf.reshape(tensor_a, (second, first * third, fourth)) print("ntensor_b - incorrect reconstruction of tensor_c") print(tensor_b.numpy()) # Correct reconstruction, first approach. # Use tf.transpose() to restore the original order tensor_b2 = tf.reshape(tf.transpose(tensor_a,[1,0,2,3]), (second, first * third, fourth)) print("ntensor_b2 - correct reconstruction of tensor_c") print(tensor_b2.numpy()) # Correct reconstruction, second (and neater) approach. # Use tf.concat() instead of tf.convert_to_tensor() tensor_b3 = tf.concat(list_of_tensor_a, axis=1) print("ntensor_b3 - correct reconstruction of tensor_c") print(tensor_b3.numpy())