A image to discribe my question
From my point of view, every iterations, the computational graph will be constructed at the first arrow, and it will be used and delete at the second arrow in backward pass. So, why it tells me that:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Here is my code:
def train(num_epoch = 10,len_vocab = 1, num_hidden=256,embedding_dim = 8): data = get_data() model = MyRNN(len_vocab,num_hidden,embedding_dim) if os.path.exists('QingBinLi'): model.load_state_dict(torch.load('QingBinLi')) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5) loss_for_draw = [] for epoch in range(num_epoch+1): h = torch.randn(1,1,num_hidden) loss_average = [] for i in range(data.shape[-2]): optimizer.zero_grad() #I think my computational graph will be constructed there pre,h = model(data[:,:,i,:] ,h) pre = pre.unsqueeze(0).unsqueeze(0) loss = criterion(pre, data[:,:,i+1,:]) loss_average.append(loss) #I think everytime the backward pass will delete the computational graph. loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=10) optimizer.step() print(f"finish {i+1} times") loss_for_draw.append(sum(loss_average)/len(loss_average)) torch.save(model.state_dict(), 'QingBinLi') print(f'now epoch:{epoch}, loss = {loss_for_draw[-1]}') return loss_for_draw
class MyRNN(nn.Module): def __init__(self,len_vocab, num_hidden=256,embedding_dim = 8): super(MyRNN,self).__init__() self.rnn = nn.RNN(embedding_dim, num_hidden) self.num_directions=1 self.output_model = nn.Linear(num_hidden, embedding_dim) def forward(self, x, h): y, h = self.rnn(x, h) output = self.output_model(y.reshape((-1))) return output, h
So, if I’m right, it shouldn’t tell me “Trying to backward through the graph a second time”..
So, where did i go wrong
Advertisement
Answer
Variable h
and data requires gradient, so we must add 2 lines:
h = h.detach() data = data.detach()