我试图实现
experience replay memory
与
tf.estimator.Estimator
应用程序编程接口。但是,我不确定哪种方法是获得至少适用于所有模式的结果的最佳方法(
TRAIN
,
EVALUATE
,
PREDICT
)我尝试了以下方法:
-
使用
tf.Variable
,这会导致批处理和输入管道出现问题(我无法在测试或预测阶段输入自定义经验)
目前尝试:
-
在
tf.Graph
.在每次运行后使用
tf.train.SessionRunHook
. 将体验载入
tf.data.Dataset.from_generator()
在培训和测试期间。自己管理国家。
我在几点上都失败了,开始相信tf.estimator.estimator API没有提供必要的接口来方便地写下这一点。
一些代码(第一种方法,由于它是固定用于exp切片的,所以批处理大小失败,因此我不能使用该模型进行预测或评估):
def model_fn(self, features, labels, mode, params):
batch_size = features["matrix"].get_shape()[0].value
# get prev_exp
if mode == tf.estimator.ModeKeys.TRAIN:
erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])
# model
pred = model(features["matrix"], prev_exp, params)
不过,最好是将erm放在功能dict中,然后在图形外部管理erm,并用sessionrunhook写回我最新的经验。有更好的方法吗?还是我错过了什么?