Skip to content
Advertisement

Problem with data cast on the GPU in PyTorch

Im trying to do an image classifier, but im having a problem with the data cast on the GPU.

def train(train_loader, net, epoch):

  # Training mode
  net.train()
  
  start = time.time()
  
  epoch_loss  = []
  pred_list, label_list = [], []

  for batch in train_loader:

    #Batch cast on the GPU
    input, label = batch
    input.to(args['device'])
    label.to(args['device'])
    
    #Forward
    ypred = net(input)
    loss = criterion(ypred, label)
    epoch_loss.append(loss.cpu().data)

    _, pred = torch.max(ypred, axis=1) 
    pred_list.append(pred.cpu().numpy())
    label_list.append(label.cpu().numpy())

    #Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  epoch_loss = np.asarray(epoch_loss)
  pred_list  = np.asarray(pred_list).ravel()
  label_list  = np.asarray(label_list).ravel()

  acc = accuracy_score(pred_list, label_list)
  
  end = time.time()
  print('#################### Train ####################')
  print('Epoch %d, Loss: %.4f +/- %.4f, Acc: %.2f, Time: %.2f' % (epoch, epoch_loss.mean(), 
  epoch_loss.std(), acc*100, end-start))
  
  return epoch_loss.mean()


for epoch in range(args['epoch_num']):
  train(train_loader, net, epoch)
  break #Testing

Model already is in cuda, but i get error that says

Input type is torch.FloatTensor and not torch.cuda.FloatTensor

Whats the problem with input.to(args['device'])?

Advertisement

Answer

UPDATE: According to the OP, an aditional data.to(device) before the train loop caused this issue.

you are probably getting a string like 0 or cuda from args[‘device’]; you should do this:

'cpu') #pass your args['device'] ``` so then use `device` to move the
model to GPU:  ``` model.to(device) ```

then call the model with:

``` for batch,(data,label) in enumerate(train_loader):

    #Batch cast on the GPU
    data.to(device =device)
    label.to(device =device)

User contributions licensed under: CC BY-SA
9 People found this is helpful
Advertisement