First i splitted the original tensor and then after some operations i want to combine the tensor to the original shape and the original tensor, before splitting it. I’m not sure i can just use the old tensor with graph mode in tensorflow.
Each dimension of the four dimension of tensor_a has at least a size of 2.
JavaScript
x
10
10
1
tensor_a = tf.split(tensor_c, split_into, axis=1)) # creating additional dimension
2
3
# some operations
4
5
tensor_a = tf.convert_to_tensor(tensor_a)
6
first, second, third, fourth = tensor_a.shape
7
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
8
9
10
Advertisement
Answer
The code below is self-contained, and shows a couple of ways of reconstructing the original tensor after splitting. I think the second approach, using tf.concat() instead of tf_convert_to_tensor() is the neatest. I’m hoping the code is self-explanatory.
JavaScript
1
39
39
1
import tensorflow as tf
2
3
# Construct a test tensor, to be split and then reconstructed
4
tensor_c = tf.reshape(tf.constant([i for i in range(24)]), [2,6,2])
5
print("tensor_c")
6
print(tensor_c.numpy())
7
8
# Split it, as the question does
9
list_of_tensor_a = tf.split(tensor_c, 3, axis=1)
10
print("nlist_of_tensor_a")
11
print([t.numpy() for t in list_of_tensor_a])
12
13
# Create a tensor shape (3, 2, 2, 2), as the question does
14
# This changes the original ordering of tensor_c. It was split on axis 1,
15
# and is now reassembled by creating a new axis 0
16
tensor_a = tf.convert_to_tensor(list_of_tensor_a)
17
print("ntensor_a")
18
print(tensor_a.shape)
19
print(tensor_a.numpy())
20
21
# Reshape as in the question.
22
# Does not reconstruct tensor_c, since the ordering has been changed
23
first, second, third, fourth = tensor_a.shape
24
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
25
print("ntensor_b - incorrect reconstruction of tensor_c")
26
print(tensor_b.numpy())
27
28
# Correct reconstruction, first approach.
29
# Use tf.transpose() to restore the original order
30
tensor_b2 = tf.reshape(tf.transpose(tensor_a,[1,0,2,3]), (second, first * third, fourth))
31
print("ntensor_b2 - correct reconstruction of tensor_c")
32
print(tensor_b2.numpy())
33
34
# Correct reconstruction, second (and neater) approach.
35
# Use tf.concat() instead of tf.convert_to_tensor()
36
tensor_b3 = tf.concat(list_of_tensor_a, axis=1)
37
print("ntensor_b3 - correct reconstruction of tensor_c")
38
print(tensor_b3.numpy())
39