Skip to content
Advertisement

How can I select top-n elements from tensor without repeating elements?

I want to select top-n elements of 3 dimension tensor given the picked elements are all unique. All the elements are sorted by the 2nd column, and I’m selecting top-2 in the example below but I don’t want duplicates in there.

  • Condition: No for loops or tf.map_fn()

  • Here is the input and desired_output that I want:

input_tensor = tf.constant([
  [[2.0, 1.0],
   [2.0, 1.0],
   [3.0, 0.4],
   [1.0, 0.1]],
  [[44.0, 0.8],
   [22.0, 0.7],
   [11.0, 0.5],
   [11.0, 0.5]],
  [[5555.0, 0.8],
   [3333.0, 0.7],
   [4444.0, 0.4],
   [1111.0, 0.1]],
  [[444.0, 0.8],
   [333.0, 1.1],
   [333.0, 1.1],
   [111.0, 0.1]]
])
  • This is what I’m getting right now; which I don’t want!
>> TOPK = 2
>> topk_resutls = tf.gather(
    input_tensor,
    tf.math.top_k(input_tensor[:, :, 1], k=TOPK, sorted=True).indices,
    batch_dims=1
)
>> topk_resutls.numpy().tolist()
[[[2.0, 1.0], [2.0, 1.0]],
 [[44.0, 0.8], [22.0, 0.7]],
 [[5555.0, 0.8], [3333.0, 0.7]],
 [[333.0, 1.1], [333.0, 1.1]]]
  • Here is what I actually want
[[[2.0, 1.0], [3.0, 0.4]],       # [3.0, 0.4] is the 2nd highest element based on 2nd column
 [[44.0, 0.8], [22.0, 0.7]],   
 [[5555.0, 0.8], [3333.0, 0.7]],
 [[333.0, 1.1], [444.0, 0.8]]]   # [444.0, 0.8] is the 2nd highest element based on 2nd column

Advertisement

Answer

This is one possible way to do that, although it requires more work since it sorts the array first.

import tensorflow as tf
import numpy as np

# Input data
k = 2
input_tensor = tf.constant([
  [[2.0, 1.0],
   [2.0, 1.0],
   [3.0, 0.4],
   [1.0, 0.1]],
  [[44.0, 0.8],
   [22.0, 0.7],
   [11.0, 0.5],
   [11.0, 0.5]],
  [[5555.0, 0.8],
   [3333.0, 0.7],
   [4444.0, 0.4],
   [1111.0, 0.1]],
  [[444.0, 0.8],
   [333.0, 1.1],
   [333.0, 1.1],
   [111.0, 0.1]]
])
# Sort by first column
idx = tf.argsort(input_tensor[..., 0], axis=-1)
s = tf.gather_nd(input_tensor, tf.expand_dims(idx, axis=-1), batch_dims=1)
# Find repeated elements
col1 = s[..., 0]
col1_ext = tf.concat([col1[..., :1] - 1, col1], axis=-1)
mask = tf.math.not_equal(col1_ext[..., 1:], col1_ext[..., :-1])
# Replace value for repeated elements with "minus infinity"
col2 = s[..., 1]
col2_masked = tf.where(mask, col2, col2.dtype.min)
# Get top-k results
topk_idx = tf.math.top_k(col2_masked, k=k, sorted=True).indices
topk_results = tf.gather(s, topk_idx, batch_dims=1)
# Print
with np.printoptions(suppress=True):
    print(topk_results.numpy())
# [[[   2.     1. ]
#   [   3.     0.4]]
# 
#  [[  44.     0.8]
#   [  22.     0.7]]
# 
#  [[5555.     0.8]
#   [3333.     0.7]]
# 
#  [[ 333.     1.1]
#   [ 444.     0.8]]]

Note there is a kind of corner case which is when there are not k different elements in a group. In that case, this solution would put the repeated elements at the end, which would break the score order.

User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement