I want to select top-n elements of 3 dimension tensor given the picked elements are all unique. All the elements are sorted by the 2nd column, and I’m selecting top-2 in the example below but I don’t want duplicates in there.
Condition: No
for loops
ortf.map_fn()
Here is the input and desired_output that I want:
JavaScript
x
19
19
1
input_tensor = tf.constant([
2
[[2.0, 1.0],
3
[2.0, 1.0],
4
[3.0, 0.4],
5
[1.0, 0.1]],
6
[[44.0, 0.8],
7
[22.0, 0.7],
8
[11.0, 0.5],
9
[11.0, 0.5]],
10
[[5555.0, 0.8],
11
[3333.0, 0.7],
12
[4444.0, 0.4],
13
[1111.0, 0.1]],
14
[[444.0, 0.8],
15
[333.0, 1.1],
16
[333.0, 1.1],
17
[111.0, 0.1]]
18
])
19
- This is what I’m getting right now; which I don’t want!
JavaScript
1
12
12
1
>> TOPK = 2
2
>> topk_resutls = tf.gather(
3
input_tensor,
4
tf.math.top_k(input_tensor[:, :, 1], k=TOPK, sorted=True).indices,
5
batch_dims=1
6
)
7
>> topk_resutls.numpy().tolist()
8
[[[2.0, 1.0], [2.0, 1.0]],
9
[[44.0, 0.8], [22.0, 0.7]],
10
[[5555.0, 0.8], [3333.0, 0.7]],
11
[[333.0, 1.1], [333.0, 1.1]]]
12
- Here is what I actually want
JavaScript
1
5
1
[[[2.0, 1.0], [3.0, 0.4]], # [3.0, 0.4] is the 2nd highest element based on 2nd column
2
[[44.0, 0.8], [22.0, 0.7]],
3
[[5555.0, 0.8], [3333.0, 0.7]],
4
[[333.0, 1.1], [444.0, 0.8]]] # [444.0, 0.8] is the 2nd highest element based on 2nd column
5
Advertisement
Answer
This is one possible way to do that, although it requires more work since it sorts the array first.
JavaScript
1
51
51
1
import tensorflow as tf
2
import numpy as np
3
4
# Input data
5
k = 2
6
input_tensor = tf.constant([
7
[[2.0, 1.0],
8
[2.0, 1.0],
9
[3.0, 0.4],
10
[1.0, 0.1]],
11
[[44.0, 0.8],
12
[22.0, 0.7],
13
[11.0, 0.5],
14
[11.0, 0.5]],
15
[[5555.0, 0.8],
16
[3333.0, 0.7],
17
[4444.0, 0.4],
18
[1111.0, 0.1]],
19
[[444.0, 0.8],
20
[333.0, 1.1],
21
[333.0, 1.1],
22
[111.0, 0.1]]
23
])
24
# Sort by first column
25
idx = tf.argsort(input_tensor[ , 0], axis=-1)
26
s = tf.gather_nd(input_tensor, tf.expand_dims(idx, axis=-1), batch_dims=1)
27
# Find repeated elements
28
col1 = s[ , 0]
29
col1_ext = tf.concat([col1[ , :1] - 1, col1], axis=-1)
30
mask = tf.math.not_equal(col1_ext[ , 1:], col1_ext[ , :-1])
31
# Replace value for repeated elements with "minus infinity"
32
col2 = s[ , 1]
33
col2_masked = tf.where(mask, col2, col2.dtype.min)
34
# Get top-k results
35
topk_idx = tf.math.top_k(col2_masked, k=k, sorted=True).indices
36
topk_results = tf.gather(s, topk_idx, batch_dims=1)
37
# Print
38
with np.printoptions(suppress=True):
39
print(topk_results.numpy())
40
# [[[ 2. 1. ]
41
# [ 3. 0.4]]
42
#
43
# [[ 44. 0.8]
44
# [ 22. 0.7]]
45
#
46
# [[5555. 0.8]
47
# [3333. 0.7]]
48
#
49
# [[ 333. 1.1]
50
# [ 444. 0.8]]]
51
Note there is a kind of corner case which is when there are not k
different elements in a group. In that case, this solution would put the repeated elements at the end, which would break the score order.