Skip to content
Advertisement

How to efficiently assign to a slice of a tensor in TensorFlow

I want to assign some values to slices of an input tensor in one of my model in TensorFlow 2.x (I am using 2.2 but ready to accept a solution for 2.1). A non-working template of what I am trying to do is:

JavaScript

of course when building this (AddToEven().build(tf.TensorShape([None, None]))) I get the following error:

JavaScript

I can achieve this simple example via the following:

JavaScript

(you can sanity-check with:

JavaScript

)

But as you can see it’s very complicated to write. And this is only for a static number of updates (here 1) on a 1D (+ batch size) tensor.

What I want to do is a bit more involved and I think writing it with tensor_scatter_nd_add is going to be a nightmare.

A lot of the current QAs on the topic cover the case for variables but not tensors (see e.g. this or this). It is mentionned here that indeed pytorch supports this, so I am surprised to see no response from any tf members on that topic recently. This answer doesn’t really help me, because I will need some kind of mask generation which is going to be awful as well.

The question is thus: how can I do slice assignment efficiently (computation-wise, memory-wise and code-wise) w/o tensor_scatter_nd_add? The trick is that I want this to be as dynamical as possible, meaning that the shape of the inputs could be variable.

(For anyone curious I am trying to translate this code in tf).

This question was originally posted in a GitHub issue.

Advertisement

Answer

Here is another solution based on binary mask.

JavaScript

Sanity-check here.

JavaScript

Result (with TF 2.1) is like this.

JavaScript

——– Below is the previous answer ——–

You need to create tf.Variable in build() method. It also allows dynamic size by shape=(None,). In the code below, I specified the input shape as (None, None).

JavaScript

I tested this code with TF 2.1.0 and TF1.15

JavaScript

Result:

JavaScript

P.S. There are some other ways, such as using tf.numpy_function(), or generating mask function.

User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement