I am working with time series models in tensorflow. My dataset contains physics signals. I need to divide this signals into windows as give this sliced windows as input to my model.
Here is how I am reading the data and slicing it:
import tensorflow as tf import numpy as np def _ds_slicer(data): win_len = 768 return {"mix":(tf.stack(tf.split(data["mix"],win_len))), "pure":(tf.stack(tf.split(data["pure"],win_len)))} dataset = tf.data.Dataset.from_tensor_slices({ "mix" : np.random.uniform(0,1,[1000,24576]), "pure" : np.random.uniform(0,1,[1000,24576]) }) dataset = dataset.map(_ds_slicer) print dataset.output_shapes # {'mix': TensorShape([Dimension(768), Dimension(32)]), 'pure': TensorShape([Dimension(768), Dimension(32)])}
I want to reshape this dataset to # {'mix': TensorShape([Dimension(32)]), 'pure': TensorShape([Dimension(32))}
Equivalent transformation in numpy would be something like following:
signal = np.random.uniform(0,1,[1000,24576]) sliced_sig = np.stack(np.split(signal,768,axis=1),axis=1) print sliced_sig.shape #(1000, 768, 32) sliced_sig=sliced_sig.reshape(-1, sliced_sig.shape[-1]) print sliced_sig.shape #(768000, 32)
I thought of using tf.contrib.data.group_by_window as an input to dataset.apply() but couldn’t figure out exactly how to use it. Is there a way I can use any custom transformation to reshape the dataset?
Advertisement
Answer
I think you’re just looking for the transformation tf.contrib.data.unbatch
. This does exactly what you want:
x = np.zeros((1000, 768, 32)) dataset = tf.data.Dataset.from_tensor_slices(x) print(dataset.output_shapes) # (768, 32) dataset = dataset.apply(tf.contrib.data.unbatch()) print(dataset.output_shapes) # (32,)
From the documentation:
If elements of the dataset are shaped [B, a0, a1, …], where B may vary from element to element, then for each element in the dataset, the unbatched dataset will contain B consecutive elements of shape [a0, a1, …].
Edit for TF 2.0
(Thanks @DavidParks)
From TF 2.0, you can use directly tf.data.Dataset.unbatch
:
x = np.zeros((1000, 768, 32)) dataset = tf.data.Dataset.from_tensor_slices(x) print(dataset.output_shapes) # (768, 32) dataset = dataset.unbatch() print(dataset.output_shapes) # (32,)