代码之家  ›  专栏  ›  技术社区  ›  nad

Pytorch中显示输入大小不匹配的自定义LSTM模型

  •  0
  • nad  · 技术社区  · 5 年前

    我有一个自定义的双向LSTM模型,其中自定义部分是

    - extract the forward and backward last hidden state
    - concat those states
    - create a fully connected layer and pass it through softmax layer.
    

    代码如下所示:

    class customModel(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(customModel, self).__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.bilstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=False, bidirectional=True)
            self.fcl = nn.Linear(hidden_size, num_classes)
    
        def forward(self, x):
            # Set initial hidden and cell states 
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    
            # Forward propagate LSTM
            out, hidden = self.bilstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
            #concat hidden state of forward and backword
            fw_bilstm = out[-1, :, :self.hidden_size]
            bk_bilstm = out[0, :, :self.hidden_size]
            concat_fw_bw = torch.cat((fw_bilstm, bk_bilstm), dim = 1)
            fc = nn.Linear(concat_fw_bw, num_classes)
            x = F.relu(fc(x))
            return F.softmax(x)
    

    我使用以下参数和输入

    input_size = 2
    hidden_size = 32  
    num_layers = 1
    num_classes = 2
    
    input_embedding = [
        torch.FloatTensor([[-0.8264],  [0.2524]]),
        torch.FloatTensor([[-0.3259],  [0.3564]])
    ]
    

    然后我创建一个模型对象

    model = customModel(input_size, hidden_size, num_layers, num_classes)
    

    然后我使用如下:

    for item in input_embedding:
        print(item.size())
        for epoch in range(1):  
            pred = model(item)  
            print (pred)
    

    当我运行它时,我看到了这条线 out, hidden = self.bilstm(x, (h0, c0)) ,显示错误

    RuntimeError: input must have 3 dimensions, got 2
    
    

    我不知道为什么当我显式指定输入时,模型认为输入必须具有3维 input_size=2

    我错过了什么?

    0 回复  |  直到 5 年前
        1
  •  1
  •   Cedias    5 年前

    你好像少了一个( 批量 序列 )输入中的维度。

    两者之间有区别 nn.LSTM nn.LSTMCell . 前者——也就是您使用的那个——将整个序列作为输入。因此,它需要形状的三维输入(序列、批次、输入尺寸)。

    假设您希望以批处理的形式将这4个字母序列(您将其编码为一个热向量)作为输入:

    x0 = [a,b,c]
    x1 = [c,d,e]
    x2 = [e,f,g]
    x3 = [h,i,j]
    
    ### input.size() should give you the following:
    (3,4,8)
    
    • 这个 seq_len 参数是序列的大小:这里是3,
    • 这个 input_size 参数是每个输入向量的大小:这里,输入将是一个大小为8的单热向量,
    • 这个 batch 是您组合的序列数:这里有4个序列。

    注意:将批处理顺序放在第一位并设置 batch_first 是真的

    另外:如果没有提供(h_0,c_0),则h_0和c_0都默认为零,因此创建它们并不有用。