代码之家  ›  专栏  ›  技术社区  ›  guillaumekln

如何使用tf。tf中数据的可初始化迭代器。估计器的输入fn?

  •  10
  • guillaumekln  · 技术社区  · 7 年前

    我想用一个 tf.estimator.Estimator tf.data 应用程序编程接口。

    我有这样的想法:

    def model_fn(features, labels, params, mode):
      # Defines model's ops.
      # Initializes with tf.train.Scaffold.
      # Returns an tf.estimator.EstimatorSpec.
    
    def input_fn():
      dataset = tf.data.TextLineDataset("test.txt")
      # map, shuffle, padded_batch, etc.
    
      iterator = dataset.make_initializable_iterator()
    
      return iterator.get_next()
    
    estimator = tf.estimator.Estimator(model_fn)
    estimator.train(input_fn)
    

    因为我不能使用 make_one_shot_iterator 对于我的用例,我的问题是 input_fn model_fn tf.train.Scaffold 初始化本地操作)。

    input_fn = iterator.get_next

    1 回复  |  直到 6 年前
        1
  •  13
  •   guillaumekln    6 年前

    从TensorFlow 1.5开始,可以 input_fn 返回a tf.data.Dataset ,例如:

    def input_fn():
      dataset = tf.data.TextLineDataset("test.txt")
      # map, shuffle, padded_batch, etc.
      return dataset
    

    看见 c294fcfd .


    对于以前的版本,可以在 tf.GraphKeys.TABLE_INITIALIZERS

    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)