代码之家  ›  专栏  ›  技术社区  ›  Sean

如果Transformer接收一批不同的句子作为输入,掩码如何在Transformer中工作?

  •  0
  • Sean  · 技术社区  · 5 年前

    我目前正在研究Transformer模型的PyTorch实现,并有一个问题。

    现在,我已经对模型进行了编码,使其能够分批接收源和目标句子对。这些句子使用预先制作的词汇表中各自的索引进行编码。例如:

    [[3,  2,  1, 23, 13, 50, 541, 0],
     [3, 24, 13,  0,  0,  0,   0, 0],
     [3, 98,  2,  4,  1,  23, 25, 4]]
    

    哪里 0 是填充索引。

    我的问题是,如果这些句子是分批输入的,我们应该如何使用掩蔽机制。我想我之所以感到困惑,是因为我知道面具看起来像:

    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]
    

    这样我们就可以强制解码器只处理下一个序列。在运行模型时,我们是否迭代地将这个掩码应用于同一个句子?例如,如果我们使用我上面给出的第一句话:

    # Iteration 1
    [3, 0, 0, 0, 0, 0, 0, 0]
    
    # Iteration 2
    [3, 2, 0, 0, 0, 0, 0, 0]
    
    .
    .
    .
    

    因此,我们可以在每个位置对每批中的每个句子进行预测?

    0 回复  |  直到 5 年前
        1
  •  0
  •   daveboat    5 年前

    你的解码器批处理掩码是一个较低的三角形掩码(你有),按元素排列,并带有一个填充掩码,在标题不是填充值的情况下是正确的。以下是一些玩具代码(主要来自 https://github.com/SamLynnEvans/Transformer )以生成这样的掩模为例:

    import numpy as np
    import torch
    
    def lower_triangular_mask(size):
        """
        Create a lower triangular mask
        """
    
        lt_mask = np.triu(np.ones((1, size, size)), k=1)
        lt_mask = torch.from_numpy(lt_mask) == 0
    
        return lt_mask
    
    def create_mask(caption, pad_value):
        """
        Creates the transformer decode mask
        """
    
        # create pad mask
        pad_mask = (caption != pad_value).unsqueeze(-2)
    
        # create lower triangular mask
        size = caption.size(1)
        lt_mask = lower_triangular_mask(size)
    
        # return the bitwise AND of the two masks
        return pad_mask & lt_mask
    
    if __name__ == '__main__':
        torch.manual_seed(0)
    
        # Here, we generate some random sequences, with an assigned pad value of 1
        pad_value = 1
        caption = torch.randint(2, 10, size=(2, 5))
        caption[0, 3] = pad_value
        caption[0, 4] = pad_value
    
        print(caption)
    
        mask = create_mask(caption, pad_value)
        print(mask.size())
        print(mask)
    

    上述代码返回

    tensor([[6, 9, 7, 1, 1],
            [5, 5, 9, 3, 5]])
    torch.Size([2, 5, 5])
    tensor([[[ True, False, False, False, False],
             [ True,  True, False, False, False],
             [ True,  True,  True, False, False],
             [ True,  True,  True, False, False],
             [ True,  True,  True, False, False]],
    
            [[ True, False, False, False, False],
             [ True,  True, False, False, False],
             [ True,  True,  True, False, False],
             [ True,  True,  True,  True, False],
             [ True,  True,  True,  True,  True]]])
    

    第一个字幕比第二个字幕短,导致一个掩码,不允许变换器在序列结束后出现。