torch.no_grad() and detach() combined
I encountered many code fragments like the following for choosing an action, that include a mix of torch.no_grad
and detach
(where actor
is some actor, SomeDistribution
your preferred distribution), and I'm wondering whether they make sense:
def f():
with torch.no_grad():
x = actor(observation)
dist = SomeDistribution(x)
sample = dist.sample()
return sample.detach()
Is the use of detach
in the return statement not unnecessary, as x has its requires_grad
already set to False, so all computations using x
should already be detached from the graph? Or do the computations after the torch.no_grad
wrapper somehow end up on the graph again, so we need to detach them once again in the end (in which case it seems to me that no_grad
would be unnecessary)?
Also, if I'm right, I suppose instead of omitting detach
one could also omit torch.no_grad
, and end up with the same functionality, but worse performance, so torch.no_grad
is to be preferred?
While it may be redundant, it depends on the internals of actor
and SomeDistribution
. In general, there are three cases I can think of where detach
would be necessary in this code. Since you've already observed that x
has requires_grad
set to False
then cases 2 and 3 don't apply to your specific case.
- If
SomeDistribution
has internal parameters (leaf tensors withrequires_grad=True
) thendist.sample()
may result in a computation graph connectingsample
to those parameters. Without detaching, that computation graph, including those parameters, would be unnecessarily kept in memory after returning. - The default behavior within a
torch.no_grad
context is to return the result of tensor operations havingrequires_grad
set toFalse
. However, ifactor(observation)
for some reason explicitly setsrequires_grad
of its return value toTrue
before returning, then a computation graph may be created that connectsx
tosample
. Without detaching, that computation graph, includingx
, would be unnecessarily kept in memory after returning. - This one seems even more unlikely, but if
actor(observation)
actually just returns a reference toobservation
, andobservation.requires_grad
isTrue
, then a computation graph all the way fromobservation
tosample
may be constructed duringdist.sample()
.
As for the suggestion of removing the no_grad
context in leu of detach
, this may result in the construction of a computation graph connecting observation
(if it requires gradients) and/or the parameters of the distribution (if it has any) to x
. The graph would be discarded after detach
, but it does take time and memory to create the computation graph, so there may be a performance penalty.
In conclusion, it's safer to do both no_grad
and detach
, though the necessity of either depends on the details of the distribution and actor.