I want to train a single variational autoencoder model or even a standard autoencoder over many datasets jointly (e.g. mnist, cifar, svhn, etc. where all the images in the datasets are resized to be the same input shape). Here is the VAE tutorial in tensorflow which I am using as a starting point: https://www.tensorflow.org/tutorials/generative/cvae.
For training the model, I would want to sample (choose) a dataset from my set of datasets and then obtain a batch of images from that dataset at each gradient update step in the training loop. I could combine all the datasets into one big dataset, but I want to leverage that the images in a given batch come from the same dataset as side information (I’m still figuring out this part, but the details aren’t too important since my question focuses on the data pipeline).
I am not sure how exactly to go about the data pipeline setup. The tutorial specifies the dataset pipeline as follows:
train_dataset = (tf.data.Dataset.from_tensor_slices(train_images) .shuffle(train_size).batch(batch_size)) test_dataset = (tf.data.Dataset.from_tensor_slices(test_images) .shuffle(test_size).batch(batch_size))
where train_images and test_images are the processed MNIST data. So it creates a tensorflow dataset, shuffles the entire dataset, and batches the data into batches of size batch_size. In my case, I assume I would want to create a separate train_dataset/test_dataset for each dataset in my set of datasets (e.g. cifar_train_dataset/cifar_test_dataset, mnist_train_dataset/mnist_test_dataset, etc.).
When it comes to training, they specify the procedure as follows:
for epoch in range(1, epochs + 1): for train_x in train_dataset: train_step(model, train_x, optimizer) loss = tf.keras.metrics.Mean() for test_x in test_dataset: loss(compute_loss(model, test_x)) elbo = -loss.result() print('Epoch: {}, Test set ELBO: {})
Instead of specifying epochs, I could just specify a total number of training iterations/steps (e.g. 500,000). Within each training step, I would want to sample a dataset from the set of datasets (assuming equal probabilities) instead of assuming a single training dataset as above.
Now comes the part I’m not sure about. The line for train_x in train_dataset
is a loop that iterates over the entire dataset in batches. Instead, I would just want to obtain a single batch of images for the given dataset I have sampled, make a model update, and repeat the process. However, I am not sure if specifying datasets as I have described above provides this flexibility? Is there any way to index a batch/obtain a single batch as opposed to iterating over all batches.
In summary, I want to train a single model over multiple datasets by sampling a batch of images from a given dataset at each training step when making model updates. I am completely open to other suggestions and approaches that address this problem. Thanks!
Advertisement
Answer
If I understand your question correctly, you want to control the number of batches that you pull from your train and test sets, instead of iterating over them completely before doing an update. You can turn your dataset into an iterator by wrapping it in iter()
and use the next()
method to grab the next batch.
Example:
import numpy as np import tensorflow as tf # fake mnist data train_imgs = tf.random.normal([100, 28, 28, 1]) test_imgs = tf.random.normal([100, 28, 28, 1]) train_labels = tf.one_hot( tf.random.uniform([100,], minval=0, maxval=10, dtype=tf.int64), 10) test_labels = tf.one_hot( tf.random.uniform([100,], minval=0, maxval=10, dtype=tf.int64), 10) # create train/test dataset train_ds = tf.data.Dataset.from_tensor_slices((train_imgs, train_labels)) train_ds = train_ds.repeat().shuffle(1 << 6).batch(8) test_ds = tf.data.Dataset.from_tensor_slices((test_imgs, train_labels)) test_ds = test_ds.repeat().shuffle(1 << 6).batch(8) # simple mnist network x_in = tf.keras.Input((28, 28, 1)) x = tf.keras.layers.Flatten()(x_in) x = tf.keras.layers.Dense(100)(x) x_out = tf.keras.layers.Dense(10)(x) # simple mnist model model = tf.keras.Model(x_in, x_out) # make datasets iterators train_iter = iter(train_ds) test_iter = iter(test_ds) # loss def xent_loss(y_true, y_pred): ce = tf.keras.losses.CategoricalCrossentropy() return ce(y_true, y_pred) # simple training loop where you control the batches per epoch # for your train and test datasets NUM_EPOCHS = 10 NUM_TRAIN_BATCHES_PER_EPOCH = 20 NUM_TEST_BATCHES_PER_EPOCH = 5 for epoch in range(NUM_EPOCHS): train_losses = [] # train for _ in range(NUM_TRAIN_BATCHES_PER_EPOCH): X_train, y_train = next(train_iter) y_hat = model(X_train) loss = xent_loss(y_train, y_hat) train_losses.append(loss) # do gradient update ... # report train loss print(f"epoch: {epoch}ttrain_loss: {np.mean(train_losses):.4f}") train_losses = [] # validate test_losses = [] for _ in range(NUM_TEST_BATCHES_PER_EPOCH): X_test, y_test = next(test_iter) y_hat = model(X_test) loss = xent_loss(y_test, y_hat) test_losses.append(loss) # report validation loss print(f"epoch: {epoch}ttest_loss: {np.mean(test_losses):.4f}") test_losses = [] print('-' * 40) # epoch: 0 train_loss: 7.3092 # epoch: 0 test_loss: 7.3427 # ---------------------------------------- # epoch: 1 train_loss: 6.8050 # epoch: 1 test_loss: 8.4867 # ----------------------------------------