I am trying to inspect the labels inside my tensorflow dataset. However, the values of the labels change to something unexpected after using take()
and skip()
, depending on whether I inspect the data or not. (It looks like within the labels some ones changed to zeros.) I do not see any way that my inspection function could change the dataset. What am I missing?
To reproduce the behaviour, change the LOOK_AT_DATA_TWICE
# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf
def inspect_dataset(ds, msg="", print_all=True):
sum_ones = 0
sum_zeros = 0
for (sig, label) in ds.as_numpy_iterator():
if print_all:
print(msg, label, np.histogram(label, bins=2)[0])
sum_ones += np.sum(label)
sum_zeros += np.sum(label - 1)
print(msg, "SUM of ones=", sum_ones)
print(msg, "SUM of zero=", sum_zeros)
all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")
complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))
LOOK_AT_DATA_TWICE = True # This changes the numbers output below
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)
inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)
Output with LOOK_AT_DATA_TWICE = True
all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012
Output with LOOK_AT_DATA_TWICE = False
all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997
When the dataset is exhausted (i.e, after you iterated through it once), it will redo all the operations. In your case, because you are shuffling, the shuffle for the first epoch will be different than the shuffling for the second.
What it means is that your training set and testing set are actually not consistent between epochs.
You can set reshuffle_each_iteration
to the call to shuffle to make the shuffle behave the same at each iteration. If you still want a different shuffle for your train set, you should call it again.
ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)