How does Pytorch's "Fold" and "Unfold" work?
unfold
and fold
are used to facilitate "sliding window" operation (like convolutions).
Suppose you want to apply a function foo
to every 5x5 window in a feature map/image:
from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)
Now windows
has size
of batch-(5*5*x.size(1)
)-num_windows, you can apply foo
on windows
:
processed = foo(windows)
Now you need to "fold" processed
back to the original size of x
:
out = f.fold(processed, x.shape[-2:], kernel_size=5)
You need to take care of padding
, and kernel_size
that may affect your ability to "fold" back processed
to the size of x
.
Moreover, fold
sums over overlapping elements, so you might want to divide the output of fold
by patch size.
unfold
imagines a tensor as a longer tensor with repeated columns/rows of values 'folded' on top of each other, which is then "unfolded":
-
size
determines how large the folds are -
step
determines how often it is folded
E.g. for a 2x5 tensor, unfolding it with step=1
, and patch size=2
across dim=1
:
x = torch.tensor([[1,2,3,4,5],
[6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1, 2], [ 2, 3], [ 3, 4], [ 4, 5]],
[[ 6, 7], [ 7, 8], [ 8, 9], [ 9, 10]]])
fold
is roughly the opposite of this operation, but "overlapping" values are summed in the output.