What is the difference between sample() and rsample()?

Using rsample allows for pathwise derivatives:

The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the rsample() method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable.


rsample: sampling using reparameterization trick.

There is eps in the source code:

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
        return self.loc + eps * self.scale

Look at the return: mean + eps * standard deviation

eps does not depend on the parameters you want to differentiate with respect to.

So, now you can freely backpropagate(=differentiate) because eps does not change when the parameters change.