How to compute mean/max of HuggingFace Transformers BERT token embeddings with attention mask?
For max
, you can multiply with attention_mask
:
pooled = torch.max((token_embeddings * attention_mask.unsqueeze(-1)), axis=1)
For mean
, you can sum along the axis and divide by attention_mask
along that axis:
mean_pooled = token_embeddings.sum(axis=1) / attention_mask.sum(axis=-1).unsqueeze(-1)