Why do we need to explicitly call zero_grad()? [duplicate]

We explicitly need to call zero_grad() because, after loss.backward() (when gradients are computed), we need to use optimizer.step() to proceed gradient descent. More specifically, the gradients are not automatically zeroed because these two operations, loss.backward() and optimizer.step(), are separated, and optimizer.step() requires the just computed gradients.

In addition, sometimes, we need to accumulate gradient among some batches; to do that, we can simply call backward multiple times and optimize once.


I have a use case for the current setup in PyTorch.

If one is using a recurrent neural network (RNN) that is making predictions at every step, one might want to have a hyperparameter that allows one to accumulate gradients back in time. Not zeroing the gradients at every time step allows for one to use back-propagating through time (BPTT) in interesting and novel ways.

If you would like more info on BPTT or RNNs see the article Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients or The Unreasonable Effectiveness of Recurrent Neural Networks.


There is a cycle in PyTorch:

  • Forward when we get output or y_hat from the input,
  • Calculating loss where loss = loss_fn(y_hat, y)
  • loss.backward when we calculate the gradients
  • optimizer.step when we update parameters

Or in code:

for mb in range(10): # 10 mini batches
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

If we would not clear the gradients after the optimizer.step, which is the appropriate step or just before the next backward() gradients would accumulate. Here is an example showing accumulation:

import torch
w = torch.rand(5)
w.requires_grad_()
print(w) 
s = w.sum() 
s.backward()
print(w.grad) # tensor([1., 1., 1., 1., 1.])
s.backward()
print(w.grad) # tensor([2., 2., 2., 2., 2.])
s.backward()
print(w.grad) # tensor([3., 3., 3., 3., 3.])
s.backward()
print(w.grad) # tensor([4., 4., 4., 4., 4.])

loss.backward() does not have any way specifying this.

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)

From all the options you can specify there is no way to zero the gradients manually. Like this in previous mini example:

w.grad.zero_()

There was some discussion on doing zero_grad() every time with backward() (obviously previous gradients) and to keep grads with preserve_grads=True, but this never came to life.


Leaving the gradients in place before calling .step() is useful in case you'd like to accumulate the gradient across multiple batches (as others have mentioned).

It's also useful for after calling .step() in case you'd like to implement momentum for SGD, and various other methods may depend on the values from the previous update's gradient.