代码之家  ›  专栏  ›  技术社区  ›  BAE

Tensorflow/Keras中tf.scatter的尺寸误差

  •  0
  • BAE  · 技术社区  · 5 年前

    我的密码:

    reshape_out = Reshape((21, 3), input_shape=(21*3,), name='reshape_to_21_3')(output3d)
    
    def proj_output_shape(shp):
        return (None, 32, 32, 1)
    
    def f(x):
        import tensorflow as tf
        batch_size = K.shape(x)[0]
        print('x.shape={0}'.format(x.shape))
    
        idx = K.cast(x[:, :, 0:2]*15.5+15.5, "int32")
        print('idx.shape={0}'.format(idx.shape))
    
        # z = mysparse_to_dense(idx, (K.shape(x)[0], 32, 32), 1.0, 0.0, name='sparse_tensor')
        updates = tf.ones([batch_size, 21])
        print('updates.shape={0}'.format(updates.shape))
    
        #shape = tf.Variable(np.array([batch_size, 32, 32]))
        #print('shape.shape={0}'.format(shape))
    
        z = tf.scatter_nd(indices=idx,
                          updates=updates,
                          shape=(batch_size, 32, 32),
                          name='cool')
    
        print('z={0}'.format(z))
        #z = tf.add(z, z)
        #z = tf.sparse_add(tf.zeros(z.dense_shape), z)
        z = K.reshape(z, (K.shape(x)[0], 32, 32, 1))
        print('z.shape={0}'.format(z.shape), z)
    
        fil = make_kernel(1.0)
        fil = K.reshape(fil, (5, 5, 1, 1))
        print('fil.shape={0}'.format(fil.shape), fil)
    
        r = K.conv2d(z,kernel=fil, padding='same', data_format="channels_last")
        print('r.shape={0}'.format(r.shape), r)
    
        return r
    

    输出:

    x.shape=(?, 21, 3)
    idx.shape=(?, 21, 2)
    updates.shape=(?, 21)
    

    错误:

    ValueError: The inner 1 dimensions of output.shape=[?,?,?] must match the inner 0 dimensions of updates.shape=[?,21]: Shapes must be equal rank, but are 1 and 0 for 'projection_4/cool' (op: 'ScatterNd') with input shapes: [?,21,2], [?,21], [3].
    

    怎么解决这个问题?谢谢

    0 回复  |  直到 5 年前