When I run my code:
JavaScript
x
20
20
1
import tensorflow as tf
2
import numpy as np
3
4
5
A = np.array([
6
[0,1,0,1,1,0,0,0,0,1],
7
[0,1,0,1,1,0,0,0,0,0],
8
[0,1,0,1,0,0,0,0,0,1]
9
])
10
11
sliced = A[:, -1]
12
13
bool_tensor = tf.math.equal(sliced, 0)
14
15
with tf.compat.v1.Session() as tfs:
16
17
print('run(bool_tensor) : ',tfs.run(bool_tensor))
18
19
print(tf.cond(bool_tensor, lambda: 999, lambda: -999))
20
I get:
run(bool_tensor) : [False True False]
ValueError: Shape must be rank 0 but is rank 1 for ‘cond/Switch’ (op: ‘Switch’) with input shapes: [3], [3].
But I want the second print to show a Tensor that evaluates to: [-999 999 -999]
I have looked into other posts but could find a solution.
Thank you
p.s: I use Tensorflow 1
Advertisement
Answer
Try using tf.where
:
JavaScript
1
18
18
1
import tensorflow as tf
2
import numpy as np
3
4
5
A = np.array([
6
[0,1,0,1,1,0,0,0,0,1],
7
[0,1,0,1,1,0,0,0,0,0],
8
[0,1,0,1,0,0,0,0,0,1]
9
])
10
11
sliced = A[:, -1]
12
bool_tensor = tf.math.equal(sliced, 0)
13
with tf.compat.v1.Session() as tfs:
14
15
print('run(bool_tensor) : ', tfs.run(bool_tensor))
16
17
print(tfs.run(tf.where(bool_tensor, tf.repeat([999], repeats=tf.shape(bool_tensor)[0]), tf.repeat([-999], repeats=tf.shape(bool_tensor)[0]))))
18
JavaScript
1
3
1
run(bool_tensor) : [False True False]
2
[-999 999 -999]
3