代码之家  ›  专栏  ›  技术社区  ›  ytrewq

PyTorch:使用numpy数组为GRU/LSTM手动设置权重参数

  •  8
  • ytrewq  · 技术社区  · 6 年前

    我有一个numpy数组,用于在文档中定义形状的参数( https://pytorch.org/docs/stable/nn.html#torch.nn.GRU ).

    它似乎工作,但我不确定返回值是否正确。

    这是用numpy参数填充GRU/LSTM的正确方法吗?

    gru = nn.GRU(input_size, hidden_size, num_layers,
                  bias=True, batch_first=False, dropout=dropout, bidirectional=bidirectional)
    
    def set_nn_wih(layer, parameter_name, w, l0=True):
        param = getattr(layer, parameter_name)
        if l0:
            for i in range(3*hidden_size):
                param.data[i] = w[i*input_size:(i+1)*input_size]
        else:
            for i in range(3*hidden_size):
                param.data[i] = w[i*num_directions*hidden_size:(i+1)*num_directions*hidden_size]
    
    def set_nn_whh(layer, parameter_name, w):
        param = getattr(layer, parameter_name)
        for i in range(3*hidden_size):
            param.data[i] = w[i*hidden_size:(i+1)*hidden_size]
    
    l0=True
    
    for i in range(num_directions):
        for j in range(num_layers):
            if j == 0:
                wih = w0[i, :, :3*input_size]
                whh = w0[i, :, 3*input_size:]  # check
                l0=True
            else:
                wih = w[j-1, i, :, :num_directions*3*hidden_size]
                whh = w[j-1, i, :, num_directions*3*hidden_size:]
                l0=False
    
            if i == 0:
                set_nn_wih(
                    gru, "weight_ih_l{}".format(j), torch.from_numpy(wih.flatten()),l0)
                set_nn_whh(
                    gru, "weight_hh_l{}".format(j), torch.from_numpy(whh.flatten()))
            else:
                set_nn_wih(
                    gru, "weight_ih_l{}_reverse".format(j), torch.from_numpy(wih.flatten()),l0)
                set_nn_whh(
                    gru, "weight_hh_l{}_reverse".format(j), torch.from_numpy(whh.flatten()))
    
    y, hn = gru(x_t, h_t)
    

    numpy数组定义如下:

    rng = np.random.RandomState(313)
    w0 = rng.randn(num_directions, hidden_size, 3*(input_size +
                   hidden_size)).astype(np.float32)
    w = rng.randn(max(1, num_layers-1), num_directions, hidden_size,
                  3*(num_directions*hidden_size + hidden_size)).astype(np.float32)
    
    1 回复  |  直到 6 年前
        1
  •  20
  •   cleros    5 年前

    这是个好问题,你已经给出了一个像样的答案。然而,它重新发明了轮子-有一个非常优雅的Pytorch内部例行程序,将允许你做同样的不需要太多的努力-和一个适用于任何网络。

    这里的核心概念是PyTorch的 state_dict . 州词典有效地包含了 parameters 由树形结构组织给出的关系式 nn.Modules 以及它们的子模块等。

    如果您只希望代码使用 ,然后尝试这一行(如果 dict 国家法令 ):

    `model.load_state_dict(dict, strict=False)`
    

    strict=False 如果你想加载 只有一些参数值 .

    答案很长-包括介绍PyTorch的 国家法令

    input_size = hidden_size = 2 这样我就可以打印整个州的记录了):

    rnn = torch.nn.GRU(2, 2, 1)
    rnn.state_dict()
    # Out[10]: 
    #     OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
    #                         [ 0.3373,  0.0070],
    #                         [ 0.0745, -0.5345],
    #                         [ 0.5347, -0.2373],
    #                         [-0.2217, -0.2824],
    #                         [-0.2983,  0.4771]])),
    #                 ('weight_hh_l0', tensor([[-0.2837, -0.0571],
    #                         [-0.1820,  0.6963],
    #                         [ 0.4978, -0.6342],
    #                         [ 0.0366,  0.2156],
    #                         [ 0.5009,  0.4382],
    #                         [-0.7012, -0.5157]])),
    #                 ('bias_ih_l0',
    #                 tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
    #                 ('bias_hh_l0',
    #                 tensor([-0.1845,  0.4075, -0.1721, -0.4893, -0.2427,  0.3973]))])
    

    所以 国家法令

    class MLP(torch.nn.Module):      
        def __init__(self):
            torch.nn.Module.__init__(self)
            self.lin_a = torch.nn.Linear(2, 2)
            self.lin_b = torch.nn.Linear(2, 2)
    
    
    mlp = MLP()
    mlp.state_dict()
    #    Out[23]: 
    #        OrderedDict([('lin_a.weight', tensor([[-0.2914,  0.0791],
    #                            [-0.1167,  0.6591]])),
    #                    ('lin_a.bias', tensor([-0.2745, -0.1614])),
    #                    ('lin_b.weight', tensor([[-0.4634, -0.2649],
    #                            [ 0.4552,  0.3812]])),
    #                    ('lin_b.bias', tensor([ 0.0273, -0.1283]))])
    
    
    class NestedMLP(torch.nn.Module):
        def __init__(self):
            torch.nn.Module.__init__(self)
            self.mlp_a = MLP()
            self.mlp_b = MLP()
    
    
    n_mlp = NestedMLP()
    n_mlp.state_dict()
    #   Out[26]: 
    #        OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543,  0.3412],
    #                            [-0.1984, -0.3235]])),
    #                    ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
    #                    ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
    #                            [-0.0100,  0.5887]])),
    #                    ('mlp_a.lin_b.bias', tensor([-0.3116,  0.5603])),
    #                    ('mlp_b.lin_a.weight', tensor([[ 0.3722,  0.6940],
    #                            [-0.5120,  0.5414]])),
    #                    ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
    #                    ('mlp_b.lin_b.weight', tensor([[-0.5571,  0.0830],
    #                            [ 0.5230, -0.1020]])),
    #                    ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])
    

    所以-如果你不想提取状态dict,而是改变它-从而改变网络的参数呢?使用 nn.Module.load_state_dict(state_dict, strict=True) ( link to the docs ) 此方法允许您用任意值加载整个state目录 只要键(即参数名)正确,值(即参数)正确 torch.tensors 形状正确。 如果 strict kwarg设置为 True (默认设置),加载的dict必须与原始状态dict完全匹配,参数值除外。也就是说,每个参数必须有一个新值。

    对于上面的GRU示例,我们需要一个正确大小的张量(以及正确的设备,顺便说一句) 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' . 因为我们有时只想装 一些 值(正如我认为你想做的),我们可以设置 严格的 False -然后我们可以只加载部分状态dict,例如只包含 'weight_ih_l0' .

    作为一个实用的建议,我只需要创建一个模型,然后打印state dict(或者至少是一个键列表和相应的张量大小)

    print([k, v.shape for k, v in model.state_dict().items()])
    

    这将告诉您要更改的参数的确切名称。然后,您只需使用相应的参数名和张量创建一个state dict,并加载它:

    from dollections import OrderedDict
    new_state_dict = OrderedDict({'tensor_name_retrieved_from_original_dict': new_tensor_value})
    model.load_state_dict(new_state_dict, strict=False)
    
        2
  •  5
  •   Leo Brueggeman    5 年前

    model.state_dict()["your_weight_names_here"][:] = torch.Tensor(your_numpy_array)