What is the equivalent of PyTorch's BoolTensor in Tensorflow 2.x?

Is there an equivalent of BoolTensor from Pytorch in Tensorflow assuming I have the below usage in Pytorch that I want to migrate to Tensorflow

done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0

Solution 1:

What is dones? Assuming it's a 0/1 tensor, you can convert it to a Bool tensor like this:

tf.cast(dones,tf.bool)

However, if you want to assign values to a tensor, you can't do it that way.

A way, which I recommend, is to multiply by a matrix of 1/0:

next_state_values *= tf.cast(dones!=1,next_state_values.dtype)

Another way , that I don't recommend as it gives some issues when using the gradient, is to use tf.tensor_scatter_nd_update. For your case, that would be:

indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))