Reading Tensorflow Dataset changes bahaviour of `take()` and `skip()`

Tags: , ,



I am trying to inspect the labels inside my tensorflow dataset. However, the values of the labels change to something unexpected after using take() and skip(), depending on whether I inspect the data or not. (It looks like within the labels some ones changed to zeros.) I do not see any way that my inspection function could change the dataset. What am I missing?

To reproduce the behaviour, change the LOOK_AT_DATA_TWICE variable.

# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf

tf.random.set_seed(42)


def inspect_dataset(ds, msg="", print_all=True):
    sum_ones = 0
    sum_zeros = 0
    for (sig, label) in ds.as_numpy_iterator():
        if print_all:
            print(msg, label, np.histogram(label, bins=2)[0])
        sum_ones += np.sum(label)
        sum_zeros += np.sum(label - 1)

    print(msg, "SUM of ones=", sum_ones)
    print(msg, "SUM of zero=", sum_zeros)


all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])

print(f"all_pattern.shape={all_pattern.shape}")
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")

complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))

LOOK_AT_DATA_TWICE = True  # This changes the numbers output below
if LOOK_AT_DATA_TWICE:
    inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)

validation_split=0.5
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)

inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)

Output with LOOK_AT_DATA_TWICE = True:

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]

complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012

Output with LOOK_AT_DATA_TWICE = False:

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
   
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997

Answer

When the dataset is exhausted (i.e, after you iterated through it once), it will redo all the operations. In your case, because you are shuffling, the shuffle for the first epoch will be different than the shuffling for the second.

What it means is that your training set and testing set are actually not consistent between epochs.

You can set reshuffle_each_iteration to the call to shuffle to make the shuffle behave the same at each iteration. If you still want a different shuffle for your train set, you should call it again.

ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)


Source: stackoverflow