代码之家  ›  专栏  ›  技术社区  ›  bluesummers

tf。数据数据集。padded\u批处理填充每个功能

  •  11
  • bluesummers  · 技术社区  · 6 年前

    我有一个 tf.data.Dataset 包含3个不同功能的实例

    • label 这是一个标量
    • sequence_feature 这是一个标量序列
    • seq_of_seqs_feature 这是一个序列特征

    我正在尝试使用 tf.data.Dataset.padded_batch() 生成填充数据作为我的模型的输入,我想以不同的方式填充每个特性。

    批次示例:

    [{'label': 24,
      'sequence_feature': [1, 2],
      'seq_of_seqs_feature': [[11.1, 22.2],
                              [33.3, 44.4]]},
     {'label': 32,
      'sequence_feature': [3, 4, 5],
      'seq_of_seqs_feature': [[55.55, 66.66]]}]
    

    预期输出:

    [{'label': 24,
      'sequence_feature': [1, 2, 0],
      'seq_of_seqs_feature': [[11.1, 22.2],
                              [33.3, 44.4]]},
     {'label': 32,
      'sequence_feature': [3, 4, 5],
      'seq_of_seqs_feature': [[55.55, 66.66],
                               0.0, 0.0    ]}]
    

    正如您所看到的 标签 不应填充功能,并且 sequence\u功能 seq\u of\u seqs\u功能 应使用给定批次中相应的最长条目进行填充。

    1 回复  |  直到 6 年前
        1
  •  19
  •   mrry    6 年前

    这个 tf.data.Dataset.padded_batch() 方法允许您指定 padded_shapes 对于结果批次的每个组件(特征)。例如,如果调用了输入数据集 ds :

    padded_ds = ds.padded_batch(
        BATCH_SIZE,
        padded_shapes={
            'label': [],                          # Scalar elements, no padding.
            'sequence_feature': [None],           # Vector elements, padded to longest.
            'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
        })                                        # in each dimension.
    

    请注意 padded\u形状 参数的结构与输入数据集的元素相同,因此在本例中,它需要一个字典,其中的键与功能名称匹配。