代码之家  ›  专栏  ›  技术社区  ›  Kracekumar

py\u func无法处理包含9个以上项目的列表

  •  0
  • Kracekumar  · 技术社区  · 6 年前

    我想用 Dataset.map 将两个值反馈给函数,函数返回一个值列表。当返回列表包含八个以上的元素时, py_func 无法处理类型。从返回列表时 map_func 包含八个元素,没有问题。

    tensorflow版本为 1.4.1 在里面 Trisquel 发行版。

    成功案例

    import tensorflow as tf
    
    def gen_range(groups=4, limit=1000):
        jump = limit/groups
        start, stop = 0, 0
        while stop != limit:
            stop = start + jump
            yield start, stop
            start = stop
    
    def bridge(x, y):
        return [[[x] * 4, [y] * 4]]
    
    with tf.Session() as sess:
        dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
            lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
            make_one_shot_iterator()
        init = tf.global_variables_initializer()
        sess.run(init)
        while True:
            print(sess.run(dataset.get_next()))
    
    输出
    (array([[  0,   0,   0,   0],
       [250, 250, 250, 250]], dtype=int32),)
    (array([[250, 250, 250, 250],
       [500, 500, 500, 500]], dtype=int32),)
    (array([[500, 500, 500, 500],
       [750, 750, 750, 750]], dtype=int32),)
    2018-01-09 01:56:12.871943: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: StopIteration: Iteration finished.
    (array([[ 750,  750,  750,  750],
       [1000, 1000, 1000, 1000]], dtype=int32),)
    

    故障案例

    def bridge(x, y):
        return [[[x] * 5, [y] * 4]]
    
    with tf.Session() as sess:
        dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
            lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
            make_one_shot_iterator()
        init = tf.global_variables_initializer()
        sess.run(init)
        while True:
            print(sess.run(dataset.get_next()))
    
    输出
    2018-01-09 01:56:41.201683: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
    2018-01-09 01:56:41.201960: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
     [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
     2018-01-09 01:56:41.202002: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
    ---------------------------------------------------------------------------
    UnimplementedError                        Traceback (most recent call last)
    /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
    1322     try:
    -> 1323       return fn(*args)
       1324     except errors.OpError as e:
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
     1301                                    feed_dict, fetch_list, target_list,
     -> 1302                                    status, run_metadata)
     1303 
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
     472             compat.as_text(c_api.TF_Message(self.status.status)),
     --> 473             c_api.TF_GetCode(self.status.status))
         474     # Delete the underlying status object from memory otherwise it stays alive
    
     UnimplementedError: Unsupported object type list
     [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
     [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=
     [<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]
    
     During handling of the above exception, another exception occurred:
    
     UnimplementedError                        Traceback (most recent call last)
     <ipython-input-120-120ecc56d75d> in <module>()
           4     sess.run(init)
           5     while True:
     ----> 6         print(sess.run(dataset.get_next()))
           7 
           8 
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
         887     try:
         888       result = self._run(None, fetches, feed_dict, options_ptr,
     --> 889                          run_metadata_ptr)
         890       if run_metadata:
         891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
        1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
        1119       results = self._do_run(handle, final_targets, final_fetches,
     -> 1120                              feed_dict_tensor, options, run_metadata)
        1121     else:
        1122       results = []
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
     1315     if handle is None:
     1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
     -> 1317                            options, run_metadata)
        1318     else:
        1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
    
     /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
        1334         except KeyError:
        1335           pass
     -> 1336       raise type(e)(node_def, op, message)
        1337 
        1338   def _extend_graph(self):
    
     UnimplementedError: Unsupported object type list
         [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
         [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=[<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   Olivier Moindrot    6 年前

    问题不在于大小,而在于返回的内容无法转换为张量。

    当你回来的时候 [[[x] * 4, [y] * 4]] ,这可以转换为形状张量 (1, 2, 4) :

    res = tf.constant([[[x] * 4, [y] * 4]])
    print(res.get_shape())  # prints (1, 2, 4)
    

    当你回来的时候 [[[x] * 5, [y] * 4]] ,您可以为 x=1, y=2 :

    [[[1, 1, 1, 1, 1],
      [2, 2, 2, 2]
    ]]
    

    无法将其转换为张量,因为第一行和第二行的尺寸不匹配。


    如果尝试执行以下操作,可能会触发类似错误:

    res = tf.constant([[1, 2], [3]])
    

    参数必须是稠密张量:[[1,2],[3]-得到形状[2],但需要[2,2]。

    TensorFlow无法推断张量的形状。