代码之家  ›  专栏  ›  技术社区  ›  Amit

在子进程出口执行Python函数

  •  1
  • Amit  · 技术社区  · 6 年前

    我有一个回忆录函数包装器,带有命中和未命中计数器。 因为我不能从函数中访问任何局部变量,所以我使用字典来计算命中和未命中。

    该函数在48个核上运行1000个并行进程,每个核运行超过一百万次,因此我使用 Manager.dict

    仅仅保持分数就可以使我的执行时间增加三倍,因此我想做一些更聪明的事情—我想保留一个本地计数器,它只是一个普通的字典,当进程退出时,将该分数添加到由管理器管理的常规分数字典中。

    有没有办法在子进程出口执行函数?像这样的 atexit 对产卵的孩子有效。

    相关代码:(请注意 MAGICAL_AT_PROCESS_EXIT_CLASS ,这就是我想要的)

    manager = Manager()
    
    global_score = manager.dict({
        "hits": 0,
        "misses": 0
    })
    
    def memoize(func):
        local_score = {
            "hits": 0,
            "misses": 0
        }
    
        cache = {}
    
        def process_exit_handler():
            global_score["hits"] += local_score["hits"]
            global_score["misses"] += local_score["misses"]
    
        MAGICAL_AT_PROCESS_EXIT_CLASS.register(process_exit_handler)
    
        @wraps(func)
        def wrap(*args):
            cache_key = pickle.dumps(args)
            if cache_key not in cache:
                local_score["misses"] += 1
                cache[cache_key] = func(*args)
            else:
                local_score["hits"] += 1
            return cache[cache_key]
    
        return wrap
    
    
    def exit_handler():
        print("Cache", global_score)
    
    atexit.register(exit_handler)
    

    (是的,我知道它独立地缓存每个进程。是的,这是理想的行为)

    当前解决方案: 我按以下方式更改了包装器方法:

    @wraps(func)
    def wrap(*args):
        cache_key = pickle.dumps(args)
        if cache_key not in cache:
            local_score["misses"] += 1
            local_score["open"] += 1
            cache[cache_key] = func(*args)
            local_score["open"] -= 1
        else:
            local_score["hits"] += 1
    
        if local_score["open"] == 0:
            score["hits"] += local_score["hits"]
            score["misses"] += local_score["misses"]
            local_score["hits"] = 0
            local_score["misses"] = 0
    
        return cache[cache_key]
    

    1 回复  |  直到 6 年前
        1
  •  2
  •   Darkonaut    6 年前

    通过子类化可以相对容易地实现这一点 Process multiprocessing.Pool ,越来越复杂了。 Pool


    有两个问题需要解决。

    1. 使子进程在进程终止时调用退出处理程序。
    2. 防止 在退出处理程序完成之前终止子进程。

    为了使用forking作为子进程的启动方法,我发现有必要使用monkey补丁 multiprocessing.pool.worker . 我们可以用 atexit 退出 . 补丁是一个包裹 worker 打电话给我们的客户 at_exit 当工作进程返回时执行函数,这在进程即将退出时发生。

    # at_exit_pool.py
    
    import os
    import threading
    from functools import wraps
    import multiprocessing.pool
    from multiprocessing.pool import worker, TERMINATE, Pool
    from multiprocessing import util, Barrier
    from functools import partial
    
    
    def finalized(worker):
        """Extend worker function with at_exit call."""
        @wraps(worker)
        def wrapper(*args, **kwargs):
            result = worker(*args, **kwargs)
            at_exit()  # <-- patch
            return result
        return wrapper
    
    
    worker = finalized(worker)
    multiprocessing.pool.worker = worker  # patch
    

    这个解决方案也是子类化的 处理这两个问题。 PatientPool 正在引入两个强制参数 在出口处 at_exit_args 在出口处 病人池 是背驮 initializer 从标准 游泳池 在子进程中注册退出处理程序。以下是处理注册退出处理程序的函数:

    # at_exit_pool.py
    
    def at_exit(func=None, barrier=None, *args):
        """Call at_exit function and wait on barrier."""
        func(*args)
        print(os.getpid(), 'barrier waiting')  # DEBUG
        barrier.wait()
    
    
    def register_at_exit(func, barrier, *args):
        """Register at_exit function."""
        global at_exit
        at_exit = partial(at_exit, func, barrier, *args)
    
    
    def combi_initializer(at_exit_args, initializer, initargs):
        """Piggyback initializer with register_at_exit."""
        if initializer:
            initializer(*initargs)
        register_at_exit(*at_exit_args)
    

    正如你在书中看到的 ,我们将使用 multiprocessing.Barrier 游泳池

    屏障的工作方式是阻止任何进程调用 .wait() 只要一个“参与方”数量的进程没有调用 .wait()

    病人池 parties 此屏障中的参数设置为子进程数+1。子进程正在调用 .wait() 在这个障碍上,一旦他们完成 病人池 .wait() 在这个屏障上。这发生在 _terminate_pool 我们正在重写的方法 游泳池 为此目的。这样做可以防止池过早终止子进程,因为所有进程都在调用 .wait()

    # at_exit_pool.py
    
    class PatientPool(Pool):
        """Pool class which awaits completion of exit handlers in child processes
        before terminating the processes."""
    
        def __init__(self, at_exit, at_exit_args=(), processes=None,
                     initializer=None, initargs=(), maxtasksperchild=None,
                     context=None):
            # changed--------------------------------------------------------------
            self._barrier = self._get_barrier(processes)
    
            at_exit_args = (at_exit, self._barrier) + at_exit_args
            initargs = (at_exit_args, initializer, initargs)
    
            super().__init__(
                processes, initializer=combi_initializer, initargs=initargs,
                maxtasksperchild=maxtasksperchild, context=context
            )
            # ---------------------------------------------------------------------
    
        @staticmethod
        def _get_barrier(processes):
            """Get Barrier object for use in _terminate_pool and
            child processes."""
            if processes is None:  # this will be repeated in super().__init__(...)
                processes = os.cpu_count() or 1
            if processes < 1:
                raise ValueError("Number of processes must be at least 1")
    
            return Barrier(processes + 1)
    
        def _terminate_pool(self, taskqueue, inqueue, outqueue, pool,
                            worker_handler, task_handler, result_handler, cache):
            """changed from classmethod to normal method"""
            # this is guaranteed to only be called once
            util.debug('finalizing pool')
    
            worker_handler._state = TERMINATE
            task_handler._state = TERMINATE
    
            util.debug('helping task handler/workers to finish')
            self.__class__._help_stuff_finish(inqueue, task_handler, len(pool))  # changed
    
            assert result_handler.is_alive() or len(cache) == 0
    
            result_handler._state = TERMINATE
            outqueue.put(None)  # sentinel
    
            # We must wait for the worker handler to exit before terminating
            # workers because we don't want workers to be restarted behind our back.
            util.debug('joining worker handler')
            if threading.current_thread() is not worker_handler:
                worker_handler.join()
    
            # patch ---------------------------------------------------------------
            print('_terminate_pool barrier waiting')  # DEBUG
            self._barrier.wait()  # <- blocks until all processes have called wait()
            print('_terminate_pool barrier crossed')  # DEBUG
            # ---------------------------------------------------------------------
    
            # Terminate workers which haven't already finished.
            if pool and hasattr(pool[0], 'terminate'):
                util.debug('terminating workers')
                for p in pool:
                    if p.exitcode is None:
                        p.terminate()
    
            util.debug('joining task handler')
            if threading.current_thread() is not task_handler:
                task_handler.join()
    
            util.debug('joining result handler')
            if threading.current_thread() is not result_handler:
                result_handler.join()
    
            if pool and hasattr(pool[0], 'terminate'):
                util.debug('joining pool workers')
                for p in pool:
                    if p.is_alive():
                        # worker has not yet exited
                        util.debug('cleaning up worker %d' % p.pid)
                        p.join()
    

    游泳池 病人池 在出口处 local_score 必须是全局变量,以便退出处理程序可以访问它。

    import os
    from functools import wraps
    # from multiprocessing import log_to_stderr, set_start_method
    # import logging
    import toml
    from at_exit_pool import register_at_exit, PatientPool
    
    
    local_score = {
        "hits": 0,
        "misses": 0
    }
    
    
    def memoize(func):
    
        cache = {}
    
        @wraps(func)
        def wrap(*args):
            cache_key = str(args)  # ~14% faster than pickle.dumps(args)
            if cache_key not in cache:
                local_score["misses"] += 1
                cache[cache_key] = func(*args)
            else:
                local_score["hits"] += 1
            return cache[cache_key]
    
        return wrap
    
    
    @memoize
    def foo(x):
        for _ in range(int(x)):
            x - 1
        return x
    
    
    def dump_score(pathfile):
        with open(pathfile, 'a') as fh:
            toml.dump({str(os.getpid()): local_score}, fh)
    
    
    if __name__ == '__main__':
    
        # set_start_method('spawn')
        # logger = log_to_stderr()
        # logger.setLevel(logging.DEBUG)
    
        PATHFILE = 'score.toml'
        N_WORKERS = 4
    
        arguments = [10e6 + i for i in range(10)] * 5
        # print(arguments[:10])
    
        with PatientPool(at_exit=dump_score, at_exit_args=(PATHFILE,),
                         processes=N_WORKERS) as pool:
    
            results = pool.map(foo, arguments, chunksize=3)
            # print(results[:10])
    

    运行此示例将产生这样的终端输出,其中“\u terminate\u pool barrier crossed”始终在最后执行,而此行之前的流可能会有所不同:

    555 barrier waiting
    _terminate_pool barrier waiting
    554 barrier waiting
    556 barrier waiting
    557 barrier waiting
    _terminate_pool barrier crossed
    
    Process finished with exit code 0
    

    [555]
    hits = 3
    misses = 8
    [554]
    hits = 3
    misses = 9
    [556]
    hits = 2
    misses = 10
    [557]
    hits = 5
    misses = 10