Skip to content
Advertisement

How to efficiently assign to a slice of a tensor in TensorFlow

I want to assign some values to slices of an input tensor in one of my model in TensorFlow 2.x (I am using 2.2 but ready to accept a solution for 2.1). A non-working template of what I am trying to do is:

import tensorflow as tf
from tensorflow.keras.models import Model

class AddToEven(Model):
    def call(self, inputs):
        outputs = inputs
        outputs[:, ::2] += inputs[:, ::2]
        return outputs

of course when building this (AddToEven().build(tf.TensorShape([None, None]))) I get the following error:

TypeError: 'Tensor' object does not support item assignment

I can achieve this simple example via the following:

class AddToEvenScatter(Model):
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        n = tf.shape(inputs)[-1]
        update_indices = tf.range(0, n, delta=2)[:, None]
        scatter_nd_perm = [1, 0]
        inputs_reshaped = tf.transpose(inputs, scatter_nd_perm)
        outputs = tf.tensor_scatter_nd_add(
            inputs_reshaped,
            indices=update_indices,
            updates=inputs_reshaped[::2],
        )
        outputs = tf.transpose(outputs, scatter_nd_perm)
        return outputs

(you can sanity-check with:

model = AddToEvenScatter()
model.build(tf.TensorShape([None, None]))
model(tf.ones([1, 10]))

)

But as you can see it’s very complicated to write. And this is only for a static number of updates (here 1) on a 1D (+ batch size) tensor.

What I want to do is a bit more involved and I think writing it with tensor_scatter_nd_add is going to be a nightmare.

A lot of the current QAs on the topic cover the case for variables but not tensors (see e.g. this or this). It is mentionned here that indeed pytorch supports this, so I am surprised to see no response from any tf members on that topic recently. This answer doesn’t really help me, because I will need some kind of mask generation which is going to be awful as well.

The question is thus: how can I do slice assignment efficiently (computation-wise, memory-wise and code-wise) w/o tensor_scatter_nd_add? The trick is that I want this to be as dynamical as possible, meaning that the shape of the inputs could be variable.

(For anyone curious I am trying to translate this code in tf).

This question was originally posted in a GitHub issue.

Advertisement

Answer

Here is another solution based on binary mask.

"""Solution based on binary mask.
- We just add this mask to inputs, instead of multiplying."""
class AddToEven(tf.keras.Model):
    def __init__(self):
        super(AddToEven, self).__init__()        

    def build(self, inputshape):
        self.built = True # Actually nothing to build with, becuase we don't have any variables or weights here.

    @tf.function
    def call(self, inputs):
        w = inputs.get_shape()[-1]

        # 1-d mask generation for w-axis (activate even indices only)        
        m_w = tf.range(w)  # [0, 1, 2,... w-1]
        m_w = ((m_w%2)==0) # [True, False, True ,...] with dtype=tf.bool

        # Apply 1-d mask to 2-d input
        m_w = tf.expand_dims(m_w, axis=0) # just extend dimension as to be (1, W)
        m_w = tf.cast(m_w, dtype=inputs.dtype) # in advance, we need to convert dtype

        # Here, we just add this (1, W) mask to (H,W) input magically.
        outputs = inputs + m_w # This add operation is allowed in both TF and numpy!
        return tf.reshape(outputs, inputs.get_shape())

Sanity-check here.

# sanity-check as model
model = AddToEven()
model.build(tf.TensorShape([None, None]))
z = model(tf.zeros([2,4]))
print(z)

Result (with TF 2.1) is like this.

tf.Tensor(
[[1. 0. 1. 0.]
 [1. 0. 1. 0.]], shape=(2, 4), dtype=float32)

——– Below is the previous answer ——–

You need to create tf.Variable in build() method. It also allows dynamic size by shape=(None,). In the code below, I specified the input shape as (None, None).

class AddToEven(tf.keras.Model):
    def __init__(self):
        super(AddToEven, self).__init__()

    def build(self, inputshape):
        self.v = tf.Variable(initial_value=tf.zeros((0,0)), shape=(None, None), trainable=False, dtype=tf.float32)

    @tf.function
    def call(self, inputs):
        self.v.assign(inputs)
        self.v[:, ::2].assign(self.v[:, ::2] + 1)
        return self.v.value()

I tested this code with TF 2.1.0 and TF1.15

# test
add_to_even = AddToEven()
z = add_to_even(tf.zeros((2,4)))
print(z)

Result:

tf.Tensor(
[[1. 0. 1. 0.]
 [1. 0. 1. 0.]], shape=(2, 4), dtype=float32)

P.S. There are some other ways, such as using tf.numpy_function(), or generating mask function.

User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement