diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..54a7f77 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/build/ +/dist/ +/parfor.egg-info/ +.idea diff --git a/parfor/__init__.py b/parfor/__init__.py index 32d4590..f303087 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -1,20 +1,17 @@ -from __future__ import print_function -import sys import multiprocessing import dill +from os import getpid from tqdm.auto import tqdm from traceback import format_exc from pickle import PicklingError, dispatch_table - -PY3 = (sys.hexversion >= 0x3000000) +from psutil import Process +from collections import OrderedDict +from io import BytesIO try: - from cStringIO import StringIO + from javabridge import kill_vm except ImportError: - if PY3: - from io import BytesIO as StringIO - else: - from StringIO import StringIO + kill_vm = lambda: None failed_rv = (lambda *args, **kwargs: None, ()) cpu_count = int(multiprocessing.cpu_count()) @@ -27,8 +24,7 @@ class Pickler(dill.Pickler): """ def save(self, obj, save_persistent_id=True): """ Copied from pickle and amended. """ - if PY3: - self.framer.commit_frame() + self.framer.commit_frame() # Check for persistent id (defined by a subclass) pid = self.persistent_id(obj) @@ -112,12 +108,12 @@ def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): protocol = dill.settings['protocol'] if protocol is None else int(protocol) _kwds = kwds.copy() _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) - file = StringIO() + file = BytesIO() Pickler(file, protocol, **_kwds).dump(obj) return file.getvalue() -class chunks(): +class chunks(object): """ Yield successive chunks from lists. Usage: chunks(s, list0, list1, ...) chunks(list0, list1, ..., s=s) @@ -132,35 +128,36 @@ class chunks(): def __init__(self, *args, **kwargs): if 's' in kwargs and ('n' in kwargs or 'r' in kwargs): - N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) + number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) n = kwargs['n'] if 'n' in kwargs else int(cpu_count * kwargs['r']) - n = n if N < kwargs['s'] * n else round(N / kwargs['s']) + n = n if number_of_items < kwargs['s'] * n else round(number_of_items / kwargs['s']) elif 's' in kwargs: # size of chunks - N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) - n = round(N / kwargs['s']) + number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) + n = round(number_of_items / kwargs['s']) elif 'n' in kwargs or 'r' in kwargs: # number of chunks - N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) + number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) n = kwargs['n'] if 'n' in kwargs else int(cpu_count * kwargs['r']) else: # size of chunks in 1st argument s, *args = args - N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) - n = round(N / s) + number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0]) + n = round(number_of_items / s) self.args = args - self.A = len(args) == 1 - self.N = N - self.len = max(1, min(N, n)) - self.lengths = [((i + 1) * self.N // self.len) - (i * self.N // self.len) for i in range(self.len)] + self.number_of_arguments = len(args) == 1 + self.number_of_items = number_of_items + self.len = max(1, min(number_of_items, n)) + self.lengths = [((i + 1) * self.number_of_items // self.len) - (i * self.number_of_items // self.len) + for i in range(self.len)] def __iter__(self): for i in range(self.len): - p, q = (i * self.N // self.len), ((i + 1) * self.N // self.len) - yield self.args[0][p:q] if self.A else [a[p:q] for a in self.args] + p, q = (i * self.number_of_items // self.len), ((i + 1) * self.number_of_items // self.len) + yield self.args[0][p:q] if self.number_of_arguments else [a[p:q] for a in self.args] def __len__(self): return self.len -class external_bar: +class ExternalBar: def __init__(self, iterable=None, callback=None, total=0): self.iterable = iterable self.callback = callback @@ -193,6 +190,9 @@ class external_bar: self.callback(n) +External_bar = ExternalBar + + class tqdmm(tqdm): """ Overload tqdm to make a special version of tqdm functioning as a meter. """ @@ -305,11 +305,108 @@ def parfor(*args, **kwargs): return decfun +class Hasher(object): + def __init__(self, obj, hsh=None): + if hsh is not None: + self.obj, self.str, self.hash = None, obj, hsh + elif isinstance(obj, Hasher): + self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash + else: + self.obj = obj + self.str = dumps(self.obj, recurse=True) + self.hash = hash(self.str) + + def __reduce__(self): + return self.__class__, (self.str, self.hash) + + def set_from_cache(self, cache=None): + if cache is None: + self.obj = dill.loads(self.str) + elif self.hash in cache: + self.obj = cache[self.hash] + else: + self.obj = cache[self.hash] = dill.loads(self.str) + + +class HashDescriptor(object): + def __set_name__(self, owner, name): + self.owner, self.name = owner, '_' + name + + def __set__(self, instance, value): + if isinstance(value, Hasher): + setattr(instance, self.name, value) + else: + setattr(instance, self.name, Hasher(value)) + + def __get__(self, instance, owner): + return getattr(instance, self.name).obj + + +class deque_dict(OrderedDict): + def __init__(self, maxlen=None, *args, **kwargs): + self.maxlen = maxlen + super(deque_dict, self).__init__(*args, **kwargs) + + def __truncate__(self): + while len(self) > self.maxlen: + self.popitem(False) + + def __setitem__(self, *args, **kwargs): + super(deque_dict, self).__setitem__(*args, **kwargs) + self.__truncate__() + + def update(self, *args, **kwargs): + super(deque_dict, self).update(*args, **kwargs) + self.__truncate__() + + +class Task(object): + fun = HashDescriptor() + args = HashDescriptor() + kwargs = HashDescriptor() + + def __init__(self, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None): + self.fun = fun or (lambda *args, **kwargs: None) + self.args = args or () + self.kwargs = kwargs or {} + self.handle = handle + self.n = n + self.done = done + self.result = dill.loads(result) if self.done else None + self.pid = None + + def __reduce__(self): + if self.done: + return self.__class__, (None, None, None, self.handle, None, self.done, dumps(self.result, recurse=True)) + else: + return self.__class__, (self._fun, self._args, self._kwargs, self.handle, dumps(self.n, recurse=True), + self.done) + + def set_from_cache(self, cache=None): + self.n = dill.loads(self.n) + self._fun.set_from_cache(cache) + self._args.set_from_cache(cache) + self._kwargs.set_from_cache(cache) + + def __call__(self): + if not self.done: + self.result = self.fun(self.n, *self.args, **self.kwargs) + self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things + return self + + def __repr__(self): + if self.done: + return 'Task {}, result: {}'.format(self.handle, self.result) + else: + return 'Task {}'.format(self.handle) + + class parpool(object): """ Parallel processing with addition of iterations at any time and request of that result any time after that. The target function and its argument can be changed at any time. """ - def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None): + def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None, + qsize=None): """ fun, args, kwargs: target function and its arguments and keyword arguments rP: ratio workers to cpu cores, default: 1 nP: number of workers, default, None, overrides rP if not None @@ -321,54 +418,24 @@ class parpool(object): else: self.nP = int(nP) self.nP = max(self.nP, 2) - self.fun = fun or (lambda x: x) - self.args = args or () - self.kwargs = kwargs or {} + self.task = Task(fun, args, kwargs) if hasattr(multiprocessing, 'get_context'): ctx = multiprocessing.get_context('spawn') else: ctx = multiprocessing - self.A = ctx.Value('i', self.nP) - self.E = ctx.Event() - self.Qi = ctx.Queue(3*self.nP) - self.Qo = ctx.Queue(3*self.nP) - self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.A, self.E, terminator)) + self.n_tasks = ctx.Value('i', self.nP) + self.event = ctx.Event() + self.queue_in = ctx.Queue(qsize or 3 * self.nP) + self.queue_out = ctx.Queue(qsize or 12 * self.nP) + self.pool = ctx.Pool(self.nP, self._worker(self.queue_in, self.queue_out, self.n_tasks, self.event, terminator)) self.is_alive = True - self.res = {} self.handle = 0 - self.handles = [] + self.tasks = {} self.bar = bar - self.barlengths = {} + self.bar_lengths = {} self.qbar = qbar if self.qbar is not None: - self.qbar.total = 3*self.nP - - @property - def fun(self): - return self._fun[1:] - - @fun.setter - def fun(self, fun): - funs = dumps(fun, recurse=True) - self._fun = (fun, hash(funs), funs) - - @property - def args(self): - return self._args[1:] - - @args.setter - def args(self, args): - argss = dumps(args, recurse=True) - self._args = (args, hash(argss), argss) - - @property - def kwargs(self): - return self._kwargs[1:] - - @kwargs.setter - def kwargs(self, kwargs): - kwargss = dumps(kwargs, recurse=True) - self._kwargs = (kwargs, hash(kwargss), kwargss) + self.qbar.total = 3 * self.nP def __enter__(self, *args, **kwargs): return self @@ -376,57 +443,65 @@ class parpool(object): def __exit__(self, *args, **kwargs): self.close() - def _getfromq(self): - """ Get an item from the queue and store it. """ + def _get_from_queue(self): + """ Get an item from the queue and store it, return True if more messages are waiting. """ try: - err, i, res = self.Qo.get(True, 0.02) - if not err: - self.res[i] = dill.loads(res) - else: - n, fun, args, kwargs, e = [dill.loads(r) for r in res] - print('Error from process working on iteration {}:\n'.format(i)) - print(e) - self.close() - print('Retrying in main thread...') - self.res[i] = fun(n, *args, **kwargs) - raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.' - .format(fun.__name__)) - if self.bar is not None: - self.bar.update(self.barlengths.pop(i)) - self._qbar_update() + code, *args = self.queue_out.get(True, 0.02) + getattr(self, code)(*args) + return True except multiprocessing.queues.Empty: - pass + for handle, task in self.tasks.items(): # retry a task if the process doing it was killed + if task.pid is not None and task.pid not in [child.pid for child in Process().children()]: + self.queue_in.put(task) + return False + + def error(self, error): + self.close() + raise Exception('Error occured in worker: {}'.format(error)) + + def task_error(self, handle, error): + task = self.tasks[handle] + print('Error from process working on iteration {}:\n'.format(handle)) + print(error) + self.close() + print('Retrying in main thread...') + fun = task.fun.__name__ + task() + raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun)) + + def done(self, task): + self.tasks[task.handle] = task + if self.bar is not None: + self.bar.update(self.bar_lengths.pop(task.handle)) + self._qbar_update() + + def started(self, handle, pid): + self.tasks[handle].pid = pid def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1): """ Add new iteration, using optional manually defined handle.""" - if self.is_alive and not self.E.is_set(): - n = dumps(n, recurse=True) - if fun is not None: - self.fun = fun - if args is not None: - self.args = args - if kwargs is not None: - self.kwargs = kwargs - while self.Qi.full(): - self._getfromq() + if self.is_alive and not self.event.is_set(): + self.task = Task(fun or self.task.fun, args or self.task.args, kwargs or self.task.kwargs, handle, n) + while self.queue_in.full(): + self._get_from_queue() if handle is None: handle = self.handle self.handle += 1 - self.handles.append(handle) - self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) - self.barlengths[handle] = barlength + self.tasks[handle] = self.task + self.queue_in.put(self.task) + self.bar_lengths[handle] = barlength self._qbar_update() return handle elif handle not in self: - self.handles.append(handle) - self.Qi.put((handle, n, self.fun, self.args, self.kwargs)) - self.barlengths[handle] = barlength + self.tasks[handle] = self.task + self.queue_in.put(self.task) + self.bar_lengths[handle] = barlength self._qbar_update() def _qbar_update(self): if self.qbar is not None: try: - self.qbar.n = self.Qi.qsize() + self.qbar.n = self.queue_in.qsize() except Exception: pass @@ -438,24 +513,33 @@ class parpool(object): """ Request result and delete its record. Wait if result not yet available. """ if handle not in self: raise ValueError('No handle: {}'.format(handle)) - while handle not in self.res: - self._getfromq() - self.handles.remove(handle) - return self.res.pop(handle) + while not self.tasks[handle].done: + if not self._get_from_queue() and not self.tasks[handle].done and not self.working: + for _ in range(10): # wait some time while processing possible new messages + self._get_from_queue() + if not self._get_from_queue() and not self.tasks[handle].done and not self.working: + self.queue_in.put(self.tasks[handle]) + result = self.tasks[handle].result + self.tasks.pop(handle) + return result + + @property + def working(self): + return not all([task.pid is None for task in self.tasks.values()]) def get_newest(self): """ Request the newest key and result and delete its record. Wait if result not yet available. """ - if len(self.handles): - while not len(self.res): - self._getfromq() - key = list(self.res.keys())[0] - return key, self[key] + while len(self.tasks): + self._get_from_queue() + for task in self.tasks: + if task.done: + return task.handle, task.result def __delitem__(self, handle): - self[handle] + self.tasks.pop(handle) def __contains__(self, handle): - return handle in self.handles + return handle in self.tasks def __repr__(self): if self.is_alive: @@ -466,93 +550,84 @@ class parpool(object): def close(self): if self.is_alive: self.is_alive = False - self.E.set() - self.P.close() - while self.A.value: - self._empty_queue(self.Qi) - self._empty_queue(self.Qo) - self._empty_queue(self.Qi) - self._empty_queue(self.Qo) - self.P.join() - self._close_queue(self.Qi) - self._close_queue(self.Qo) - self.res = {} + self.event.set() + self.pool.close() + while self.n_tasks.value: + self._empty_queue(self.queue_in) + self._empty_queue(self.queue_out) + self._empty_queue(self.queue_in) + self._empty_queue(self.queue_out) + self.pool.join() + self._close_queue(self.queue_in) + self._close_queue(self.queue_out) self.handle = 0 - self.handles = [] + self.tasks = {} @staticmethod - def _empty_queue(Q): - if not Q._closed: - while not Q.empty(): + def _empty_queue(queue): + if not queue._closed: + while not queue.empty(): try: - Q.get(True, 0.02) + queue.get(True, 0.02) except multiprocessing.queues.Empty: pass @staticmethod - def _close_queue(Q): - if not Q._closed: - while not Q.empty(): + def _close_queue(queue): + if not queue._closed: + while not queue.empty(): try: - Q.get(True, 0.02) + queue.get(True, 0.02) except multiprocessing.queues.Empty: pass - Q.close() - Q.join_thread() + queue.close() + queue.join_thread() class _worker(object): """ Manages executing the target function which will be executed in different processes. """ - def __init__(self, Qi, Qo, A, E, terminator, cachesize=48): - self.cache = [] - self.Qi = Qi - self.Qo = Qo - self.A = A - self.E = E + def __init__(self, queue_in, queue_out, n_tasks, event, terminator, cachesize=48): + self.cache = deque_dict(cachesize) + self.queue_in = queue_in + self.queue_out = queue_out + self.n_tasks = n_tasks + self.event = event self.terminator = dumps(terminator, recurse=True) - self.cachesize = cachesize - def add_to_q(self, value): - while not self.E.is_set(): + def add_to_queue(self, *args): + while not self.event.is_set(): try: - self.Qo.put(value, timeout=0.1) + self.queue_out.put(args, timeout=0.1) break except multiprocessing.queues.Full: continue def __call__(self): - while not self.E.is_set(): - i, n, Fun, Args, Kwargs = [None]*5 + pid = getpid() + while not self.event.is_set(): try: - i, n, Fun, Args, Kwargs = self.Qi.get(True, 0.02) - fun = self.get_from_cache(*Fun) - args = self.get_from_cache(*Args) - kwargs = self.get_from_cache(*Kwargs) - self.add_to_q((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True))) + task = self.queue_in.get(True, 0.02) + try: + task.set_from_cache(self.cache) + self.add_to_queue('started', task.handle, pid) + self.add_to_queue('done', task()) + except Exception: + self.add_to_queue('task_error', task.handle, format_exc()) + self.event.set() except multiprocessing.queues.Empty: continue except Exception: - self.add_to_q((True, i, (n, Fun[1], Args[1], Kwargs[1], dumps(format_exc(), recurse=True)))) - self.E.set() + self.add_to_queue('error', format_exc()) + self.event.set() terminator = dill.loads(self.terminator) + kill_vm() if terminator is not None: terminator() - with self.A.get_lock(): - self.A.value -= 1 - - def get_from_cache(self, h, ser): - if len(self.cache): - hs, objs = zip(*self.cache) - if h in hs: - return objs[hs.index(h)] - obj = dill.loads(ser) - self.cache.append((h, obj)) - while len(self.cache) > self.cachesize: - self.cache.pop(0) - return obj + with self.n_tasks.get_lock(): + self.n_tasks.value -= 1 def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar=True, qbar=False, terminator=None, - rP=1, nP=None, serial=4): + rP=1, nP=None, serial=4, qsize=None): """ map a function fun to each iteration in iterable best use: iterable is a generator and length is given to this function @@ -577,17 +652,17 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar pass if length and length < serial: # serial case if callable(bar): - return [fun(c, *args, **kwargs) for c in external_bar(iterable, bar)] - elif bar is False: + return [fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)] + else: return [fun(c, *args, **kwargs) for c in tqdm(iterable, total=length, desc=desc, disable=not bar)] else: # parallel case chunk = isinstance(iterable, chunks) if chunk: - length = iterable.N - with external_bar(callback=qbar) if callable(qbar) \ + length = iterable.number_of_items + with ExternalBar(callback=qbar) if callable(qbar) \ else tqdmm(total=0, desc='Task buffer', disable=not qbar, leave=False) as qbar, \ - external_bar(callback=bar) if callable(bar) else tqdm(total=length, desc=desc, disable=not bar) as bar: - with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator) as p: + ExternalBar(callback=bar) if callable(bar) else tqdm(total=length, desc=desc, disable=not bar) as bar: + with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator, qsize) as p: length = 0 for i, j in enumerate(iterable): # add work to the queue if chunk: diff --git a/setup.py b/setup.py index 1fdcdcb..f918075 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="parfor", - version="2021.7.1", + version="2022.3.0", author="Wim Pomp", author_email="wimpomp@gmail.com", description="A package to mimic the use of parfor as done in Matlab.", @@ -14,11 +14,10 @@ setuptools.setup( url="https://github.com/wimpomp/parfor", packages=setuptools.find_packages(), classifiers=[ - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", ], - python_requires='>=2.7', - install_requires=['tqdm>=4.50.0', 'dill>=0.3.0'], + python_requires='>=3.5', + install_requires=['tqdm>=4.50.0', 'dill>=0.3.0', 'psutil'], ) \ No newline at end of file