If I have the following dataset:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
When I use a batch_size=2
, I would get [[1,2], [3,4], [5,6]]
.
However, I would like to get the following output:
[[1,2,1,2], [3,4,3,4], [5,6,5,6]]
Basically, I want to repeat the batch dimension by 2x and use this as a new batch. Obviously, this is a toy example. In a real case, if I have a batch of size (64, 300)
, I would like to make a batch of (128, 300)
.
Advertisement
Answer
You can do it by defining a map function
def double_input(x): x = tf.concat([x,x],axis=0) return x dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]) dataset = dataset.batch(2) dataset = dataset.map(double_input) for x in dataset.take(-1): print(x) >>>tf.Tensor([1 2 1 2], shape=(4,), dtype=int32) >>>tf.Tensor([3 4 3 4], shape=(4,), dtype=int32) >>>tf.Tensor([5 6 5 6], shape=(4,), dtype=int32)