通过子类化可以相对容易地实现这一点
Process
multiprocessing.Pool
,越来越复杂了。
Pool
有两个问题需要解决。
-
使子进程在进程终止时调用退出处理程序。
-
防止
在退出处理程序完成之前终止子进程。
为了使用forking作为子进程的启动方法,我发现有必要使用monkey补丁
multiprocessing.pool.worker
. 我们可以用
atexit
退出
. 补丁是一个包裹
worker
打电话给我们的客户
at_exit
当工作进程返回时执行函数,这在进程即将退出时发生。
# at_exit_pool.py
import os
import threading
from functools import wraps
import multiprocessing.pool
from multiprocessing.pool import worker, TERMINATE, Pool
from multiprocessing import util, Barrier
from functools import partial
def finalized(worker):
"""Extend worker function with at_exit call."""
@wraps(worker)
def wrapper(*args, **kwargs):
result = worker(*args, **kwargs)
at_exit() # <-- patch
return result
return wrapper
worker = finalized(worker)
multiprocessing.pool.worker = worker # patch
这个解决方案也是子类化的
处理这两个问题。
PatientPool
正在引入两个强制参数
在出口处
和
at_exit_args
在出口处
病人池
是背驮
initializer
从标准
游泳池
在子进程中注册退出处理程序。以下是处理注册退出处理程序的函数:
# at_exit_pool.py
def at_exit(func=None, barrier=None, *args):
"""Call at_exit function and wait on barrier."""
func(*args)
print(os.getpid(), 'barrier waiting') # DEBUG
barrier.wait()
def register_at_exit(func, barrier, *args):
"""Register at_exit function."""
global at_exit
at_exit = partial(at_exit, func, barrier, *args)
def combi_initializer(at_exit_args, initializer, initargs):
"""Piggyback initializer with register_at_exit."""
if initializer:
initializer(*initargs)
register_at_exit(*at_exit_args)
正如你在书中看到的
,我们将使用
multiprocessing.Barrier
游泳池
屏障的工作方式是阻止任何进程调用
.wait()
只要一个“参与方”数量的进程没有调用
.wait()
病人池
parties
此屏障中的参数设置为子进程数+1。子进程正在调用
.wait()
在这个障碍上,一旦他们完成
病人池
.wait()
在这个屏障上。这发生在
_terminate_pool
我们正在重写的方法
游泳池
为此目的。这样做可以防止池过早终止子进程,因为所有进程都在调用
.wait()
# at_exit_pool.py
class PatientPool(Pool):
"""Pool class which awaits completion of exit handlers in child processes
before terminating the processes."""
def __init__(self, at_exit, at_exit_args=(), processes=None,
initializer=None, initargs=(), maxtasksperchild=None,
context=None):
# changed--------------------------------------------------------------
self._barrier = self._get_barrier(processes)
at_exit_args = (at_exit, self._barrier) + at_exit_args
initargs = (at_exit_args, initializer, initargs)
super().__init__(
processes, initializer=combi_initializer, initargs=initargs,
maxtasksperchild=maxtasksperchild, context=context
)
# ---------------------------------------------------------------------
@staticmethod
def _get_barrier(processes):
"""Get Barrier object for use in _terminate_pool and
child processes."""
if processes is None: # this will be repeated in super().__init__(...)
processes = os.cpu_count() or 1
if processes < 1:
raise ValueError("Number of processes must be at least 1")
return Barrier(processes + 1)
def _terminate_pool(self, taskqueue, inqueue, outqueue, pool,
worker_handler, task_handler, result_handler, cache):
"""changed from classmethod to normal method"""
# this is guaranteed to only be called once
util.debug('finalizing pool')
worker_handler._state = TERMINATE
task_handler._state = TERMINATE
util.debug('helping task handler/workers to finish')
self.__class__._help_stuff_finish(inqueue, task_handler, len(pool)) # changed
assert result_handler.is_alive() or len(cache) == 0
result_handler._state = TERMINATE
outqueue.put(None) # sentinel
# We must wait for the worker handler to exit before terminating
# workers because we don't want workers to be restarted behind our back.
util.debug('joining worker handler')
if threading.current_thread() is not worker_handler:
worker_handler.join()
# patch ---------------------------------------------------------------
print('_terminate_pool barrier waiting') # DEBUG
self._barrier.wait() # <- blocks until all processes have called wait()
print('_terminate_pool barrier crossed') # DEBUG
# ---------------------------------------------------------------------
# Terminate workers which haven't already finished.
if pool and hasattr(pool[0], 'terminate'):
util.debug('terminating workers')
for p in pool:
if p.exitcode is None:
p.terminate()
util.debug('joining task handler')
if threading.current_thread() is not task_handler:
task_handler.join()
util.debug('joining result handler')
if threading.current_thread() is not result_handler:
result_handler.join()
if pool and hasattr(pool[0], 'terminate'):
util.debug('joining pool workers')
for p in pool:
if p.is_alive():
# worker has not yet exited
util.debug('cleaning up worker %d' % p.pid)
p.join()
游泳池
病人池
在出口处
local_score
必须是全局变量,以便退出处理程序可以访问它。
import os
from functools import wraps
# from multiprocessing import log_to_stderr, set_start_method
# import logging
import toml
from at_exit_pool import register_at_exit, PatientPool
local_score = {
"hits": 0,
"misses": 0
}
def memoize(func):
cache = {}
@wraps(func)
def wrap(*args):
cache_key = str(args) # ~14% faster than pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
cache[cache_key] = func(*args)
else:
local_score["hits"] += 1
return cache[cache_key]
return wrap
@memoize
def foo(x):
for _ in range(int(x)):
x - 1
return x
def dump_score(pathfile):
with open(pathfile, 'a') as fh:
toml.dump({str(os.getpid()): local_score}, fh)
if __name__ == '__main__':
# set_start_method('spawn')
# logger = log_to_stderr()
# logger.setLevel(logging.DEBUG)
PATHFILE = 'score.toml'
N_WORKERS = 4
arguments = [10e6 + i for i in range(10)] * 5
# print(arguments[:10])
with PatientPool(at_exit=dump_score, at_exit_args=(PATHFILE,),
processes=N_WORKERS) as pool:
results = pool.map(foo, arguments, chunksize=3)
# print(results[:10])
运行此示例将产生这样的终端输出,其中“\u terminate\u pool barrier crossed”始终在最后执行,而此行之前的流可能会有所不同:
555 barrier waiting
_terminate_pool barrier waiting
554 barrier waiting
556 barrier waiting
557 barrier waiting
_terminate_pool barrier crossed
Process finished with exit code 0
[555]
hits = 3
misses = 8
[554]
hits = 3
misses = 9
[556]
hits = 2
misses = 10
[557]
hits = 5
misses = 10