代码之家  ›  专栏  ›  技术社区  ›  Ricardo Cruz

keras/tensorflow:将常数层连接到卷积

  •  1
  • Ricardo Cruz  · 技术社区  · 6 年前

    对于每个卷积激活映射,我想连接一层常量——更具体地说,我想连接一个网格网格。(这是为了复制uber的一篇论文。)

    例如,假设我有一个 (?, 256, 256, 32) ;然后我想连接一个形状的常量层 (?, 256, 256, 1) 是的。

    我就是这样做的:

    from keras import layers
    import tensorflow as tf
    import numpy as np
    
    input_layer = layers.Input((256, 256, 3))
    conv = layers.Conv2D(32, 3, padding='same')(input_layer)
    print('conv:', conv.shape)
    
    
    xx, yy = np.mgrid[:256, :256]  # [(256, 256), (256, 256)]
    xx = tf.constant(xx, np.float32)
    yy = tf.constant(yy, np.float32)
    
    xx = tf.reshape(xx, (-1, 256, 256, -1))
    yy = tf.reshape(yy, (-1, 256, 256, -1))
    print('xx:', xx.shape, 'yy:', yy.shape)
    
    concat = layers.Concatenate()([conv, xx, yy])
    print('concat:', concat.shape)
    
    conv2 = layers.Conv2D(32, 3, padding='same')(concat)
    print('conv2:', conv2.shape)
    

    但我得到了错误:

    conv: (?, 256, 256, 32)
    xx: (?, 256, 256, ?) yy: (?, 256, 256, ?)
    concat: (?, 256, 256, ?)
    Traceback (most recent call last):
    File "temp.py", line 21, in <module>
    conv2 = layers.Conv2D(32, 3, padding='same')(concat)
    [...]
    raise ValueError('The channel dimension of the inputs '
    ValueError: The channel dimension of the inputs should be defined. Found `None`.
    

    问题是我的常数层是 (?, 256, 256, ?) ,而不是 (?,256,256,1页) ,然后下一个卷积层出错。

    我尝试过其他事情,但没有成功。

    PS:我试图实现的文件已经 implemented here 是的。

    1 回复  |  直到 6 年前
        1
  •  1
  •   rvinas    6 年前

    问题是 tf.reshape 无法推断多个维度的形状(即使用 -1 对于多个维度,将导致未定义维度 ? )中。因为你想要 xx yy 成为 (?, 256, 256, 1) ,可以按如下方式重塑这些张量:

    xx = tf.reshape(xx, (-1, 256, 256, 1))
    yy = tf.reshape(yy, (-1, 256, 256, 1))
    

    生成的形状将是 (1, 256, 256, 1) 是的。现在, conv (?, 256, 256, 32) ,和 keras.layers.Concatenate 除了CutAT轴之外,需要所有输入的形状匹配。你可以用 tf.tile 重复张量 二十 沿着第一个维度以匹配批次大小:

    xx = tf.tile(xx, [tf.shape(conv)[0], 1, 1, 1])
    yy = tf.tile(yy, [tf.shape(conv)[0], 1, 1, 1])
    

    形状 二十 现在 (?,256,256,1页) ,张量可以连接,因为它们的第一个维度与批处理大小匹配。