代码之家  ›  专栏  ›  技术社区  ›  Shawn Walton

尝试将图像数据输入DNNRegressor时出现值错误

  •  0
  • Shawn Walton  · 技术社区  · 6 年前

    我很难拿到那份工作 DNNRegressor 接受图像数据。运行代码时出现以下错误:

    ValueError: Cannot reshape a tensor with 147456 elements to shape [384,442368] (169869312 elements) for 'dnn/input_from_feature_columns/input_layer/image/Reshape' (op: 'Reshape') with input shapes: [384,384,1], [2] and with input tensors computed as partial shapes: input[1] = [384,442368].
    

    import os
    import os.path
    
    import tensorflow as tf
    
    SPLIT_PERCENTAGE = 0.8
    
    # snip snip
    
    # ids is a List of strings
    # filenames is a List of filenames of image files on the disk
    # labels is a List of int scores
    
    estimator = tf.estimator.DNNRegressor(
        feature_columns=[
            tf.feature_column.numeric_column('image', shape=(384, 384, 3)),
        ],
        hidden_units=[1024, 512, 256],
        model_dir=output_dir,
    )
    
    estimator.train(input_fn=lambda: input_fn(False, ids, filenames, labels))
    
    def input_fn(is_training, ids, filenames, labels):
        id_tensor = tf.constant(ids, dtype=tf.string)
        filenames_tensor = tf.constant(filenames, dtype=tf.string)
        labels_tensor = tf.constant(labels, dtype=tf.float32)
    
        ds = tf.data.Dataset.from_tensor_slices(((id_tensor, filenames_tensor), labels_tensor))
        print(ds)
        ds = ds.take(int(len(labels) * SPLIT_PERCENTAGE)) if is_training else ds.skip(int(len(labels) * SPLIT_PERCENTAGE))
        ds = ds.map(load_image)
    
        iterator = ds.make_one_shot_iterator()
        features, labels = iterator.get_next()
    
        return features, labels
    
    def load_image(id_file, score):
        _, filename = id_file
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string, channels=1)
        image_converted = tf.image.convert_image_dtype(image_decoded, tf.float16)
        image_resized = tf.image.resize_image_with_crop_or_pad(image_converted, 384, 384)
    
        return {'image': image_resized}, [tf.log(score)]
    

    我怀疑这和我如何声明我的特性列有关,但是 this example

    1 回复  |  直到 6 年前
        1
  •  1
  •   Shawn Walton    6 年前

    .batch() 呼叫是 必要的 在一个 DataSet 为了让它被消费,即使它只是 .batch(1)