Skip to content
Advertisement

Reading in file names from a tensor in Tensorflow

Context: I am trying to make a GAN to generate images from a large dataset, and have been running into OOM issues when loading in the training data. In an effort to solve this, I am trying to pass in a list of file directories and read them in as images only when needed.

Issue: I do not know how to parse out the file name from the tensor itself. If anyone has any insight on how to convert the tensor back to a list or somehow iterate through the tensor. Or, if this is a bad way to solve this problem, please let me know

Relevant code snippets:

Generating the data: NOTE: make_file_list() returns a list of file names for all the images I want to read in

data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)

training function:

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)

train step:

@tf.function
def train_step(image_files):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = [load_img(filepath) for filepath in image_files]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

Error:

line 250, in train_step  *
        images = [load_img(filepath) for filepath in image_files]

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

Advertisement

Answer

Remove @tf.function decorator on your train_step. If you decorate your train_step with @tf.function, Tensorflow will try to convert the Python code inside train_step into an execution graph instead of operating in eager mode. Execution graphs offer speedup, but also put some constraints on which operators can be performed (as the error stated).

To keep @tf.function on train_step, you can do the iterating and loading step in your train function first, then pass the already loaded image as an argument to train_step instead of trying to load image directly within train_step

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

    for image_batch in dataset:
        images = [load_img(filepath) for filepath in image_batch ]
        gen_loss, disc_loss = train_step(images)

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        ....
User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement