What is the difference between sample() and rsample()?
Using rsample
allows for pathwise derivatives:
The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the
rsample()
method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable.
rsample: sampling using reparameterization trick.
There is eps
in the source code:
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + eps * self.scale
Look at the return: mean + eps
* standard deviation
eps
does not depend on the parameters you want to differentiate with respect to.
So, now you can freely backpropagate(=differentiate) because eps
does not change when the parameters change.