How to check if two Torch tensors or matrices are equal?
Solution 1:
torch.eq(a, b)
eq()
implements the ==
operator comparing each element in a
with b
(if b is a value) or each element in a
with its corresponding element in b
(if b
is a tensor).
Alternative from @deltheil:
torch.all(tens_a.eq(tens_b))
Solution 2:
This below solution worked for me:
torch.equal(tensorA, tensorB)
From the documentation:
True
if two tensors have the same size and elements,False
otherwise.
Solution 3:
To compare tensors you can do element wise:
torch.eq
is element wise:
torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])
Or torch.equal
for the whole tensor exactly:
torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True
But then you may be lost because at some point there are small differences you would like to ignore. For instance floats 1.0
and 1.0000000001
are pretty close and you may consider these are equal. For that kind of comparison you have torch.allclose
.
torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True
At some point may be important to check element wise how many elements are equal, comparing to the full number of elements. If you have two tensors dt1
and dt2
you get number of elements of dt1
as dt1.nelement()
And with this formula you get the percentage:
print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())
Solution 4:
Try this if you want to ignore small precision differences which are common for floats
torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))