代码之家  ›  专栏  ›  技术社区  ›  John

tensorflow中两个相同秩张量之间的广播

  •  0
  • John  · 技术社区  · 7 年前

    x s 关于形状:

    > x.shape
    TensorShape([Dimension(None), Dimension(3), Dimension(5), Dimension(5)])
    > s.shape
    TensorShape([Dimension(None), Dimension(12), Dimension(5), Dimension(5)])
    

    我想在两人之间播放点积 x 通过维度 1 具体如下:

    > x_s.shape
    TensorShape([Dimension(None), Dimension(4), Dimension(5), Dimension(5)])
    

    x_s[i, 0, k, l] = sum([x[i, j, k, l] * s[i, j, k, l] for j in range (3)])
    x_s[i, 1, k, l] = sum([x[i, j-3, k, l] * s[i, j, k, l] for j in range (3, 6)])
    x_s[i, 2, k, l] = sum([x[i, j-6, k, l] * s[i, j, k, l] for j in range (6, 9)])
    x_s[i, 3, k, l] = sum([x[i, j-9, k, l] * s[i, j, k, l] for j in range (9, 12)])
    

    我有以下实现:

    s_t = tf.transpose(s, [0, 2, 3, 1]) # [None, 5, 5, 12]
    x_t = tf.transpose(x, [0, 2, 3, 1]) # [None, 5, 5, 3]
    x_t = tf.tile(x_t, [1, 1, 1, 4]) # [None, 5, 5, 12]
    
    x_s = x_t * s_t # [None, 5, 5, 12]
    x_s = tf.reshape(x_s, [tf.shape(x_s)[0], 5, 5, 4, 3]) # [None, 5, 5, 4, 3]
    x_s = tf.reduce_sum(x_s, axis=-1) # [None, 5, 5, 4]
    x_s = tf.transpose(x_s, [0, 3, 1, 2]) # [None, 4, 5, 5]
    

    我知道这在内存中效率不高,因为 tile . 而且 reshape 的, transpose element-wise reduce_sum s运算可能会影响较大张量的性能。有什么办法可以让它更干净吗?

    2 回复  |  直到 7 年前
        1
  •  1
  •   DomJack    7 年前

    你有证据表明 reshape s很贵?以下内容使用了重塑和标注广播:

    x_s = tf.reduce_sum(tf.reshape(s, (-1, 4, 3, 5, 5)) *
                        tf.expand_dims(x, axis=1), axis=2)
    
        2
  •  0
  •   Jie.Zhou    7 年前

    s 具有 tf.split 转化为四个张量,然后使用 tf.tensordot 为了得到最终结果,像这样

    splits = tf.split(s, [3] * 4, axis=1)
    splits = map(lambda split: tf.tensordot(split, x, axes=[[1], [1]]), splits)
    x_s = tf.stack(splits, axis=1)