this is my Define validate function
when I load the model and start prediction using this code I have received the error using PyTorch.and after this, I am iterating through the epoch loop and batch loop and I landed with this error.
def validate_epoch(net, val_loader,loss_type='CE'): net.train(False) running_loss = 0.0 sm = nn.Softmax(dim=1) truth = [] preds = [] bar = tqdm(total=len(val_loader), desc='Processing', ncols=90) names_all = [] n_batches = len(val_loader) for i, (batch, targets, names) in enumerate(val_loader): if loss_type == 'CE': labels = Variable(targets.float()) inputs = Variable(batch) elif loss_type == 'MSE': labels = Variable(targets.float()) inputs = Variable(batch) outputs = net(inputs) labels = labels.long() loss = criterion(outputs, labels) if loss_type =='CE': probs = sm(outputs).data.cpu().numpy() elif loss_type =='MSE': probs = outputs probs[probs < 0] = 0 probs[probs > 4] = 4 probs = probs.view(1,-1).squeeze(0).round().data.cpu().numpy() preds.append(probs) truth.append(targets.cpu().numpy()) names_all.extend(names) running_loss += loss.item() bar.update(1) gc.collect() gc.collect() bar.close() if loss_type =='CE': preds = np.vstack(preds) else: preds = np.hstack(preds) truth = np.hstack(truth) return running_loss / n_batches, preds, truth, names_all
And this is the main function where I call validate function get the error when model is loaded and start prediction on the test loader
criterion = nn.CrossEntropyLoss() model.eval() test_losses = [] test_mse = [] test_kappa = [] test_acc = [] test_started = time.time() test_loss, probs, truth, file_names = validate_epoch(model, test_iterator)
as you can see in traceback error it gives some Terminal shows error:
ValueError Traceback (most recent call last) <ipython-input-27-d2b4a1ca3852> in <module> 12 test_started = time.time() 13 ---> 14 test_loss, probs, truth, file_names = validate_epoch(model, test_iterator) 15 preds = probs.argmax(1) 16 <ipython-input-25-34e29e0ff6ed> in validate_epoch(net, val_loader, loss_type) 9 names_all = [] 10 n_batches = len(val_loader) ---> 11 for i, (batch, targets, names) in enumerate(val_loader): 12 if loss_type == 'CE': 13 labels = Variable(targets.float()) ValueError: not enough values to unpack (expected 3, got 2)
Advertisement
Answer
From torchvision.datasets.ImageFolder documentation:
“Returns: (sample, target) where target is class_index of the target class.”
So, quite simply, the dataset object you’re currently using returns a tuple with 2 items. You’ll get an error if you try to store this tuple in 3 variables. The correct line would be:
for i, (batch, targets) in enumerate(val_loader):
If you really need the names (which I assume is the file path for each image) you can define a new dataset object that inherits from the ImageFolder
dataset and overload the __getitem__
function to also return this information.