代码之家  ›  专栏  ›  技术社区  ›  Jsevillamol

Keras-基于用户输入的早期浇灌

  •  3
  • Jsevillamol  · 技术社区  · 6 年前

    我想知道是否有一种简单的方法可以根据用户输入而不是任何特定指标的监控来创建一种在喀拉斯早期停止的触发方法。

    例如,我想发送一个键盘信号到执行培训的进程,以便它离开 fit_generator 函数并执行其余代码。

    有什么想法吗?

    编辑:根据@ankurgoel的回答,我编写了以下代码:

    # Monitors the SIGINT (ctrl + C) to safely stop training when it is sent
    flag = False
    class TerminateOnFlag(Callback):
        """Callback that terminates training when the flag is raised.
        """
        def on_batch_end(self, batch, logs=None):
            if flag:    
                self.model.stop_training = True
    
    def handler(signum, frame):
        logging.info('SIGINT signal received. Training will finish after this epoch')
        global flag
        flag = True
    
    signal.signal(signal.SIGINT, handler) # We assign a specific handler for the SIGINT signal
    terminateOnFlag = TerminateOnFlag()
    callbacks.append(terminateOnFlag)
    

    在哪里? callbacks 是我输入的回调列表 菲特发生器 .

    在训练期间,当我发送 SIGINT 信号真的我明白了 SIGINT signal received. Training will finish after this epoch 但是,当时代结束时,什么也不会发生。怎么回事?

    1 回复  |  直到 6 年前
        1
  •  3
  •   Ankur Goel    6 年前

    您可以考虑以下方法:

    使用一个全局变量,初始化0 使用信号处理器,

    当python进程接收到信号(中断)时,其值从0更改为1。

    在keras中使用自定义回调,以在更改此变量值时停止培训。

    class TerminateOnFlag(Callback):
    """Callback that terminates training when flag=1 is encountered.
    """
    
    def on_batch_end(self, batch, logs=None):
        if flag==1:    
            self.model.stop_training = True
    

    原始回拨可从以下网址获得: https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L251

    您仍然需要检查是否可以提供自定义回调来适应\生成器,而不是标准回调。

    以下是信号处理程序的代码:

    迎风:

    import signal, os
    
    def handler(signum, frame):
        print('Signal handler called with signal', signum)
        raise OSError("Couldn't open device!")
    
    signal.signal(signal.CTRL_C_EVENT, handler) # only in python version 3.2
    

    Linux:

    import signal, os
    
    def handler(signum, frame):
        print('Signal handler called with signal', signum)
        raise OSError("Couldn't open device!")
    
    signal.signal(signal.SIGINT, handler)