Skip to content
Advertisement

How to help tqdm figure out the total in a custom iterator

I’m implementing my own iterator. tqdm does not show a progressbar, as it does not know the total amount of elements in the list. I don’t want to use “total=” as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()

Is this even possible? What to add to the above code…

Using tqdm like below..

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)

Advertisement

Answer

The original question states:

I don’t want to use “total=” as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.

However, the currently accepted answer explicitly states to use total:

with tqdm(total=len(my_iterable)) as progress_bar:

In fact, the given example is more complicated than it would need to be as the original question did not ask for complex updating of the bar. Hence,

for i in tqdm(my_iterable, total=my_total):
    do_something()

is actually sufficient already (as the author, @emem, already noted in a comment).


This question is relatively old (4 years at the time of writing this), yet looking at tqdm’s code, one can see that already from the very beginning (8 years ago at the time of writing this) the behavior was to default to total = len(iterable) in case total is not given.

Thus, the correct answer to the question is to implement __len__. Which, as is stated in the question, the original example already implements. Hence, it should already work correctly.

A full toy example to test the behavior can be found in the following (please note the comment above the __len__ method):

from time import sleep
from tqdm import tqdm


class Iter:

    def __init__(self, n=10):
        self.n = n
        self.iter = iter(range(n))

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.iter)

    # commenting the next two lines disables showing the bar
    # due to tqdm not knowing the total number of elements:
    def __len__(self):
        return self.n


it = Iter()
for i in tqdm(it):
    sleep(0.2)

Looking at what tqdm does exactly:

try:
    total = len(iterable)
except (TypeError, AttributeError):
    total = None

… and since we do not know exactly what @Duane used as batches, I would think that this is basically just a well hidden typo (self.batches.len()), which causes an AttributeError that is caught within tqdm.

If batches is just a sequence type, then this was probably the intended definition:

    def __len__(self):
        return len(self.batches)

The definition of __next__ (using len(self.batches)) also points in this direction.

User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement