Context: I am trying to make a GAN to generate images from a large dataset, and have been running into OOM issues when loading in the training data. In an effort to solve this, I am trying to pass in a list of file directories and read them in as images only when needed.
Issue: I do not know how to parse out the file name from the tensor itself. If anyone has any insight on how to convert the tensor back to a list or somehow iterate through the tensor. Or, if this is a bad way to solve this problem, please let me know
Relevant code snippets:
Generating the data:
NOTE: make_file_list()
returns a list of file names for all the images I want to read in
data = make_file_list(base_dir) train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) train(train_dataset, EPOCHS)
training function:
def train(dataset, epochs): plot_iteration = [] gen_loss_l = [] disc_loss_l = [] for epoch in range(epochs): start = time.time() for image_batch in dataset: gen_loss, disc_loss = train_step(image_batch)
train step:
@tf.function def train_step(image_files): noise = tf.random.normal([BATCH_SIZE, noise_dim]) images = [load_img(filepath) for filepath in image_files] with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True)
Error:
line 250, in train_step * images = [load_img(filepath) for filepath in image_files] OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature
Advertisement
Answer
Remove @tf.function
decorator on your train_step
. If you decorate your train_step
with @tf.function, Tensorflow will try to convert the Python code inside train_step
into an execution graph instead of operating in eager mode. Execution graphs offer speedup, but also put some constraints on which operators can be performed (as the error stated).
To keep @tf.function
on train_step
, you can do the iterating and loading step in your train
function first, then pass the already loaded image as an argument to train_step
instead of trying to load image directly within train_step
def train(dataset, epochs): plot_iteration = [] gen_loss_l = [] disc_loss_l = [] for epoch in range(epochs): start = time.time() for image_batch in dataset: images = [load_img(filepath) for filepath in image_batch ] gen_loss, disc_loss = train_step(images) @tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, noise_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) ....