- Use a singleton pool to prevent lengthy restarts of the pool, this also means that arguments for pool size have gone.

- Removed qbar and TqdmMeter.
- Wrap chunked function for better error messages.
This commit is contained in:
Wim Pomp
2023-09-08 17:19:35 +02:00
parent 92f162b7d5
commit f3302f6dba
6 changed files with 246 additions and 246 deletions

View File

@@ -1,19 +1,15 @@
import multiprocessing
from os import getpid
from tqdm.auto import tqdm
from traceback import format_exc
from collections import OrderedDict
from warnings import warn
from contextlib import ExitStack
from functools import wraps
from os import getpid
from traceback import format_exc
from warnings import warn
from tqdm.auto import tqdm
from .pickler import Pickler, dumps, loads
try:
from javabridge import kill_vm
except ImportError:
kill_vm = lambda: None
cpu_count = int(multiprocessing.cpu_count())
@@ -110,43 +106,6 @@ class ExternalBar:
self.callback(n)
class TqdmMeter(tqdm):
""" Overload tqdm to make a special version of tqdm functioning as a meter. """
def __init__(self, iterable=None, desc=None, total=None, *args, **kwargs):
self._n = 0
self._total = total
self.disable = False
if 'bar_format' not in kwargs and len(args) < 16:
kwargs['bar_format'] = '{desc}{bar}{n}/{total}'
super().__init__(iterable, desc, total, *args, **kwargs)
@property
def n(self):
return self._n
@n.setter
def n(self, value):
if not value == self.n:
self._n = int(value)
self.refresh()
@property
def total(self):
return self._total
@total.setter
def total(self, value):
self._total = value
if hasattr(self, 'container'):
self.container.children[1].max = value
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
if not self.leave:
self.n = self.total
super().__exit__(exc_type, exc_value, traceback)
class Hasher:
def __init__(self, obj, hsh=None):
if hsh is not None:
@@ -207,7 +166,8 @@ class Task:
args = HashDescriptor()
kwargs = HashDescriptor()
def __init__(self, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
self.pool_id = pool_id
self.fun = fun or (lambda *args, **kwargs: None)
self.args = args or ()
self.kwargs = kwargs or {}
@@ -219,10 +179,11 @@ class Task:
def __reduce__(self):
if self.done:
return self.__class__, (None, None, None, self.handle, None, self.done, dumps(self.result, recurse=True))
return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done,
dumps(self.result, recurse=True))
else:
return self.__class__, (self._fun, self._args, self._kwargs, self.handle, dumps(self.n, recurse=True),
self.done)
return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle,
dumps(self.n, recurse=True), self.done)
def set_from_cache(self, cache=None):
self.n = loads(self.n)
@@ -238,9 +199,9 @@ class Task:
def __repr__(self):
if self.done:
return 'Task {}, result: {}'.format(self.handle, self.result)
return f'Task {self.handle}, result: {self.result}'
else:
return 'Task {}'.format(self.handle)
return f'Task {self.handle}'
class Context(multiprocessing.context.SpawnContext):
@@ -255,38 +216,23 @@ class Context(multiprocessing.context.SpawnContext):
pass
class Parpool:
class ParPool:
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, qsize=None):
""" fun, args, kwargs: target function and its arguments and keyword arguments
rP: ratio workers to cpu cores, default: 1
nP: number of workers, default, None, overrides rP if not None
bar, qbar: instances of tqdm and tqdmm to use for monitoring buffer and progress """
if rP is None and nP is None:
self.nP = cpu_count
elif nP is None:
self.nP = int(round(rP * cpu_count))
else:
self.nP = int(nP)
self.nP = max(self.nP, 2)
self.task = Task(fun, args, kwargs)
self.is_started = False
ctx = Context()
self.n_tasks = ctx.Value('i', self.nP)
self.event = ctx.Event()
self.queue_in = ctx.Queue(qsize or 3 * self.nP)
self.queue_out = ctx.Queue(qsize or 12 * self.nP)
self.pool = ctx.Pool(self.nP, self._Worker(self.queue_in, self.queue_out, self.n_tasks, self.event))
self.is_alive = True
def __init__(self, fun=None, args=None, kwargs=None, bar=None):
self.id = id(self)
self.handle = 0
self.tasks = {}
self.last_task = Task(self.id, fun, args, kwargs)
self.bar = bar
self.bar_lengths = {}
self.qbar = qbar
if self.qbar is not None:
self.qbar.total = 3 * self.nP
self.spool = PoolSingleton(self)
self.manager = self.spool.manager
self.is_started = False
def __getstate__(self):
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def __enter__(self, *args, **kwargs):
return self
@@ -294,27 +240,66 @@ class Parpool:
def __exit__(self, *args, **kwargs):
self.close()
def _get_from_queue(self):
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
code, *args = self.queue_out.get(True, 0.02)
getattr(self, code)(*args)
return True
except multiprocessing.queues.Empty:
for handle, task in self.tasks.items(): # retry a task if the process doing it was killed
if task.pid is not None and task.pid not in [child.pid for child in multiprocessing.active_children()]:
self.queue_in.put(task)
warn('Task {} was restarted because process {} was probably killed.'.format(task.handle, task.pid))
return False
def close(self):
if self.id in self.spool.pools:
self.spool.pools.pop(self.id)
def error(self, error):
self.close()
raise Exception('Error occurred in worker: {}'.format(error))
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None:
new_handle = self.handle
self.handle += 1
else:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
self.last_task = Task(self.id, fun or self.last_task.fun, args or self.last_task.args,
kwargs or self.last_task.kwargs, new_handle, n)
self.tasks[new_handle] = self.last_task
self.bar_lengths[new_handle] = barlength
self.spool.add_task(self.last_task)
if handle is None:
return new_handle
def __setitem__(self, handle, n):
""" Add new iteration. """
self(n, handle=handle)
def __getitem__(self, handle):
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError(f'No handle: {handle} in pool')
while not self.tasks[handle].done:
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
for _ in range(10): # wait some time while processing possible new messages
self.spool.get_from_queue()
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
# retry a task if the process was killed while retrieving the task
self.spool.add_task(self.tasks[handle])
warn(f'Task {handle} was restarted because the process retrieving it was probably killed.')
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
def __contains__(self, handle):
return handle in self.tasks
def __delitem__(self, handle):
self.tasks.pop(handle)
def get_newest(self):
return self.spool.get_newest_for_pool(self)
def process_queue(self):
self.spool.process_queue()
def task_error(self, handle, error):
if handle in self:
task = self.tasks[handle]
print('Error from process working on iteration {}:\n'.format(handle))
print(f'Error from process working on iteration {handle}:\n')
print(error)
self.close()
print('Retrying in main thread...')
@@ -325,166 +310,163 @@ class Parpool:
def done(self, task):
if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task
if self.bar is not None:
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle))
self._qbar_update()
def started(self, handle, pid):
self.is_started = True
if handle in self: # if not, the task was restarted erroneously
self.tasks[handle].pid = pid
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
""" Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.event.is_set():
self.task = Task(fun or self.task.fun, args or self.task.args, kwargs or self.task.kwargs, handle, n)
while self.queue_in.full():
self._get_from_queue()
if handle is None:
handle = self.handle
self.handle += 1
self.tasks[handle] = self.task
self.queue_in.put(self.task)
self.bar_lengths[handle] = barlength
self._qbar_update()
return handle
elif handle not in self:
self.tasks[handle] = self.task
self.queue_in.put(self.task)
self.bar_lengths[handle] = barlength
self._qbar_update()
def _qbar_update(self):
if self.qbar is not None:
try:
self.qbar.n = self.queue_in.qsize()
except Exception:
pass
def __setitem__(self, handle, n):
""" Add new iteration. """
self(n, handle=handle)
def __getitem__(self, handle):
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError('No handle: {}'.format(handle))
while not self.tasks[handle].done:
if not self._get_from_queue() and not self.tasks[handle].done and self.is_started and not self.working:
for _ in range(10): # wait some time while processing possible new messages
self._get_from_queue()
if not self._get_from_queue() and not self.tasks[handle].done and self.is_started and not self.working:
# retry a task if the process was killed while retrieving the task
self.queue_in.put(self.tasks[handle])
warn('Task {} was restarted because the process retrieving it was probably killed.'.format(handle))
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
@property
def working(self):
return not all([task.pid is None for task in self.tasks.values()])
def get_newest(self):
class PoolSingleton:
def __new__(cls, *args, **kwargs):
if not hasattr(cls, 'instance') or cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls)
new.n_processes = cpu_count
new.instance = new
new.is_started = False
ctx = Context()
new.n_workers = ctx.Value('i', new.n_processes)
new.event = ctx.Event()
new.queue_in = ctx.Queue(3 * new.n_processes)
new.queue_out = ctx.Queue(new.n_processes)
new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event))
new.is_alive = True
new.handle = 0
new.pools = {}
new.manager = ctx.Manager()
cls.instance = new
return cls.instance
def __init__(self, parpool=None):
if parpool is not None:
self.pools[parpool.id] = parpool
def __getstate__(self):
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def error(self, error):
self.close()
raise Exception(f'Error occurred in worker: {error}')
def process_queue(self):
while self.get_from_queue():
pass
def get_from_queue(self):
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
code, pool_id, *args = self.queue_out.get(True, 0.02)
if pool_id is None:
getattr(self, code)(*args)
elif pool_id in self.pools:
getattr(self.pools[pool_id], code)(*args)
return True
except multiprocessing.queues.Empty:
for pool in self.pools.values():
for handle, task in pool.tasks.items(): # retry a task if the process doing it was killed
if task.pid is not None \
and task.pid not in [child.pid for child in multiprocessing.active_children()]:
self.queue_in.put(task)
warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.')
return False
def add_task(self, task):
""" Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.event.is_set():
while self.queue_in.full():
self.get_from_queue()
self.queue_in.put(task)
def get_newest_for_pool(self, pool):
""" Request the newest key and result and delete its record. Wait if result not yet available. """
while len(self.tasks):
self._get_from_queue()
for task in self.tasks.values():
while len(pool.tasks):
self.get_from_queue()
for task in pool.tasks.values():
if task.done:
handle, result = task.handle, task.result
self.tasks.pop(handle)
pool.tasks.pop(handle)
return handle, result
def __delitem__(self, handle):
self.tasks.pop(handle)
def __contains__(self, handle):
return handle in self.tasks
def __repr__(self):
if self.is_alive:
return '{} with {} workers.'.format(self.__class__, self.nP)
else:
return 'Closed {}'.format(self.__class__)
def close(self):
self.__class__.instance = None
def empty_queue(queue):
if not queue._closed:
while not queue.empty():
try:
queue.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
def close_queue(queue):
empty_queue(queue)
if not queue._closed:
queue.close()
queue.join_thread()
if self.is_alive:
self.is_alive = False
self.event.set()
self.pool.close()
while self.n_tasks.value:
self._empty_queue(self.queue_in)
self._empty_queue(self.queue_out)
self._empty_queue(self.queue_in)
self._empty_queue(self.queue_out)
while self.n_workers.value:
empty_queue(self.queue_in)
empty_queue(self.queue_out)
empty_queue(self.queue_in)
empty_queue(self.queue_out)
self.pool.join()
self._close_queue(self.queue_in)
self._close_queue(self.queue_out)
close_queue(self.queue_in)
close_queue(self.queue_out)
self.handle = 0
self.tasks = {}
@staticmethod
def _empty_queue(queue):
if not queue._closed:
while not queue.empty():
try:
queue.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
@staticmethod
def _close_queue(queue):
if not queue._closed:
while not queue.empty():
try:
queue.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
queue.close()
queue.join_thread()
class Worker:
""" Manages executing the target function which will be executed in different processes. """
def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48):
self.cache = DequeDict(cachesize)
self.queue_in = queue_in
self.queue_out = queue_out
self.n_workers = n_workers
self.event = event
class _Worker(object):
""" Manages executing the target function which will be executed in different processes. """
def __init__(self, queue_in, queue_out, n_tasks, event, cachesize=48):
self.cache = DequeDict(cachesize)
self.queue_in = queue_in
self.queue_out = queue_out
self.n_tasks = n_tasks
self.event = event
def add_to_queue(self, *args):
while not self.event.is_set():
try:
self.queue_out.put(args, timeout=0.1)
break
except multiprocessing.queues.Full:
continue
def add_to_queue(self, *args):
while not self.event.is_set():
def __call__(self):
pid = getpid()
while not self.event.is_set():
try:
task = self.queue_in.get(True, 0.02)
try:
self.queue_out.put(args, timeout=0.1)
break
except multiprocessing.queues.Full:
continue
def __call__(self):
pid = getpid()
while not self.event.is_set():
try:
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.handle, pid)
task.set_from_cache(self.cache)
self.add_to_queue('done', task())
except Exception:
self.add_to_queue('task_error', task.handle, format_exc())
self.event.set()
except multiprocessing.queues.Empty:
continue
self.add_to_queue('started', task.pool_id, task.handle, pid)
task.set_from_cache(self.cache)
self.add_to_queue('done', task.pool_id, task())
except Exception:
self.add_to_queue('error', format_exc())
self.event.set()
for child in multiprocessing.active_children():
child.kill()
with self.n_tasks.get_lock():
self.n_tasks.value -= 1
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
except multiprocessing.queues.Empty:
continue
except Exception:
self.add_to_queue('error', None, format_exc())
self.event.set()
for child in multiprocessing.active_children():
child.kill()
with self.n_workers.get_lock():
self.n_workers.value -= 1
def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=True, qbar=False, terminator=None,
rP=1, nP=None, serial=None, qsize=None, length=None, **bar_kwargs):
def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=True, terminator=None,
serial=None, length=None, **bar_kwargs):
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
@@ -500,11 +482,7 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
desc: string with description of the progress bar
bar: bool enable progress bar,
or a callback function taking the number of passed iterations as an argument
pbar: bool enable buffer indicator bar, or a callback function taking the queue size as an argument
rP: ratio workers to cpu cores, default: 1
nP: number of workers, default, None, overrides rP if not None
serial: execute in series instead of parallel if True, None (default): let pmap decide
qsize: maximum size of the task queue
length: deprecated alias for total
**bar_kwargs: keywords arguments for tqdm.tqdm
@@ -569,6 +547,7 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
else:
iterable = Chunks(iterable, ratio=5, length=total)
@wraps(fun)
def chunk_fun(iterator, *args, **kwargs):
return [fun(i, *args, **kwargs) for i in iterator]
@@ -587,10 +566,12 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
else:
return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], [])
else: # parallel case
with ExternalBar(callback=qbar) if callable(qbar) \
else TqdmMeter(total=0, desc='Task buffer', disable=not qbar, leave=False) as qbar, \
ExternalBar(callback=bar) if callable(bar) else tqdm(**bar_kwargs) as bar:
with Parpool(chunk_fun, args, kwargs, rP, nP, bar, qbar, qsize) as p:
with ExitStack() as stack:
if callable(bar):
bar = stack.enter_context(ExternalBar(callback=bar))
elif bar is True:
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, bar) as p:
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=iterable.lengths[i])
if bar.total is None or bar.total < i+1:
@@ -622,6 +603,6 @@ def deprecated(cls, name):
# backwards compatibility
parpool = deprecated(Parpool, 'parpool')
tqdmm = deprecated(TqdmMeter, 'tqdmm')
parpool = deprecated(ParPool, 'parpool')
Parpool = deprecated(ParPool, 'Parpool')
chunks = deprecated(Chunks, 'chunks')

View File

@@ -1,7 +1,7 @@
import dill
from pickle import PicklingError, dispatch_table
from io import BytesIO
from pickle import PicklingError, dispatch_table
import dill
failed_rv = (lambda *args, **kwargs: None, ())
loads = dill.loads