Applying non-torch function on loss before calling backward()?
In pytorch you can easily do this by inheriting from torch.autograd.Function
: All you need to do is implement your custom forward()
and the corresponding backward()
methods. Because I don't know the function you intend to write, I'll demonstrate it by implementing the sine function in a way that works with the automatic differentiation. Note that you need to have a method to compute the derivative of your function with respect to its input to implement the backward pass.
import torch
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
""" compute forward pass of custom function """
ctx.save_for_backward(inp) # save activation for backward pass
return inp.sin() # compute forward pass, can also be computed by any other library
@staticmethod
def backward(ctx, grad_out):
""" compute product of output gradient with the
jacobian of your function evaluated at input """
inp, = ctx.saved_tensors
grad_inp = grad_out * torch.cos(inp) # propagate gradient, can also be computed by any other library
return grad_inp
To use it you can use the function sin = MySin.apply
on your input.
There is also another example worked out in the documentation.