我想我理解的用法
src_key_padding_mask
幸亏
Difference between src_mask and src_key_padding_mask
。然而,我期待着
src_key_padding_mask
以使输出为零或掩码值的负无穷大。只是想知道我是否正确使用了它,或者是否需要修改下面的代码片段。
请注意,我知道我需要使用位置编码,并且我没有故意使用它。
import random
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
random.seed(42)
torch.manual_seed(42)
DIM = 5
BATCH = 2
x = [torch.randn(random.randint(1, 3), DIM) for _ in range(2)]
mask = pad_sequence([torch.LongTensor([1]*len(elem)) for elem in x]) == 0
padded_x = pad_sequence(x)
encoder_layer = nn.TransformerEncoderLayer(d_model=DIM, nhead=1)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=2).eval()
# the output of the following two are the same except for where it is masked. I was expecting zeros:
out1 = encoder(padded_x, src_key_padding_mask=mask.T)
out2 = encoder(padded_x)