Skip to content
Advertisement

How to use tf.repeat() to replicate a specific column/row/slice?

this thread explains well the use of tf.repeat() as a tensorflow alternative to np.repeat(). one functionality which I was unable to figure out, in np.repeat(), a specific column/row/slice can be replicated by supplying the index. e.g.

import numpy as np
x = np.array([[1,2],[3,4]])
np.repeat(x, [1, 2], axis=0)
       # Answer will be -> array([[1, 2],
       #                          [3, 4],
       #                          [3, 4]])

is there any tensorflow alternative to this functionality of np.repeat()?

Advertisement

Answer

You could use the repeats parameter of tf.repeat:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.repeat(x, repeats=[1, 2], axis=0)
print(x)
tf.Tensor(
[[1 2]
 [3 4]
 [3 4]], shape=(3, 2), dtype=int32)

where you get the first row in the tensor once, and the second row twice.

Or you could use tf.concat with tf.repeat:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.concat([x[:1], tf.repeat(x[1:], 2, axis=0)], axis=0)
print(x)

Tensorflow 1.14.0 solution:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.concat([x[:1], tf.tile(x[1:], multiples=[2, 1])], axis=0)
print(x)
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement