Skip to content

Behavior of in Tensorflow

I’m trying to take variable length tensors and split them up into tensors of length 4, discarding any extra elements (if the length is not divisible by four).

I’ve therefore written the following function:

def batches_of_four(tokens):
  token_length = tokens.shape[0]

  splits = token_length // 4

  tokens = tokens[0 : splits * 4]

  return tf.split(tokens, num_or_size_splits=splits)

dataset =
    tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7]]))


This produces the output [<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4], dtype=int32)>], as expected.

If I now run the same function using

for item in

I instead get the following error

    File "<ipython-input-173-a09c55117ea2>", line 5, in batches_of_four  *
        splits = token_length // 4

    TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'

I see that this is because token_length is None, but I don’t understand why. I assume this has something to do with graph vs eager execution, but the function works if I call it outside of .map even if I annotate it with @tf.function.

Why is the behavior different inside .map? (Also: is there any better way of writing the batches_of_four function?)



You should use tf.shape to get the dynamic shape of a tensor in graph mode:

token_length = tf.shape(tokens)[0]

And another problem you have is using a scalar tensor as the number of splits in graph mode. That won’t work either.

Try this:

import tensorflow as tf

def body(i, m, n):
  n = n.write(n.size(), m[i:i+chunk_size])
  return tf.add(i,chunk_size), m, n 

def split_data(data, chunk_size):
    length = tf.shape(data)[0]
    x = data[:(length // chunk_size) * chunk_size]
    ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    i0 = tf.constant(0)
    c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
    _, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
    return out.stack()

chunk_size = 4

dataset =
    tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8, 9]])).map(lambda x: split_data(x, 4)).flat_map(

for item in dataset:
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([5 6 7 8], shape=(4,), dtype=int32)

And see my other answer here.

User contributions licensed under: CC BY-SA
6 People found this is helpful