你的解码器批处理掩码是一个较低的三角形掩码(你有),按元素排列,并带有一个填充掩码,在标题不是填充值的情况下是正确的。以下是一些玩具代码(主要来自
https://github.com/SamLynnEvans/Transformer
)以生成这样的掩模为例:
import numpy as np
import torch
def lower_triangular_mask(size):
"""
Create a lower triangular mask
"""
lt_mask = np.triu(np.ones((1, size, size)), k=1)
lt_mask = torch.from_numpy(lt_mask) == 0
return lt_mask
def create_mask(caption, pad_value):
"""
Creates the transformer decode mask
"""
# create pad mask
pad_mask = (caption != pad_value).unsqueeze(-2)
# create lower triangular mask
size = caption.size(1)
lt_mask = lower_triangular_mask(size)
# return the bitwise AND of the two masks
return pad_mask & lt_mask
if __name__ == '__main__':
torch.manual_seed(0)
# Here, we generate some random sequences, with an assigned pad value of 1
pad_value = 1
caption = torch.randint(2, 10, size=(2, 5))
caption[0, 3] = pad_value
caption[0, 4] = pad_value
print(caption)
mask = create_mask(caption, pad_value)
print(mask.size())
print(mask)
上述代码返回
tensor([[6, 9, 7, 1, 1],
[5, 5, 9, 3, 5]])
torch.Size([2, 5, 5])
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
第一个字幕比第二个字幕短,导致一个掩码,不允许变换器在序列结束后出现。