代码之家  ›  专栏  ›  技术社区  ›  user25004

如何避免在发生大修时失去过去在Keras的运行[[副本]

  •  0
  • user25004  · 技术社区  · 6 年前

    2 回复  |  直到 6 年前
        1
  •  2
  •   nuric    6 年前

    我使用一个自定义回调来存储上一个历元、权重、损失等,以便在以后恢复:

    class StatefulCheckpoint(ModelCheckpoint):
      """Save extra checkpoint data to resume training."""
      def __init__(self, weight_file, state_file=None, **kwargs):
        """Save the state (epoch etc.) along side weights."""
        super().__init__(weight_file, **kwargs)
        self.state_f = state_file
        self.state = dict()
        if self.state_f:
          # Load the last state if any
          try:
            with open(self.state_f, 'r') as f:
              self.state = json.load(f)
            self.best = self.state['best']
          except Exception as e: # pylint: disable=broad-except
            print("Skipping last state:", e)
    
      def on_epoch_end(self, epoch, logs=None):
        """Saves training state as well as weights."""
        super().on_epoch_end(epoch, logs)
        if self.state_f:
          state = {'epoch': epoch+1, 'best': self.best,
                   'hostname': self.hostname}
          state.update(logs)
          state.update(self.params)
          with open(self.state_f, 'w') as f:
            json.dump(state, f)
    
      def get_last_epoch(self, initial_epoch=0):
        """Return last saved epoch if any, or return default argument."""
        return self.state.get('epoch', initial_epoch)
    

    这只适用于你的时代是合理的时间,例如1小时,但它是干净的,与kerasapi一致。

        2
  •  1
  •   shayaan    6 年前

    巨蟒有很大的优势 logging utilities pickle 对序列化模型很有用。