Pytorch preferred way to copy a tensor
Solution 1:
TL;DR
Use .clone().detach()
(or preferrably .detach().clone()
)
If you first detach the tensor and then clone it, the computation path is not copied, the other way around it is copied and then abandoned. Thus,
.detach().clone()
is very slightly more efficient.-- pytorch forums
as it's slightly fast and explicit in what it does.
Using perflot
, I plotted the timing of various methods to copy a pytorch tensor.
y = tensor.new_tensor(x) # method a
y = x.clone().detach() # method b
y = torch.empty_like(x).copy_(x) # method c
y = torch.tensor(x) # method d
y = x.detach().clone() # method e
The x-axis is the dimension of tensor created, y-axis shows the time. The graph is in linear scale. As you can clearly see, the tensor()
or new_tensor()
takes more time compared to other three methods.
Note: In multiple runs, I noticed that out of b, c, e, any method can have lowest time. The same is true for a and d. But, the methods b, c, e consistently have lower timing than a and d.
import torch
import perfplot
perfplot.show(
setup=lambda n: torch.randn(n),
kernels=[
lambda a: a.new_tensor(a),
lambda a: a.clone().detach(),
lambda a: torch.empty_like(a).copy_(a),
lambda a: torch.tensor(a),
lambda a: a.detach().clone(),
],
labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
n_range=[2 ** k for k in range(15)],
xlabel="len(a)",
logx=False,
logy=False,
title='Timing comparison for copying a pytorch tensor',
)
Solution 2:
According to Pytorch documentation #a and #b are equivalent. It also say that
The equivalents using clone() and detach() are recommended.
So if you want to copy a tensor and detach from the computation graph you should be using
y = x.clone().detach()
Since it is the cleanest and most readable way. With all other version there is some hidden logic and it is also not 100% clear what happens to the computation graph and gradient propagation.
Regarding #c: It seems a bit to complicated for what is actually done and could also introduces some overhead but I am not sure about that.
Edit: Since it was asked in the comments why not just use .clone()
.
From the pytorch docs
Unlike copy_(), this function is recorded in the computation graph. Gradients propagating to the cloned tensor will propagate to the original tensor.
So while .clone()
returns a copy of the data it keeps the computation graph and records the clone operation in it. As mentioned this will lead to gradient propagated to the cloned tensor also propagate to the original tensor. This behavior can lead to errors and is not obvious. Because of these possible side effects a tensor should only be cloned via .clone()
if this behavior is explicitly wanted. To avoid these side effects the .detach()
is added to disconnect the computation graph from the cloned tensor.
Since in general for a copy operation one wants a clean copy which can't lead to unforeseen side effects the preferred way to copy a tensors is .clone().detach()
.