代码之家  ›  专栏  ›  技术社区  ›  Chocolate

TensorFlow-使用估计器API实现经验重放内存

  •  0
  • Chocolate  · 技术社区  · 6 年前

    我试图实现 experience replay memory tf.estimator.Estimator 应用程序编程接口。但是,我不确定哪种方法是获得至少适用于所有模式的结果的最佳方法( TRAIN , EVALUATE , PREDICT )我尝试了以下方法:

    • 使用 tf.Variable ,这会导致批处理和输入管道出现问题(我无法在测试或预测阶段输入自定义经验)

    目前尝试:

    • tf.Graph .在每次运行后使用 tf.train.SessionRunHook . 将体验载入 tf.data.Dataset.from_generator() 在培训和测试期间。自己管理国家。

    我在几点上都失败了,开始相信tf.estimator.estimator API没有提供必要的接口来方便地写下这一点。

    一些代码(第一种方法,由于它是固定用于exp切片的,所以批处理大小失败,因此我不能使用该模型进行预测或评估):

     def model_fn(self, features, labels, mode, params):
        batch_size = features["matrix"].get_shape()[0].value
    
        # get prev_exp
        if mode == tf.estimator.ModeKeys.TRAIN:
            erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
            prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])
    
        # model
        pred = model(features["matrix"], prev_exp, params) 
    

    不过,最好是将erm放在功能dict中,然后在图形外部管理erm,并用sessionrunhook写回我最新的经验。有更好的方法吗?还是我错过了什么?

    1 回复  |  直到 6 年前
        1
  •  0
  •   Chocolate    6 年前

    我通过在图形外部实现erm解决了我的问题,将它从_Generator()反馈到输入管道中,并使用sessionrunhook进行写操作。是的,挺无聊的,但它起作用了。