- drop support for python 2

- automatically kill javabridge (if it exists) at the end of a process
- better hashing and tasking
This commit is contained in:
Wim Pomp
2022-03-03 12:43:00 +01:00
parent 4bddca82eb
commit fcf451cb23
3 changed files with 256 additions and 178 deletions

View File

@@ -1,20 +1,17 @@
from __future__ import print_function
import sys
import multiprocessing
import dill
from os import getpid
from tqdm.auto import tqdm
from traceback import format_exc
from pickle import PicklingError, dispatch_table
PY3 = (sys.hexversion >= 0x3000000)
from psutil import Process
from collections import OrderedDict
from io import BytesIO
try:
from cStringIO import StringIO
from javabridge import kill_vm
except ImportError:
if PY3:
from io import BytesIO as StringIO
else:
from StringIO import StringIO
kill_vm = lambda: None
failed_rv = (lambda *args, **kwargs: None, ())
cpu_count = int(multiprocessing.cpu_count())
@@ -27,8 +24,7 @@ class Pickler(dill.Pickler):
"""
def save(self, obj, save_persistent_id=True):
""" Copied from pickle and amended. """
if PY3:
self.framer.commit_frame()
self.framer.commit_frame()
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
@@ -112,12 +108,12 @@ def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
file = StringIO()
file = BytesIO()
Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue()
class chunks():
class chunks(object):
""" Yield successive chunks from lists.
Usage: chunks(s, list0, list1, ...)
chunks(list0, list1, ..., s=s)
@@ -132,35 +128,36 @@ class chunks():
def __init__(self, *args, **kwargs):
if 's' in kwargs and ('n' in kwargs or 'r' in kwargs):
N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = kwargs['n'] if 'n' in kwargs else int(cpu_count * kwargs['r'])
n = n if N < kwargs['s'] * n else round(N / kwargs['s'])
n = n if number_of_items < kwargs['s'] * n else round(number_of_items / kwargs['s'])
elif 's' in kwargs: # size of chunks
N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = round(N / kwargs['s'])
number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = round(number_of_items / kwargs['s'])
elif 'n' in kwargs or 'r' in kwargs: # number of chunks
N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = kwargs['n'] if 'n' in kwargs else int(cpu_count * kwargs['r'])
else: # size of chunks in 1st argument
s, *args = args
N = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = round(N / s)
number_of_items = min(*[len(a) for a in args]) if len(args) > 1 else len(args[0])
n = round(number_of_items / s)
self.args = args
self.A = len(args) == 1
self.N = N
self.len = max(1, min(N, n))
self.lengths = [((i + 1) * self.N // self.len) - (i * self.N // self.len) for i in range(self.len)]
self.number_of_arguments = len(args) == 1
self.number_of_items = number_of_items
self.len = max(1, min(number_of_items, n))
self.lengths = [((i + 1) * self.number_of_items // self.len) - (i * self.number_of_items // self.len)
for i in range(self.len)]
def __iter__(self):
for i in range(self.len):
p, q = (i * self.N // self.len), ((i + 1) * self.N // self.len)
yield self.args[0][p:q] if self.A else [a[p:q] for a in self.args]
p, q = (i * self.number_of_items // self.len), ((i + 1) * self.number_of_items // self.len)
yield self.args[0][p:q] if self.number_of_arguments else [a[p:q] for a in self.args]
def __len__(self):
return self.len
class external_bar:
class ExternalBar:
def __init__(self, iterable=None, callback=None, total=0):
self.iterable = iterable
self.callback = callback
@@ -193,6 +190,9 @@ class external_bar:
self.callback(n)
External_bar = ExternalBar
class tqdmm(tqdm):
""" Overload tqdm to make a special version of tqdm functioning as a meter. """
@@ -305,11 +305,108 @@ def parfor(*args, **kwargs):
return decfun
class Hasher(object):
def __init__(self, obj, hsh=None):
if hsh is not None:
self.obj, self.str, self.hash = None, obj, hsh
elif isinstance(obj, Hasher):
self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash
else:
self.obj = obj
self.str = dumps(self.obj, recurse=True)
self.hash = hash(self.str)
def __reduce__(self):
return self.__class__, (self.str, self.hash)
def set_from_cache(self, cache=None):
if cache is None:
self.obj = dill.loads(self.str)
elif self.hash in cache:
self.obj = cache[self.hash]
else:
self.obj = cache[self.hash] = dill.loads(self.str)
class HashDescriptor(object):
def __set_name__(self, owner, name):
self.owner, self.name = owner, '_' + name
def __set__(self, instance, value):
if isinstance(value, Hasher):
setattr(instance, self.name, value)
else:
setattr(instance, self.name, Hasher(value))
def __get__(self, instance, owner):
return getattr(instance, self.name).obj
class deque_dict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
super(deque_dict, self).__init__(*args, **kwargs)
def __truncate__(self):
while len(self) > self.maxlen:
self.popitem(False)
def __setitem__(self, *args, **kwargs):
super(deque_dict, self).__setitem__(*args, **kwargs)
self.__truncate__()
def update(self, *args, **kwargs):
super(deque_dict, self).update(*args, **kwargs)
self.__truncate__()
class Task(object):
fun = HashDescriptor()
args = HashDescriptor()
kwargs = HashDescriptor()
def __init__(self, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
self.fun = fun or (lambda *args, **kwargs: None)
self.args = args or ()
self.kwargs = kwargs or {}
self.handle = handle
self.n = n
self.done = done
self.result = dill.loads(result) if self.done else None
self.pid = None
def __reduce__(self):
if self.done:
return self.__class__, (None, None, None, self.handle, None, self.done, dumps(self.result, recurse=True))
else:
return self.__class__, (self._fun, self._args, self._kwargs, self.handle, dumps(self.n, recurse=True),
self.done)
def set_from_cache(self, cache=None):
self.n = dill.loads(self.n)
self._fun.set_from_cache(cache)
self._args.set_from_cache(cache)
self._kwargs.set_from_cache(cache)
def __call__(self):
if not self.done:
self.result = self.fun(self.n, *self.args, **self.kwargs)
self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things
return self
def __repr__(self):
if self.done:
return 'Task {}, result: {}'.format(self.handle, self.result)
else:
return 'Task {}'.format(self.handle)
class parpool(object):
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None):
def __init__(self, fun=None, args=None, kwargs=None, rP=None, nP=None, bar=None, qbar=None, terminator=None,
qsize=None):
""" fun, args, kwargs: target function and its arguments and keyword arguments
rP: ratio workers to cpu cores, default: 1
nP: number of workers, default, None, overrides rP if not None
@@ -321,54 +418,24 @@ class parpool(object):
else:
self.nP = int(nP)
self.nP = max(self.nP, 2)
self.fun = fun or (lambda x: x)
self.args = args or ()
self.kwargs = kwargs or {}
self.task = Task(fun, args, kwargs)
if hasattr(multiprocessing, 'get_context'):
ctx = multiprocessing.get_context('spawn')
else:
ctx = multiprocessing
self.A = ctx.Value('i', self.nP)
self.E = ctx.Event()
self.Qi = ctx.Queue(3*self.nP)
self.Qo = ctx.Queue(3*self.nP)
self.P = ctx.Pool(self.nP, self._worker(self.Qi, self.Qo, self.A, self.E, terminator))
self.n_tasks = ctx.Value('i', self.nP)
self.event = ctx.Event()
self.queue_in = ctx.Queue(qsize or 3 * self.nP)
self.queue_out = ctx.Queue(qsize or 12 * self.nP)
self.pool = ctx.Pool(self.nP, self._worker(self.queue_in, self.queue_out, self.n_tasks, self.event, terminator))
self.is_alive = True
self.res = {}
self.handle = 0
self.handles = []
self.tasks = {}
self.bar = bar
self.barlengths = {}
self.bar_lengths = {}
self.qbar = qbar
if self.qbar is not None:
self.qbar.total = 3*self.nP
@property
def fun(self):
return self._fun[1:]
@fun.setter
def fun(self, fun):
funs = dumps(fun, recurse=True)
self._fun = (fun, hash(funs), funs)
@property
def args(self):
return self._args[1:]
@args.setter
def args(self, args):
argss = dumps(args, recurse=True)
self._args = (args, hash(argss), argss)
@property
def kwargs(self):
return self._kwargs[1:]
@kwargs.setter
def kwargs(self, kwargs):
kwargss = dumps(kwargs, recurse=True)
self._kwargs = (kwargs, hash(kwargss), kwargss)
self.qbar.total = 3 * self.nP
def __enter__(self, *args, **kwargs):
return self
@@ -376,57 +443,65 @@ class parpool(object):
def __exit__(self, *args, **kwargs):
self.close()
def _getfromq(self):
""" Get an item from the queue and store it. """
def _get_from_queue(self):
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
err, i, res = self.Qo.get(True, 0.02)
if not err:
self.res[i] = dill.loads(res)
else:
n, fun, args, kwargs, e = [dill.loads(r) for r in res]
print('Error from process working on iteration {}:\n'.format(i))
print(e)
self.close()
print('Retrying in main thread...')
self.res[i] = fun(n, *args, **kwargs)
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'
.format(fun.__name__))
if self.bar is not None:
self.bar.update(self.barlengths.pop(i))
self._qbar_update()
code, *args = self.queue_out.get(True, 0.02)
getattr(self, code)(*args)
return True
except multiprocessing.queues.Empty:
pass
for handle, task in self.tasks.items(): # retry a task if the process doing it was killed
if task.pid is not None and task.pid not in [child.pid for child in Process().children()]:
self.queue_in.put(task)
return False
def error(self, error):
self.close()
raise Exception('Error occured in worker: {}'.format(error))
def task_error(self, handle, error):
task = self.tasks[handle]
print('Error from process working on iteration {}:\n'.format(handle))
print(error)
self.close()
print('Retrying in main thread...')
fun = task.fun.__name__
task()
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun))
def done(self, task):
self.tasks[task.handle] = task
if self.bar is not None:
self.bar.update(self.bar_lengths.pop(task.handle))
self._qbar_update()
def started(self, handle, pid):
self.tasks[handle].pid = pid
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
""" Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.E.is_set():
n = dumps(n, recurse=True)
if fun is not None:
self.fun = fun
if args is not None:
self.args = args
if kwargs is not None:
self.kwargs = kwargs
while self.Qi.full():
self._getfromq()
if self.is_alive and not self.event.is_set():
self.task = Task(fun or self.task.fun, args or self.task.args, kwargs or self.task.kwargs, handle, n)
while self.queue_in.full():
self._get_from_queue()
if handle is None:
handle = self.handle
self.handle += 1
self.handles.append(handle)
self.Qi.put((handle, n, self.fun, self.args, self.kwargs))
self.barlengths[handle] = barlength
self.tasks[handle] = self.task
self.queue_in.put(self.task)
self.bar_lengths[handle] = barlength
self._qbar_update()
return handle
elif handle not in self:
self.handles.append(handle)
self.Qi.put((handle, n, self.fun, self.args, self.kwargs))
self.barlengths[handle] = barlength
self.tasks[handle] = self.task
self.queue_in.put(self.task)
self.bar_lengths[handle] = barlength
self._qbar_update()
def _qbar_update(self):
if self.qbar is not None:
try:
self.qbar.n = self.Qi.qsize()
self.qbar.n = self.queue_in.qsize()
except Exception:
pass
@@ -438,24 +513,33 @@ class parpool(object):
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError('No handle: {}'.format(handle))
while handle not in self.res:
self._getfromq()
self.handles.remove(handle)
return self.res.pop(handle)
while not self.tasks[handle].done:
if not self._get_from_queue() and not self.tasks[handle].done and not self.working:
for _ in range(10): # wait some time while processing possible new messages
self._get_from_queue()
if not self._get_from_queue() and not self.tasks[handle].done and not self.working:
self.queue_in.put(self.tasks[handle])
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
@property
def working(self):
return not all([task.pid is None for task in self.tasks.values()])
def get_newest(self):
""" Request the newest key and result and delete its record. Wait if result not yet available. """
if len(self.handles):
while not len(self.res):
self._getfromq()
key = list(self.res.keys())[0]
return key, self[key]
while len(self.tasks):
self._get_from_queue()
for task in self.tasks:
if task.done:
return task.handle, task.result
def __delitem__(self, handle):
self[handle]
self.tasks.pop(handle)
def __contains__(self, handle):
return handle in self.handles
return handle in self.tasks
def __repr__(self):
if self.is_alive:
@@ -466,93 +550,84 @@ class parpool(object):
def close(self):
if self.is_alive:
self.is_alive = False
self.E.set()
self.P.close()
while self.A.value:
self._empty_queue(self.Qi)
self._empty_queue(self.Qo)
self._empty_queue(self.Qi)
self._empty_queue(self.Qo)
self.P.join()
self._close_queue(self.Qi)
self._close_queue(self.Qo)
self.res = {}
self.event.set()
self.pool.close()
while self.n_tasks.value:
self._empty_queue(self.queue_in)
self._empty_queue(self.queue_out)
self._empty_queue(self.queue_in)
self._empty_queue(self.queue_out)
self.pool.join()
self._close_queue(self.queue_in)
self._close_queue(self.queue_out)
self.handle = 0
self.handles = []
self.tasks = {}
@staticmethod
def _empty_queue(Q):
if not Q._closed:
while not Q.empty():
def _empty_queue(queue):
if not queue._closed:
while not queue.empty():
try:
Q.get(True, 0.02)
queue.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
@staticmethod
def _close_queue(Q):
if not Q._closed:
while not Q.empty():
def _close_queue(queue):
if not queue._closed:
while not queue.empty():
try:
Q.get(True, 0.02)
queue.get(True, 0.02)
except multiprocessing.queues.Empty:
pass
Q.close()
Q.join_thread()
queue.close()
queue.join_thread()
class _worker(object):
""" Manages executing the target function which will be executed in different processes. """
def __init__(self, Qi, Qo, A, E, terminator, cachesize=48):
self.cache = []
self.Qi = Qi
self.Qo = Qo
self.A = A
self.E = E
def __init__(self, queue_in, queue_out, n_tasks, event, terminator, cachesize=48):
self.cache = deque_dict(cachesize)
self.queue_in = queue_in
self.queue_out = queue_out
self.n_tasks = n_tasks
self.event = event
self.terminator = dumps(terminator, recurse=True)
self.cachesize = cachesize
def add_to_q(self, value):
while not self.E.is_set():
def add_to_queue(self, *args):
while not self.event.is_set():
try:
self.Qo.put(value, timeout=0.1)
self.queue_out.put(args, timeout=0.1)
break
except multiprocessing.queues.Full:
continue
def __call__(self):
while not self.E.is_set():
i, n, Fun, Args, Kwargs = [None]*5
pid = getpid()
while not self.event.is_set():
try:
i, n, Fun, Args, Kwargs = self.Qi.get(True, 0.02)
fun = self.get_from_cache(*Fun)
args = self.get_from_cache(*Args)
kwargs = self.get_from_cache(*Kwargs)
self.add_to_q((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True)))
task = self.queue_in.get(True, 0.02)
try:
task.set_from_cache(self.cache)
self.add_to_queue('started', task.handle, pid)
self.add_to_queue('done', task())
except Exception:
self.add_to_queue('task_error', task.handle, format_exc())
self.event.set()
except multiprocessing.queues.Empty:
continue
except Exception:
self.add_to_q((True, i, (n, Fun[1], Args[1], Kwargs[1], dumps(format_exc(), recurse=True))))
self.E.set()
self.add_to_queue('error', format_exc())
self.event.set()
terminator = dill.loads(self.terminator)
kill_vm()
if terminator is not None:
terminator()
with self.A.get_lock():
self.A.value -= 1
def get_from_cache(self, h, ser):
if len(self.cache):
hs, objs = zip(*self.cache)
if h in hs:
return objs[hs.index(h)]
obj = dill.loads(ser)
self.cache.append((h, obj))
while len(self.cache) > self.cachesize:
self.cache.pop(0)
return obj
with self.n_tasks.get_lock():
self.n_tasks.value -= 1
def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar=True, qbar=False, terminator=None,
rP=1, nP=None, serial=4):
rP=1, nP=None, serial=4, qsize=None):
""" map a function fun to each iteration in iterable
best use: iterable is a generator and length is given to this function
@@ -577,17 +652,17 @@ def pmap(fun, iterable=None, args=None, kwargs=None, length=None, desc=None, bar
pass
if length and length < serial: # serial case
if callable(bar):
return [fun(c, *args, **kwargs) for c in external_bar(iterable, bar)]
elif bar is False:
return [fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)]
else:
return [fun(c, *args, **kwargs) for c in tqdm(iterable, total=length, desc=desc, disable=not bar)]
else: # parallel case
chunk = isinstance(iterable, chunks)
if chunk:
length = iterable.N
with external_bar(callback=qbar) if callable(qbar) \
length = iterable.number_of_items
with ExternalBar(callback=qbar) if callable(qbar) \
else tqdmm(total=0, desc='Task buffer', disable=not qbar, leave=False) as qbar, \
external_bar(callback=bar) if callable(bar) else tqdm(total=length, desc=desc, disable=not bar) as bar:
with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator) as p:
ExternalBar(callback=bar) if callable(bar) else tqdm(total=length, desc=desc, disable=not bar) as bar:
with parpool(fun, args, kwargs, rP, nP, bar, qbar, terminator, qsize) as p:
length = 0
for i, j in enumerate(iterable): # add work to the queue
if chunk: