I want to select top-n elements of 3 dimension tensor given the picked elements are all unique. All the elements are sorted by the 2nd column, and I’m selecting top-2 in the example below but I don’t want duplicates in there.
Condition: No
for loops
ortf.map_fn()
Here is the input and desired_output that I want:
input_tensor = tf.constant([ [[2.0, 1.0], [2.0, 1.0], [3.0, 0.4], [1.0, 0.1]], [[44.0, 0.8], [22.0, 0.7], [11.0, 0.5], [11.0, 0.5]], [[5555.0, 0.8], [3333.0, 0.7], [4444.0, 0.4], [1111.0, 0.1]], [[444.0, 0.8], [333.0, 1.1], [333.0, 1.1], [111.0, 0.1]] ])
- This is what I’m getting right now; which I don’t want!
>> TOPK = 2 >> topk_resutls = tf.gather( input_tensor, tf.math.top_k(input_tensor[:, :, 1], k=TOPK, sorted=True).indices, batch_dims=1 ) >> topk_resutls.numpy().tolist() [[[2.0, 1.0], [2.0, 1.0]], [[44.0, 0.8], [22.0, 0.7]], [[5555.0, 0.8], [3333.0, 0.7]], [[333.0, 1.1], [333.0, 1.1]]]
- Here is what I actually want
[[[2.0, 1.0], [3.0, 0.4]], # [3.0, 0.4] is the 2nd highest element based on 2nd column [[44.0, 0.8], [22.0, 0.7]], [[5555.0, 0.8], [3333.0, 0.7]], [[333.0, 1.1], [444.0, 0.8]]] # [444.0, 0.8] is the 2nd highest element based on 2nd column
Advertisement
Answer
This is one possible way to do that, although it requires more work since it sorts the array first.
import tensorflow as tf import numpy as np # Input data k = 2 input_tensor = tf.constant([ [[2.0, 1.0], [2.0, 1.0], [3.0, 0.4], [1.0, 0.1]], [[44.0, 0.8], [22.0, 0.7], [11.0, 0.5], [11.0, 0.5]], [[5555.0, 0.8], [3333.0, 0.7], [4444.0, 0.4], [1111.0, 0.1]], [[444.0, 0.8], [333.0, 1.1], [333.0, 1.1], [111.0, 0.1]] ]) # Sort by first column idx = tf.argsort(input_tensor[..., 0], axis=-1) s = tf.gather_nd(input_tensor, tf.expand_dims(idx, axis=-1), batch_dims=1) # Find repeated elements col1 = s[..., 0] col1_ext = tf.concat([col1[..., :1] - 1, col1], axis=-1) mask = tf.math.not_equal(col1_ext[..., 1:], col1_ext[..., :-1]) # Replace value for repeated elements with "minus infinity" col2 = s[..., 1] col2_masked = tf.where(mask, col2, col2.dtype.min) # Get top-k results topk_idx = tf.math.top_k(col2_masked, k=k, sorted=True).indices topk_results = tf.gather(s, topk_idx, batch_dims=1) # Print with np.printoptions(suppress=True): print(topk_results.numpy()) # [[[ 2. 1. ] # [ 3. 0.4]] # # [[ 44. 0.8] # [ 22. 0.7]] # # [[5555. 0.8] # [3333. 0.7]] # # [[ 333. 1.1] # [ 444. 0.8]]]
Note there is a kind of corner case which is when there are not k
different elements in a group. In that case, this solution would put the repeated elements at the end, which would break the score order.