我想用
Dataset.map
将两个值反馈给函数,函数返回一个值列表。当返回列表包含八个以上的元素时,
py_func
无法处理类型。从返回列表时
map_func
包含八个元素,没有问题。
tensorflow版本为
1.4.1
在里面
Trisquel
发行版。
成功案例
import tensorflow as tf
def gen_range(groups=4, limit=1000):
jump = limit/groups
start, stop = 0, 0
while stop != limit:
stop = start + jump
yield start, stop
start = stop
def bridge(x, y):
return [[[x] * 4, [y] * 4]]
with tf.Session() as sess:
dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
make_one_shot_iterator()
init = tf.global_variables_initializer()
sess.run(init)
while True:
print(sess.run(dataset.get_next()))
输出
(array([[ 0, 0, 0, 0],
[250, 250, 250, 250]], dtype=int32),)
(array([[250, 250, 250, 250],
[500, 500, 500, 500]], dtype=int32),)
(array([[500, 500, 500, 500],
[750, 750, 750, 750]], dtype=int32),)
2018-01-09 01:56:12.871943: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: StopIteration: Iteration finished.
(array([[ 750, 750, 750, 750],
[1000, 1000, 1000, 1000]], dtype=int32),)
故障案例
def bridge(x, y):
return [[[x] * 5, [y] * 4]]
with tf.Session() as sess:
dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
make_one_shot_iterator()
init = tf.global_variables_initializer()
sess.run(init)
while True:
print(sess.run(dataset.get_next()))
输出
2018-01-09 01:56:41.201683: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
2018-01-09 01:56:41.201960: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
[[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
2018-01-09 01:56:41.202002: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
---------------------------------------------------------------------------
UnimplementedError Traceback (most recent call last)
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1322 try:
-> 1323 return fn(*args)
1324 except errors.OpError as e:
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1301 feed_dict, fetch_list, target_list,
-> 1302 status, run_metadata)
1303
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
472 compat.as_text(c_api.TF_Message(self.status.status)),
--> 473 c_api.TF_GetCode(self.status.status))
474 # Delete the underlying status object from memory otherwise it stays alive
UnimplementedError: Unsupported object type list
[[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
[[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=
[<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]
During handling of the above exception, another exception occurred:
UnimplementedError Traceback (most recent call last)
<ipython-input-120-120ecc56d75d> in <module>()
4 sess.run(init)
5 while True:
----> 6 print(sess.run(dataset.get_next()))
7
8
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 if final_fetches or final_targets or (handle and feed_dict_tensor):
1119 results = self._do_run(handle, final_targets, final_fetches,
-> 1120 feed_dict_tensor, options, run_metadata)
1121 else:
1122 results = []
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1315 if handle is None:
1316 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317 options, run_metadata)
1318 else:
1319 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1334 except KeyError:
1335 pass
-> 1336 raise type(e)(node_def, op, message)
1337
1338 def _extend_graph(self):
UnimplementedError: Unsupported object type list
[[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
[[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=[<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]