diff --git a/parfor/__init__.py b/parfor/__init__.py index 960d3b1..57c7a60 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -1,7 +1,6 @@ from __future__ import print_function import sys import multiprocessing -import warnings import dill from tqdm.auto import tqdm from traceback import format_exc @@ -19,6 +18,7 @@ except ImportError: failed_rv = (lambda *args, **kwargs: None, ()) + class Pickler(dill.Pickler): """ Overload dill to ignore unpickleble parts of objects. You probably didn't want to use these parts anyhow. @@ -105,6 +105,7 @@ class Pickler(dill.Pickler): except: self.save_reduce(obj=obj, *failed_rv) + def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): """pickle an object to a string""" protocol = dill.settings['protocol'] if protocol is None else int(protocol) @@ -114,9 +115,10 @@ def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): Pickler(file, protocol, **_kwds).dump(obj) return file.getvalue() + def chunks(n, *args): """ Yield successive n-sized chunks from lists. """ - A = len(args)==1 + A = len(args) == 1 N = len(args[0]) n = int(round(N/max(1, round(N/n)))) for i in range(0, N, n) if N else []: @@ -125,13 +127,14 @@ def chunks(n, *args): else: yield [a[i:i+n] for a in args] + class tqdmm(tqdm): """ Overload tqdm to make a special version of tqdm functioning as a meter. """ def __init__(self, *args, **kwargs): self._n = 0 self.disable = False - if not 'bar_format' in kwargs and len(args) < 16: + if 'bar_format' not in kwargs and len(args) < 16: kwargs['bar_format'] = '{n}/{total}' super(tqdmm, self).__init__(*args, **kwargs) @@ -150,8 +153,9 @@ class tqdmm(tqdm): self.n = self.total super(tqdmm, self).__exit__(exc_type, exc_value, traceback) + def parfor(*args, **kwargs): - """ @parfor(iterator=None, args=(), kwargs={}, length=None, desc=None, bar=True, qbar=True, rP=1/3, serial=4, debug=False): + """ @parfor(iterator=None, args=(), kwargs={}, length=None, desc=None, bar=True, qbar=True, rP=1/3, serial=4): decorator to parallize for-loops required arguments: @@ -169,7 +173,6 @@ def parfor(*args, **kwargs): nP: number of workers, default: None, overrides rP if not None number of workers will always be at least 2 serial: switch to serial if number of tasks less than serial, default: 4 - debug: if an error occurs in an iteration, return the erorr instead of retrying in the main process output: list with results from applying the decorated function to each iteration of the iterator specified as the first argument to the function @@ -236,12 +239,12 @@ def parfor(*args, **kwargs): return pmap(fun, *args, **kwargs) return decfun + class parpool(object): """ Parallel processing with addition of iterations at any time and request of that result any time after that. The target function and its argument can be changed at any time. """ - def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None, - debug=False): + def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None): """ fun, args, kwargs: target function and its arguments and keyword arguments rP: ratio workers to cpu cores, default: 1 nP: number of workers, default, None, overrides rP if not None @@ -256,22 +259,22 @@ class parpool(object): self.fun = fun or (lambda x: x) self.args = args or () self.kwargs = kwargs or {} - self.debug = debug if hasattr(multiprocessing, 'get_context'): ctx = multiprocessing.get_context('spawn') else: ctx = multiprocessing + self.A = ctx.Value('i', self.nP) self.E = ctx.Event() self.Qi = ctx.Queue(3*self.nP) self.Qo = ctx.Queue(3*self.nP) - self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.E, terminator, self.debug)) + self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.A, self.E, terminator)) self.is_alive = True self.res = {} self.handle = 0 self.handles = [] self.bar = bar self.qbar = qbar - if not self.qbar is None: + if self.qbar is not None: self.qbar.total = 3*self.nP @property @@ -310,28 +313,19 @@ class parpool(object): def _getfromq(self): """ Get an item from the queue and store it. """ try: - r = self.Qo.get(True, 0.02) - if r[0] is None: - self.res[r[1]] = dill.loads(r[2]) - elif r[0] is False: - pfmt = warnings.formatwarning - warnings.formatwarning = lambda message, *args: '{}\n'.format(message) - warnings.warn( - 'Warning, error occurred in iteration {}. The iteration will be retried and should raise a ' - 'debuggable error. If it doesn\'t, it\'s an error specific to parallel execution.' - .format(r[1])) - warnings.formatwarning = pfmt - fun, args, kwargs = [dill.loads(f[1]) for f in r[2][1:]] - r = (False, r[1], fun(dill.loads(r[2][0]), *args, **kwargs)) - self.res[r[1]] = r[2] + err, i, res = self.Qo.get(True, 0.02) + if not err: + self.res[i] = dill.loads(res) else: - err = dill.loads(r[2]) - pfmt = warnings.formatwarning - warnings.formatwarning = lambda message, *args: '{}\n'.format(message) - warnings.warn('Warning, error occurred in iteration {}:\n{}'.format(r[1], err)) - warnings.formatwarning = pfmt - self.res[r[1]] = err - if not self.bar is None: + n, fun, args, kwargs, e = [dill.loads(r) for r in res] + print('Error from process working on iteration {}:\n'.format(i)) + print(e) + self.close() + print('Retrying in main thread...') + self.res[i] = fun(n, *args, **kwargs) + raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.' + .format(fun.__name__)) + if self.bar is not None: self.bar.update() self._qbar_update() except multiprocessing.queues.Empty: @@ -339,29 +333,30 @@ class parpool(object): def __call__(self, n, fun=None, args=None, kwargs=None, handle=None): """ Add new iteration, using optional manually defined handle.""" - n = dumps(n, recurse=True) - if not fun is None: - self.fun = fun - if not args is None: - self.args = args - if not kwargs is None: - self.kwargs = kwargs - while self.Qi.full(): - self._getfromq() - if handle is None: - handle = self.handle - self.handle += 1 - self.handles.append(handle) - self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) + if self.is_alive and not self.E.is_set(): + n = dumps(n, recurse=True) + if fun is not None: + self.fun = fun + if args is not None: + self.args = args + if kwargs is not None: + self.kwargs = kwargs + while self.Qi.full(): + self._getfromq() + if handle is None: + handle = self.handle + self.handle += 1 + self.handles.append(handle) + self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) + self._qbar_update() + return handle + elif handle not in self: + self.handles.append(handle) + self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) self._qbar_update() - return handle - elif not handle in self: - self.handles.append(handle) - self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) - self._qbar_update() def _qbar_update(self): - if not self.qbar is None: + if self.qbar is not None: try: self.qbar.n = self.Qi.qsize() except: @@ -373,16 +368,16 @@ class parpool(object): def __getitem__(self, handle): """ Request result and delete its record. Wait if result not yet available. """ - if not handle in self: + if handle not in self: raise ValueError('No handle: {}'.format(handle)) - while not handle in self.res: + while handle not in self.res: self._getfromq() self.handles.remove(handle) return self.res.pop(handle) def get_newest(self): - """ Request the newest key and result and delete its record. Wait if result not yet available. """ - if len(self.handles): + """ Request the newest key and result and delete its record. Wait if result not yet available. """ + if len(self.handles): while not len(self.res): self._getfromq() key = list(self.res.keys())[0] @@ -402,20 +397,34 @@ class parpool(object): def close(self): if self.is_alive: + self.is_alive = False self.E.set() + self.P.close() + while self.A.value: + self._empty_queue(self.Qi) + self._empty_queue(self.Qo) + self._empty_queue(self.Qi) + self._empty_queue(self.Qo) + self.P.join() self._close_queue(self.Qi) self._close_queue(self.Qo) - self.P.close() - self.P.join() - self.is_alive = False self.res = {} self.handle = 0 self.handles = [] + @staticmethod + def _empty_queue(Q): + if not Q._closed: + while not Q.empty(): + try: + Q.get(True, 0.02) + except multiprocessing.queues.Empty: + pass + @staticmethod def _close_queue(Q): if not Q._closed: - while Q.full(): + while not Q.empty(): try: Q.get(True, 0.02) except multiprocessing.queues.Empty: @@ -425,15 +434,22 @@ class parpool(object): class _worker(object): """ Manages executing the target function which will be executed in different processes. """ - def __init__(self, Qi, Qo, E, terminator, debug=False, cachesize=48): + def __init__(self, Qi, Qo, A, E, terminator, cachesize=48): self.cache = [] self.Qi = Qi self.Qo = Qo + self.A = A self.E = E self.terminator = dumps(terminator, recurse=True) - self.debug = debug self.cachesize = cachesize - # print(self.terminator) + + def add_to_q(self, value): + while not self.E.is_set(): + try: + self.Qo.put(value, timeout=0.1) + break + except multiprocessing.queues.Full: + continue def __call__(self): while not self.E.is_set(): @@ -443,17 +459,17 @@ class parpool(object): fun = self.get_from_cache(*Fun) args = self.get_from_cache(*Args) kwargs = self.get_from_cache(*Kwargs) - self.Qo.put((None, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True))) + self.add_to_q((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True))) except multiprocessing.queues.Empty: continue except: - if self.debug: - self.Qo.put((True, i, dumps(format_exc(), recurse=True))) - else: - self.Qo.put((False, i, (n, Fun, Args, Kwargs))) + self.add_to_q((True, i, (n, Fun[1], Args[1], Kwargs[1], dumps(format_exc(), recurse=True)))) + self.E.set() terminator = dill.loads(self.terminator) - if not terminator is None: + if terminator is not None: terminator() + with self.A.get_lock(): + self.A.value -= 1 def get_from_cache(self, h, ser): if len(self.cache): @@ -466,8 +482,9 @@ class parpool(object): self.cache.pop(0) return obj + def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar=True, qbar=False, terminator=None, - rP=1, nP=None, serial=4, debug=False): + rP=1, nP=None, serial=4): """ map a function fun to each iteration in iterable best use: iterable is a generator and length is given to this function @@ -483,7 +500,6 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar rP: ratio workers to cpu cores, default: 1 nP: number of workers, default, None, overrides rP if not None serial: switch to serial if number of tasks less than serial, default: 4 - debug: if an error occurs in an iteration, return the erorr instead of retrying in the main process """ args = args or () kwargs = kwargs or {} @@ -491,16 +507,16 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar length = len(iterable) except: pass - if length and length