这是我发现的使用
tf.cond
.
以便从每个示例中检索2个示例
tfrecord
我用过
zip
方法
tf.Dataset.data
API如下:
def load_train_sewa_tfrecords(filenames_train, train_batch_size):
datasets_train_iterators = []
with tf.name_scope('TFRecordsTrain'):
for file_name in filenames_train:
dataset_train = tf.data.TFRecordDataset(file_name).map(_parse_train_function).batch(train_batch_size)
datasets_train_iterators.append(dataset_train)
dataset_train_all = tf.data.Dataset.zip(tuple(datasets_train_iterators))
iterator_train_all = dataset_train_all.make_initializable_iterator()
with tf.name_scope('inputs_train'):
next_batch = iterator_train_all.get_next(name='next_batch')
names = []
detected = []
arousal = []
valence = []
liking = []
istalkings = []
images = []
for n in next_batch:
names.append(n[0])
detected.append(n[1])
arousal.append(n[2])
valence.append(n[3])
liking.append(n[4])
istalkings.append(n[5])
images.append(n[6])
names = tf.concat(names, axis=0, name='names')
detected = tf.concat(detected, axis=0, name='detected')
arousal = tf.concat(arousal, axis=0, name='arousal')
valence = tf.concat(valence, axis=0, name='valence')
liking = tf.concat(liking, axis=0, name='liking')
istalkings = tf.concat(istalkings, axis=0, name='istalkings')
images = tf.concat(images, axis=0, name='images')
return names, detected, arousal, valence, liking, istalkings, images, iterator_train_all
我将有一个类似的开发方法;或者我可以更改方法的传递参数,这样我可以使用相同的方法两次…(不是问题)。
然后:
names_dev, detected_dev, arousal_dev, valence_dev, liking_dev, istalkings_dev, images_dev, iterator_dev_all = \
load_devel_sewa_tfrecords(filenames_dev, sewa_batch_size)
names_train, detected_train, arousal_train, valence_train, liking_train, istalkings_train, images_train, iterator_train_all = \
load_train_sewa_tfrecords(filenames_train, sewa_batch_size)
images_train = pre_process_sewa_images(images_train)
images_dev = pre_process_sewa_images(images_dev)
def return_train_sewa():
return names_train, detected_train, arousal_train, valence_train, liking_train, istalkings_train, images_train
def return_dev_sewa():
return names_dev, detected_dev, arousal_dev, valence_dev, liking_dev, istalkings_dev, images_dev
names, detected, arousal, valence, liking, istalkings, images_sewa = tf.cond(phase_train, return_train_sewa, return_dev_sewa)
sewa_inputs = []
sess = tf.Session()
import numpy as np
for e in range(epochs):
sess.run(iterator_train_all.initializer)
sess.run(iterator_dev_all.initializer)
i = 0
total = 0
try:
while True:
i += 1
names_np, detected_np, arousal_np, valence_np, liking_np, istalkings_np = \
sess.run([names, detected, arousal, valence, liking, istalkings], feed_dict={phase_train: True})
total += np.shape(names_np)[0]
print("total =", total, " | i =", i)
except:
print("end of train...")
i_d = 0
total_d = 0
sess.run(iterator_train_all.initializer)
sess.run(iterator_dev_all.initializer)
try:
while True:
i_d += 1
names_np, detected_np, arousal_np, valence_np, liking_np, istalkings_np = \
sess.run([names, detected, arousal, valence, liking, istalkings], feed_dict={phase_train: False})
total_d += np.shape(names_np)[0]
print("total_d =", total_d, " | i_d =", i_d)
print(names_np)
except:
print("End of devel")
请注意,必须同时运行两个初始化
sess.run(iterator_train_all.initializer)
和
sess.run(iterator_dev_all.initializer)
之前
sess.run([names....])
因为我猜
Tf.COND
;将检索培训和验证示例,除了
Tf.COND
将仅返回其中一个
phase_train
将“保持架”放置,以确定我们是处于培训模式还是测试模式。
证明:当我插入
names = tf.Print(input_=[names], data=[names], message='dev names')
在下面
load_devel_sewa_tfrecords
在返回之前;我得到:
dev names[\'Devel_01\' \'Devel_01\' \'Devel_02\'...]
在控制台中打印出来。即,在评估培训数据集时;TensorFlow正在同时评估devel数据集;但是
Tf.COND
已输出与培训数据集相关的tfrecords。
希望这个答案有帮助!!