- Better and faster error handling.

This commit is contained in:
Wim Pomp
2021-03-19 20:39:38 +01:00
parent ddc0189bad
commit 5d57081713
2 changed files with 92 additions and 76 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import print_function from __future__ import print_function
import sys import sys
import multiprocessing import multiprocessing
import warnings
import dill import dill
from tqdm.auto import tqdm from tqdm.auto import tqdm
from traceback import format_exc from traceback import format_exc
@@ -19,6 +18,7 @@ except ImportError:
failed_rv = (lambda *args, **kwargs: None, ()) failed_rv = (lambda *args, **kwargs: None, ())
class Pickler(dill.Pickler): class Pickler(dill.Pickler):
""" Overload dill to ignore unpickleble parts of objects. """ Overload dill to ignore unpickleble parts of objects.
You probably didn't want to use these parts anyhow. You probably didn't want to use these parts anyhow.
@@ -105,6 +105,7 @@ class Pickler(dill.Pickler):
except: except:
self.save_reduce(obj=obj, *failed_rv) self.save_reduce(obj=obj, *failed_rv)
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):
"""pickle an object to a string""" """pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol) protocol = dill.settings['protocol'] if protocol is None else int(protocol)
@@ -114,9 +115,10 @@ def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):
Pickler(file, protocol, **_kwds).dump(obj) Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue() return file.getvalue()
def chunks(n, *args): def chunks(n, *args):
""" Yield successive n-sized chunks from lists. """ """ Yield successive n-sized chunks from lists. """
A = len(args)==1 A = len(args) == 1
N = len(args[0]) N = len(args[0])
n = int(round(N/max(1, round(N/n)))) n = int(round(N/max(1, round(N/n))))
for i in range(0, N, n) if N else []: for i in range(0, N, n) if N else []:
@@ -125,13 +127,14 @@ def chunks(n, *args):
else: else:
yield [a[i:i+n] for a in args] yield [a[i:i+n] for a in args]
class tqdmm(tqdm): class tqdmm(tqdm):
""" Overload tqdm to make a special version of tqdm functioning as a meter. """ """ Overload tqdm to make a special version of tqdm functioning as a meter. """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._n = 0 self._n = 0
self.disable = False self.disable = False
if not 'bar_format' in kwargs and len(args) < 16: if 'bar_format' not in kwargs and len(args) < 16:
kwargs['bar_format'] = '{n}/{total}' kwargs['bar_format'] = '{n}/{total}'
super(tqdmm, self).__init__(*args, **kwargs) super(tqdmm, self).__init__(*args, **kwargs)
@@ -150,8 +153,9 @@ class tqdmm(tqdm):
self.n = self.total self.n = self.total
super(tqdmm, self).__exit__(exc_type, exc_value, traceback) super(tqdmm, self).__exit__(exc_type, exc_value, traceback)
def parfor(*args, **kwargs): def parfor(*args, **kwargs):
""" @parfor(iterator=None, args=(), kwargs={}, length=None, desc=None, bar=True, qbar=True, rP=1/3, serial=4, debug=False): """ @parfor(iterator=None, args=(), kwargs={}, length=None, desc=None, bar=True, qbar=True, rP=1/3, serial=4):
decorator to parallize for-loops decorator to parallize for-loops
required arguments: required arguments:
@@ -169,7 +173,6 @@ def parfor(*args, **kwargs):
nP: number of workers, default: None, overrides rP if not None nP: number of workers, default: None, overrides rP if not None
number of workers will always be at least 2 number of workers will always be at least 2
serial: switch to serial if number of tasks less than serial, default: 4 serial: switch to serial if number of tasks less than serial, default: 4
debug: if an error occurs in an iteration, return the erorr instead of retrying in the main process
output: list with results from applying the decorated function to each iteration of the iterator output: list with results from applying the decorated function to each iteration of the iterator
specified as the first argument to the function specified as the first argument to the function
@@ -236,12 +239,12 @@ def parfor(*args, **kwargs):
return pmap(fun, *args, **kwargs) return pmap(fun, *args, **kwargs)
return decfun return decfun
class parpool(object): class parpool(object):
""" Parallel processing with addition of iterations at any time and request of that result any time after that. """ Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time. The target function and its argument can be changed at any time.
""" """
def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None, def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None):
debug=False):
""" fun, args, kwargs: target function and its arguments and keyword arguments """ fun, args, kwargs: target function and its arguments and keyword arguments
rP: ratio workers to cpu cores, default: 1 rP: ratio workers to cpu cores, default: 1
nP: number of workers, default, None, overrides rP if not None nP: number of workers, default, None, overrides rP if not None
@@ -256,22 +259,22 @@ class parpool(object):
self.fun = fun or (lambda x: x) self.fun = fun or (lambda x: x)
self.args = args or () self.args = args or ()
self.kwargs = kwargs or {} self.kwargs = kwargs or {}
self.debug = debug
if hasattr(multiprocessing, 'get_context'): if hasattr(multiprocessing, 'get_context'):
ctx = multiprocessing.get_context('spawn') ctx = multiprocessing.get_context('spawn')
else: else:
ctx = multiprocessing ctx = multiprocessing
self.A = ctx.Value('i', self.nP)
self.E = ctx.Event() self.E = ctx.Event()
self.Qi = ctx.Queue(3*self.nP) self.Qi = ctx.Queue(3*self.nP)
self.Qo = ctx.Queue(3*self.nP) self.Qo = ctx.Queue(3*self.nP)
self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.E, terminator, self.debug)) self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.A, self.E, terminator))
self.is_alive = True self.is_alive = True
self.res = {} self.res = {}
self.handle = 0 self.handle = 0
self.handles = [] self.handles = []
self.bar = bar self.bar = bar
self.qbar = qbar self.qbar = qbar
if not self.qbar is None: if self.qbar is not None:
self.qbar.total = 3*self.nP self.qbar.total = 3*self.nP
@property @property
@@ -310,28 +313,19 @@ class parpool(object):
def _getfromq(self): def _getfromq(self):
""" Get an item from the queue and store it. """ """ Get an item from the queue and store it. """
try: try:
r = self.Qo.get(True, 0.02) err, i, res = self.Qo.get(True, 0.02)
if r[0] is None: if not err:
self.res[r[1]] = dill.loads(r[2]) self.res[i] = dill.loads(res)
elif r[0] is False:
pfmt = warnings.formatwarning
warnings.formatwarning = lambda message, *args: '{}\n'.format(message)
warnings.warn(
'Warning, error occurred in iteration {}. The iteration will be retried and should raise a '
'debuggable error. If it doesn\'t, it\'s an error specific to parallel execution.'
.format(r[1]))
warnings.formatwarning = pfmt
fun, args, kwargs = [dill.loads(f[1]) for f in r[2][1:]]
r = (False, r[1], fun(dill.loads(r[2][0]), *args, **kwargs))
self.res[r[1]] = r[2]
else: else:
err = dill.loads(r[2]) n, fun, args, kwargs, e = [dill.loads(r) for r in res]
pfmt = warnings.formatwarning print('Error from process working on iteration {}:\n'.format(i))
warnings.formatwarning = lambda message, *args: '{}\n'.format(message) print(e)
warnings.warn('Warning, error occurred in iteration {}:\n{}'.format(r[1], err)) self.close()
warnings.formatwarning = pfmt print('Retrying in main thread...')
self.res[r[1]] = err self.res[i] = fun(n, *args, **kwargs)
if not self.bar is None: raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'
.format(fun.__name__))
if self.bar is not None:
self.bar.update() self.bar.update()
self._qbar_update() self._qbar_update()
except multiprocessing.queues.Empty: except multiprocessing.queues.Empty:
@@ -339,12 +333,13 @@ class parpool(object):
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None): def __call__(self, n, fun=None, args=None, kwargs=None, handle=None):
""" Add new iteration, using optional manually defined handle.""" """ Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.E.is_set():
n = dumps(n, recurse=True) n = dumps(n, recurse=True)
if not fun is None: if fun is not None:
self.fun = fun self.fun = fun
if not args is None: if args is not None:
self.args = args self.args = args
if not kwargs is None: if kwargs is not None:
self.kwargs = kwargs self.kwargs = kwargs
while self.Qi.full(): while self.Qi.full():
self._getfromq() self._getfromq()
@@ -355,13 +350,13 @@ class parpool(object):
self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) self.Qi.put((handle, n, self.fun, self.args, self.kwargs))
self._qbar_update() self._qbar_update()
return handle return handle
elif not handle in self: elif handle not in self:
self.handles.append(handle) self.handles.append(handle)
self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) self.Qi.put((handle, n, self.fun, self.args, self.kwargs))
self._qbar_update() self._qbar_update()
def _qbar_update(self): def _qbar_update(self):
if not self.qbar is None: if self.qbar is not None:
try: try:
self.qbar.n = self.Qi.qsize() self.qbar.n = self.Qi.qsize()
except: except:
@@ -373,9 +368,9 @@ class parpool(object):
def __getitem__(self, handle): def __getitem__(self, handle):
""" Request result and delete its record. Wait if result not yet available. """ """ Request result and delete its record. Wait if result not yet available. """
if not handle in self: if handle not in self:
raise ValueError('No handle: {}'.format(handle)) raise ValueError('No handle: {}'.format(handle))
while not handle in self.res: while handle not in self.res:
self._getfromq() self._getfromq()
self.handles.remove(handle) self.handles.remove(handle)
return self.res.pop(handle) return self.res.pop(handle)
@@ -402,20 +397,34 @@ class parpool(object):
def close(self): def close(self):
if self.is_alive: if self.is_alive:
self.is_alive = False
self.E.set() self.E.set()
self.P.close()
while self.A.value:
self._empty_queue(self.Qi)
self._empty_queue(self.Qo)
self._empty_queue(self.Qi)
self._empty_queue(self.Qo)
self.P.join()
self._close_queue(self.Qi) self._close_queue(self.Qi)
self._close_queue(self.Qo) self._close_queue(self.Qo)
self.P.close()
self.P.join()
self.is_alive = False
self.res = {} self.res = {}
self.handle = 0 self.handle = 0
self.handles = [] self.handles = []
@staticmethod
def _empty_queue(Q):
if not Q._closed:
while not Q.empty():
try:
Q.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
@staticmethod @staticmethod
def _close_queue(Q): def _close_queue(Q):
if not Q._closed: if not Q._closed:
while Q.full(): while not Q.empty():
try: try:
Q.get(True, 0.02) Q.get(True, 0.02)
except multiprocessing.queues.Empty: except multiprocessing.queues.Empty:
@@ -425,15 +434,22 @@ class parpool(object):
class _worker(object): class _worker(object):
""" Manages executing the target function which will be executed in different processes. """ """ Manages executing the target function which will be executed in different processes. """
def __init__(self, Qi, Qo, E, terminator, debug=False, cachesize=48): def __init__(self, Qi, Qo, A, E, terminator, cachesize=48):
self.cache = [] self.cache = []
self.Qi = Qi self.Qi = Qi
self.Qo = Qo self.Qo = Qo
self.A = A
self.E = E self.E = E
self.terminator = dumps(terminator, recurse=True) self.terminator = dumps(terminator, recurse=True)
self.debug = debug
self.cachesize = cachesize self.cachesize = cachesize
# print(self.terminator)
def add_to_q(self, value):
while not self.E.is_set():
try:
self.Qo.put(value, timeout=0.1)
break
except multiprocessing.queues.Full:
continue
def __call__(self): def __call__(self):
while not self.E.is_set(): while not self.E.is_set():
@@ -443,17 +459,17 @@ class parpool(object):
fun = self.get_from_cache(*Fun) fun = self.get_from_cache(*Fun)
args = self.get_from_cache(*Args) args = self.get_from_cache(*Args)
kwargs = self.get_from_cache(*Kwargs) kwargs = self.get_from_cache(*Kwargs)
self.Qo.put((None, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True))) self.add_to_q((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True)))
except multiprocessing.queues.Empty: except multiprocessing.queues.Empty:
continue continue
except: except:
if self.debug: self.add_to_q((True, i, (n, Fun[1], Args[1], Kwargs[1], dumps(format_exc(), recurse=True))))
self.Qo.put((True, i, dumps(format_exc(), recurse=True))) self.E.set()
else:
self.Qo.put((False, i, (n, Fun, Args, Kwargs)))
terminator = dill.loads(self.terminator) terminator = dill.loads(self.terminator)
if not terminator is None: if terminator is not None:
terminator() terminator()
with self.A.get_lock():
self.A.value -= 1
def get_from_cache(self, h, ser): def get_from_cache(self, h, ser):
if len(self.cache): if len(self.cache):
@@ -466,8 +482,9 @@ class parpool(object):
self.cache.pop(0) self.cache.pop(0)
return obj return obj
def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar=True, qbar=False, terminator=None, def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar=True, qbar=False, terminator=None,
rP=1, nP=None, serial=4, debug=False): rP=1, nP=None, serial=4):
""" map a function fun to each iteration in iterable """ map a function fun to each iteration in iterable
best use: iterable is a generator and length is given to this function best use: iterable is a generator and length is given to this function
@@ -483,7 +500,6 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar
rP: ratio workers to cpu cores, default: 1 rP: ratio workers to cpu cores, default: 1
nP: number of workers, default, None, overrides rP if not None nP: number of workers, default, None, overrides rP if not None
serial: switch to serial if number of tasks less than serial, default: 4 serial: switch to serial if number of tasks less than serial, default: 4
debug: if an error occurs in an iteration, return the erorr instead of retrying in the main process
""" """
args = args or () args = args or ()
kwargs = kwargs or {} kwargs = kwargs or {}
@@ -491,16 +507,16 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar
length = len(iterable) length = len(iterable)
except: except:
pass pass
if length and length<serial: #serial case if length and length < serial: # serial case
return [fun(c, *args, **kwargs) for c in tqdm(iterable, total=length, desc=desc, disable=not bar)] return [fun(c, *args, **kwargs) for c in tqdm(iterable, total=length, desc=desc, disable=not bar)]
else: #parallel case else: # parallel case
with tqdmm(total=0, desc='Task buffer', disable=not qbar, leave=False) as qbar,\ with tqdmm(total=0, desc='Task buffer', disable=not qbar, leave=False) as qbar,\
tqdm(total=length, desc=desc, disable=not bar) as bar: tqdm(total=length, desc=desc, disable=not bar) as bar:
with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator, debug) as p: with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator) as p:
length = 0 length = 0
for i, j in enumerate(iterable): #add work to the queue for i, j in enumerate(iterable): # add work to the queue
p[i] = j p[i] = j
if bar.total is None or bar.total < i+1: if bar.total is None or bar.total < i+1:
bar.total = i+1 bar.total = i+1
length += 1 length += 1
return [p[i] for i in range(length)] #collect the results return [p[i] for i in range(length)] # collect the results

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup( setuptools.setup(
name="parfor", name="parfor",
version="2021.3.1", version="2021.3.2",
author="Wim Pomp", author="Wim Pomp",
author_email="wimpomp@gmail.com", author_email="wimpomp@gmail.com",
description="A package to mimic the use of parfor as done in Matlab.", description="A package to mimic the use of parfor as done in Matlab.",