Torch backward do not return a tensor

To set up the problem I have as input a matrix X and as output a matrix Y, we obtain Y by a matrix multiplication with W that I pass through an exponential. From what I have understood from torch.backward() with a gradient parameter the formula should be the following

enter image description here

Yet dy_over_dx as a jacobian should be a tensor of sort of size enter image description here (I mean not the usual n by n dimensional matrix).

X = torch.tensor( [[2.,1.,-3], [-3,4,2]], requires_grad=True)
W = torch.tensor( [ [3.,2.,1.,-1] , [2,1,3,2 ] , [3,2,1,-2] ], requires_grad=True)

Y = torch.exp(torch.matmul(X, W))
Y.retain_grad()
print(Y)

dL_over_dy = torch.tensor(  [[2,3,-3,9],[-8,1,4,6]])
print(dL_over_dy, dL_over_dy.shape)

Y.backward(dL_over_dy)
print(X.grad)
tensor([[3.6788e-01, 3.6788e-01, 7.3891e+00, 4.0343e+02],
        [1.4841e+02, 7.3891e+00, 5.9874e+04, 1.0966e+03]],
       grad_fn=<ExpBackward>)
tensor([[ 2,  3, -3,  9],
        [-8,  1,  4,  6]]) torch.Size([2, 4])
tensor([[ -3648.6118,   7197.7920,  -7279.4707],
        [229369.6250, 729282.0625, 222789.8281]])

Next if I look at the gradient of Y which I suppose is dy_over_dx I obtain, what do I do not understand here ?

print(Y.grad)
tensor([[ 2.,  3., -3.,  9.],
        [-8.,  1.,  4.,  6.]])

What you're looking at here is Y.grad, which is dL/dY i.e. none other than dL_over_dy.

To help clarify, let Z = X @ Y (@ is equivalent to matmul), and Y = exp(Z). Then we have with the chain-rule:

  • Y.grad = dL/dY

  • Z.grad = dL/dZ = dL/dY . dY/dZ, where dY/dZ = exp(Z) = Y

  • X.grad = dL/dX = dL/dZ . dZ/dX, where dZ/dX = d(X@W)/dX = W.T


Here is the implementation:

X = torch.tensor([[ 2., 1., -3], 
                  [ -3, 4., 2.]], requires_grad=True)

W = torch.tensor([[ 3., 2., 1., -1], 
                  [ 2., 1., 3., 2.], 
                  [ 3., 2., 1., -2]], requires_grad=True)

Z = torch.matmul(X, W)
Z.retain_grad()
Y = torch.exp(Z)

dL_over_dy = torch.tensor([[ 2., 3., -3, 9.],
                           [ -8, 1., 4., 6.]])

Y.backward(dL_over_dy)

Then we have

>>> dL_over_Z = dL_over_dy*Y
tensor([[ 7.3576e-01,  1.1036e+00, -2.2167e+01,  3.6309e+03],
        [-1.1873e+03,  7.3891e+00,  2.3950e+05,  6.5798e+03]],
       grad_fn=<MulBackward0>)

>>> dL_over_X = dL_over_Z @ W.T
tensor([[ -3648.6118,   7197.7920,  -7279.4707],
        [229369.6250, 729282.0625, 222789.8281]], grad_fn=<MmBackward0>)