代码之家  ›  专栏  ›  技术社区  ›  mic

错误:等级4的x的dot支持尚未实现

  •  2
  • mic  · 技术社区  · 6 年前

    我使用Tensorflow.js来预测我在Keras中训练的模型。然而,当我输入我的四维张量时,我会得到以下错误:

    UnhandledPromiseRejectionWarning: Unhandled promise rejection (rejection id: 1): Error: dot support for x of rank 4 is not yet implemented: x shape = 32,1,1,100
    

    编辑1

    model.predict(noise_tensor) . 大多数代码都是不相关的:

      noise_tensor.print(true)
    
      generated_images = model.predict(noise_tensor) //error occours here
    

    这是我的4d张量的打印输出:

    Tensor
      dtype: float32
      rank: 4
      shape: [64,1,1,100]
      values:
        [ [ [[0.3799773 , -0.0252707, 0.0118336 , ..., 0.1703698 , -0.0649208, 0.2152225 ],]],
    
    
          [ [[0.219656  , 0.2850143 , -0.1078744, ..., 0.1627689 , -0.0838831, -0.1112608],]],
    
    
          [ [[-0.1295149, -0.08308  , 0.1872116 , ..., -0.2033772, -0.4184959, -0.3357461],]],
    
    
         ...
          [ [[0.0029674 , 0.0422036 , 0.067896  , ..., 0.1368463 , 0.1122015 , -0.0395375],]],
    
    
          [ [[0.043546  , -0.0281712, 0.0898769 , ..., 0.205565  , 0.1444133 , 0.0067788 ],]],
    
    
          [ [[-0.1089588, -0.0161969, -0.0724337, ..., 0.1427118 , -0.2577117, 0.0013836 ],]]]
    

    以下是Keras模型的摘要:

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    dense_1 (Dense)              (None, 1, 1, 32768)       3309568
    _________________________________________________________________
    reshape_1 (Reshape)          (None, 8, 8, 512)         0
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 8, 8, 512)         2048
    _________________________________________________________________
    activation_1 (Activation)    (None, 8, 8, 512)         0
    _________________________________________________________________
    conv2d_transpose_1 (Conv2DTr (None, 16, 16, 256)       3277056
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 16, 16, 256)       1024
    _________________________________________________________________
    activation_2 (Activation)    (None, 16, 16, 256)       0
    _________________________________________________________________
    conv2d_transpose_2 (Conv2DTr (None, 32, 32, 128)       819328
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 32, 32, 128)       512
    _________________________________________________________________
    activation_3 (Activation)    (None, 32, 32, 128)       0
    _________________________________________________________________
    conv2d_transpose_3 (Conv2DTr (None, 64, 64, 64)        204864
    _________________________________________________________________
    batch_normalization_4 (Batch (None, 64, 64, 64)        256
    _________________________________________________________________
    activation_4 (Activation)    (None, 64, 64, 64)        0
    _________________________________________________________________
    conv2d_transpose_4 (Conv2DTr (None, 128, 128, 1)       1601
    _________________________________________________________________
    activation_5 (Activation)    (None, 128, 128, 1)       0
    =================================================================
    Total params: 7,616,257
    Trainable params: 7,614,337
    Non-trainable params: 1,920
    _________________________________________________________________
    

    以及Python中的相应代码:

    def construct_generator():
    
        generator = Sequential()
    
        generator.add(Dense(units=8 * 8 * 512,
                            kernel_initializer='glorot_uniform',
                            input_shape=(1, 1, 100)))
        generator.add(Reshape(target_shape=(8, 8, 512)))
        generator.add(BatchNormalization(momentum=0.5))
        generator.add(Activation('relu'))
    
        generator.add(Conv2DTranspose(filters=256, kernel_size=(5, 5),
                                      strides=(2, 2), padding='same',
                                      data_format='channels_last',
                                      kernel_initializer='glorot_uniform'))
        generator.add(BatchNormalization(momentum=0.5))
        generator.add(Activation('relu'))
    
        generator.add(Conv2DTranspose(filters=128, kernel_size=(5, 5),
                                      strides=(2, 2), padding='same',
                                      data_format='channels_last',
                                      kernel_initializer='glorot_uniform'))
        generator.add(BatchNormalization(momentum=0.5))
        generator.add(Activation('relu'))
    
        generator.add(Conv2DTranspose(filters=64, kernel_size=(5, 5),
                                      strides=(2, 2), padding='same',
                                      data_format='channels_last',
                                      kernel_initializer='glorot_uniform'))
        generator.add(BatchNormalization(momentum=0.5))
        generator.add(Activation('relu'))
    
        generator.add(Conv2DTranspose(filters=1, kernel_size=(5, 5),
                                      strides=(2, 2), padding='same',
                                      data_format='channels_last',
                                      kernel_initializer='glorot_uniform'))
        generator.add(Activation('tanh'))
    
        optimizer = Adam(lr=0.00015, beta_1=0.5)
        generator.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=None)
    
        print('generator')
        generator.summary()
    
        return generator
    

    编辑2

    这是tensorflow.js中的一个错误。对于未来的访问者,请查看GitHub线程 here .

    1 回复  |  直到 6 年前
        1
  •  1
  •   edkeveked    6 年前

    现在,输入应该是1或2级 tf.dot 工作

    推荐文章