代码之家  ›  专栏  ›  技术社区  ›  Ujjwal

TensorFlow对象检测API中的历元结束检测

  •  1
  • Ujjwal  · 技术社区  · 6 年前

    如何在TF对象检测API中检测一个历元的结束(即通过数据集完成一次完整扫描)?这对于在自定义检测模型中进行一些簿记或内部处理(即重置一些权重)可能很有用

    0 回复  |  直到 6 年前
        1
  •  0
  •   Chelaru Adrian    4 年前

    您可能想要实现 tf.estimator.SessionRunHook

    为此,需要编辑model\u lib。py在 tf.estimator.TrainSpec 通过添加挂钩参数,或创建自己的训练文件,并在将其传递给tf之前覆盖train\u spec。估计员。train\u和\u评估。

    使用添加到Tensorflow对象检测API的ProfilerHook的示例: (SessionRunHook应类似)

    config = tf.estimator.RunConfig(model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps,
                                save_checkpoints_secs=save_checkpoints_secs, keep_checkpoint_max=keep_checkpoint_max,
                                log_step_count_steps=log_step_count_steps)
    
    train_and_eval_dict = model_lib.create_estimator_and_inputs(
          run_config=config,
          hparams=model_hparams.create_hparams(hparams_overrides),
          pipeline_config_path=pipeline_config_path,
          config_override = cfg_override,
          train_steps=num_train_steps,
          sample_1_of_n_eval_examples=sample_1_of_n_eval_examples,
          sample_1_of_n_eval_on_train_examples=sample_1_of_n_eval_on_train_examples,
          save_final_config=save_final_config)
    
    estimator = train_and_eval_dict['estimator']
    train_input_fn = train_and_eval_dict['train_input_fn']
    eval_input_fns = train_and_eval_dict['eval_input_fns']
    eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
    predict_input_fn = train_and_eval_dict['predict_input_fn']
    train_steps = train_and_eval_dict['train_steps']
    
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
      train_input_fn,
      eval_input_fns,
      eval_on_train_input_fn,
      predict_input_fn,
      train_steps,
      eval_on_train_data=False)
    
    profile_hook = tf.train.ProfilerHook(save_steps=profiler_save_step, save_secs=None, output_dir=profiler_output_dir, 
                                         show_dataflow=True, show_memory=True)
    
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn
                                        ,max_steps=train_steps
                                        ,hooks=[profile_hook])
    
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])