Skip to content
Advertisement

Drop a row in a tensor if the sum of the elements is lower than some threshold

How can I drop rows in a tensor if the sum of the elements in each row is lower than the threshold -1? For example:

tensor = tf.random.normal((3, 3))
tf.Tensor(
[[ 0.506158    0.53865975 -0.40939444]
 [ 0.4917719  -0.1575156   1.2308844 ]
 [ 0.08580616 -1.1503975  -2.252681  ]], shape=(3, 3), dtype=float32)

Since the sum of the last row is smaller than -1, I need to remove it and get the tensor (2, 3):

tf.Tensor(
[[ 0.506158    0.53865975 -0.40939444]
 [ 0.4917719  -0.1575156   1.2308844 ]], shape=(2, 3), dtype=float32)

I know how to use tf.reduce_sum, but I do not know how to delete rows from a tensor. Something like df.drop would be nice.

Advertisement

Answer

tf.boolean_mask is all you need.

tensor = tf.constant([
    [ 0.506158,    0.53865975, -0.40939444],
    [ 0.4917719,  -0.1575156,   1.2308844 ],
    [ 0.08580616, -1.1503975,  -2.252681  ],
])

mask = tf.reduce_sum(tensor, axis=1) > -1 
# <tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>

tf.boolean_mask(
    tensor=tensor, 
    mask=mask,
    axis=0
)
# <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
# array([[ 0.506158  ,  0.53865975, -0.40939444],
#        [ 0.4917719 , -0.1575156 ,  1.2308844 ]], dtype=float32)>

User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement