Skip to content
Advertisement

Dropping dimension of Tensor with size bigger then one tensor, in Tensorflow

First i splitted the original tensor and then after some operations i want to combine the tensor to the original shape and the original tensor, before splitting it. I’m not sure i can just use the old tensor with graph mode in tensorflow.

Each dimension of the four dimension of tensor_a has at least a size of 2.

tensor_a = tf.split(tensor_c, split_into, axis=1)) # creating additional dimension

# some operations

tensor_a = tf.convert_to_tensor(tensor_a)
first, second, third, fourth = tensor_a.shape
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))


Advertisement

Answer

The code below is self-contained, and shows a couple of ways of reconstructing the original tensor after splitting. I think the second approach, using tf.concat() instead of tf_convert_to_tensor() is the neatest. I’m hoping the code is self-explanatory.

import tensorflow as tf

# Construct a test tensor, to be split and then reconstructed
tensor_c = tf.reshape(tf.constant([i for i in range(24)]), [2,6,2])
print("tensor_c")
print(tensor_c.numpy())

# Split it, as the question does
list_of_tensor_a = tf.split(tensor_c, 3, axis=1)
print("nlist_of_tensor_a")
print([t.numpy() for t in list_of_tensor_a])

# Create a tensor shape (3, 2, 2, 2), as the question does
# This changes the original ordering of tensor_c. It was split on axis 1, 
# and is now reassembled by creating a new axis 0
tensor_a = tf.convert_to_tensor(list_of_tensor_a)
print("ntensor_a")
print(tensor_a.shape)
print(tensor_a.numpy())

# Reshape as in the question. 
# Does not reconstruct tensor_c, since the ordering has been changed
first, second, third, fourth = tensor_a.shape
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
print("ntensor_b - incorrect reconstruction of tensor_c")
print(tensor_b.numpy())

# Correct reconstruction, first approach. 
# Use tf.transpose() to restore the original order
tensor_b2 = tf.reshape(tf.transpose(tensor_a,[1,0,2,3]), (second, first * third, fourth))
print("ntensor_b2 - correct reconstruction of tensor_c")
print(tensor_b2.numpy())

# Correct reconstruction, second (and neater) approach. 
# Use tf.concat() instead of tf.convert_to_tensor()
tensor_b3 = tf.concat(list_of_tensor_a, axis=1)
print("ntensor_b3 - correct reconstruction of tensor_c")
print(tensor_b3.numpy())
User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement