I am trying to inspect the labels inside my tensorflow dataset. However, the values of the labels change to something unexpected after using take()
and skip()
, depending on whether I inspect the data or not. (It looks like within the labels some ones changed to zeros.) I do not see any way that my inspection function could change the dataset. What am I missing?
To reproduce the behaviour, change the LOOK_AT_DATA_TWICE
variable.
# python 3.9.4, tensorflow 2.5.0-rc1 import numpy as np import tensorflow as tf tf.random.set_seed(42) def inspect_dataset(ds, msg="", print_all=True): sum_ones = 0 sum_zeros = 0 for (sig, label) in ds.as_numpy_iterator(): if print_all: print(msg, label, np.histogram(label, bins=2)[0]) sum_ones += np.sum(label) sum_zeros += np.sum(label - 1) print(msg, "SUM of ones=", sum_ones) print(msg, "SUM of zero=", sum_zeros) all_pattern = np.random.random((4000, 1000)) all_labels = np.array(2000 * [0] + 2000 * [1]) print(f"all_pattern.shape={all_pattern.shape}") print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}") print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}") complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels)) complete_ds = complete_ds.shuffle(len(all_labels)) LOOK_AT_DATA_TWICE = True # This changes the numbers output below if LOOK_AT_DATA_TWICE: inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False) inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False) validation_split=0.5 num_test_samples = int(validation_split * len(all_labels)) train_ds = complete_ds.skip(num_test_samples) val_ds = complete_ds.take(num_test_samples) inspect_dataset(train_ds, msg="train_ds in generation", print_all=False) inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)
Output with LOOK_AT_DATA_TWICE = True
:
all_pattern.shape=(4000, 1000) all_labels.shape=(4000,), sum(all_labels)=2000 Creating dataset from labels hist: [2000 2000] complete_ds in gerneration SUM of ones= 2000 complete_ds in gerneration SUM of zero= -2000 complete_ds in gerneration SUM of ones= 2000 complete_ds in gerneration SUM of zero= -2000 train_ds in generation SUM of ones= 997 train_ds in generation SUM of zero= -1003 val_ds in generation SUM of ones= 988 val_ds in generation SUM of zero= -1012
Output with LOOK_AT_DATA_TWICE = False
:
all_pattern.shape=(4000, 1000) all_labels.shape=(4000,), sum(all_labels)=2000 Creating dataset from labels hist: [2000 2000] complete_ds in gerneration SUM of ones= 2000 complete_ds in gerneration SUM of zero= -2000 train_ds in generation SUM of ones= 1031 train_ds in generation SUM of zero= -969 val_ds in generation SUM of ones= 1003 val_ds in generation SUM of zero= -997
Advertisement
Answer
When the dataset is exhausted (i.e, after you iterated through it once), it will redo all the operations. In your case, because you are shuffling, the shuffle for the first epoch will be different than the shuffling for the second.
What it means is that your training set and testing set are actually not consistent between epochs.
You can set reshuffle_each_iteration
to the call to shuffle to make the shuffle behave the same at each iteration. If you still want a different shuffle for your train set, you should call it again.
ds = tf.data.Dataset.from_tensor_slices(data) shuffled_ds = ds.shuffle(reshuffle_each_iteration=False) train_ds = shuffled_ds.take(train_size) train_ds = train_ds.shuffle() test_ds = shuffled_ds.skip(train_size)