- Use a shared memory approach.
- Track unpickled objects by CouldNotBePickled class. - Bump minimal python to 3.8.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import multiprocessing
|
||||
from collections import OrderedDict
|
||||
from collections import UserDict
|
||||
from contextlib import ExitStack
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from os import getpid
|
||||
from traceback import format_exc
|
||||
@@ -9,11 +8,68 @@ from warnings import warn
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .pickler import Pickler, dumps, loads
|
||||
from .pickler import dumps, loads
|
||||
|
||||
cpu_count = int(multiprocessing.cpu_count())
|
||||
|
||||
|
||||
class SharedMemory(UserDict):
|
||||
def __init__(self, manager):
|
||||
super().__init__()
|
||||
self.data = manager.dict() # item_id_c: dilled representation of object
|
||||
self.cache = {} # item_id: object
|
||||
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
||||
|
||||
def __getstate__(self):
|
||||
return self.data
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self: # values will not be changed
|
||||
try:
|
||||
self.data[key] = False, value
|
||||
except Exception: # only use our pickler when necessary
|
||||
self.data[key] = True, dumps(value, recurse=True)
|
||||
self.cache[key] = value # the id of the object will not be reused as long as the object exists
|
||||
|
||||
def add_item(self, item, pool_id, task_handle):
|
||||
item_id = id(item)
|
||||
self[item_id] = item
|
||||
if item_id in self.pool_ids:
|
||||
self.pool_ids[item_id].add((pool_id, task_handle))
|
||||
else:
|
||||
self.pool_ids[item_id] = {(pool_id, task_handle)}
|
||||
return item_id
|
||||
|
||||
def remove_pool(self, pool_id):
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
||||
for key in set(self.data.keys()) - set(self.pool_ids):
|
||||
del self[key]
|
||||
self.garbage_collect()
|
||||
|
||||
def remove_task(self, pool_id, task):
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
||||
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
||||
del self[item_id]
|
||||
self.garbage_collect()
|
||||
|
||||
# worker functions
|
||||
def __setstate__(self, data):
|
||||
self.data = data
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not key in self.cache:
|
||||
dilled, value = self.data[key]
|
||||
if dilled:
|
||||
value = loads(value)
|
||||
self.cache[key] = value
|
||||
return self.cache[key]
|
||||
|
||||
def garbage_collect(self):
|
||||
for key in set(self.cache) - set(self.data.keys()):
|
||||
del self.cache[key]
|
||||
|
||||
|
||||
class Chunks:
|
||||
""" Yield successive chunks from lists.
|
||||
Usage: chunks(list0, list1, ...)
|
||||
@@ -107,95 +163,39 @@ class ExternalBar:
|
||||
self.callback(n)
|
||||
|
||||
|
||||
class Hasher:
|
||||
def __init__(self, obj, hsh=None):
|
||||
if hsh is not None:
|
||||
self.obj, self.str, self.hash = None, obj, hsh
|
||||
elif isinstance(obj, Hasher):
|
||||
self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash
|
||||
else:
|
||||
self.obj = obj
|
||||
self.str = dumps(self.obj, recurse=True)
|
||||
self.hash = hash(self.str)
|
||||
|
||||
def __reduce__(self):
|
||||
return self.__class__, (self.str, self.hash)
|
||||
|
||||
def set_from_cache(self, cache=None):
|
||||
if cache is None:
|
||||
self.obj = loads(self.str)
|
||||
elif self.hash in cache:
|
||||
self.obj = cache[self.hash]
|
||||
else:
|
||||
self.obj = cache[self.hash] = loads(self.str)
|
||||
|
||||
|
||||
class HashDescriptor:
|
||||
def __set_name__(self, owner, name):
|
||||
self.owner, self.name = owner, '_' + name
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if isinstance(value, Hasher):
|
||||
setattr(instance, self.name, value)
|
||||
else:
|
||||
setattr(instance, self.name, Hasher(value))
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return getattr(instance, self.name).obj
|
||||
|
||||
|
||||
class DequeDict(OrderedDict):
|
||||
def __init__(self, maxlen=None, *args, **kwargs):
|
||||
self.maxlen = maxlen
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __truncate__(self):
|
||||
while len(self) > self.maxlen:
|
||||
self.popitem(False)
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
super().__setitem__(*args, **kwargs)
|
||||
self.__truncate__()
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
super().update(*args, **kwargs)
|
||||
self.__truncate__()
|
||||
|
||||
|
||||
class Task:
|
||||
fun = HashDescriptor()
|
||||
args = HashDescriptor()
|
||||
kwargs = HashDescriptor()
|
||||
|
||||
def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
|
||||
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None):
|
||||
self.pool_id = pool_id
|
||||
self.fun = fun or (lambda *args, **kwargs: None)
|
||||
self.args = args or ()
|
||||
self.kwargs = kwargs or {}
|
||||
self.handle = handle
|
||||
self.n = n
|
||||
self.done = done
|
||||
self.result = loads(result) if self.done else None
|
||||
self.fun = shared_memory.add_item(fun, pool_id, handle)
|
||||
self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
|
||||
self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
|
||||
for item in kwargs.items()]
|
||||
self.name = fun.__name__ if hasattr(fun, '__name__') else None
|
||||
self.done = False
|
||||
self.result = None
|
||||
self.pid = None
|
||||
|
||||
def __reduce__(self):
|
||||
if self.done:
|
||||
return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done,
|
||||
dumps(self.result, recurse=True))
|
||||
def __getstate__(self):
|
||||
state = self.__dict__
|
||||
if self.result is not None:
|
||||
state['result'] = dumps(self.result, recurse=True)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
|
||||
if state['result'] is None:
|
||||
self.result = None
|
||||
else:
|
||||
return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle,
|
||||
dumps(self.n, recurse=True), self.done)
|
||||
self.result = loads(state['result'])
|
||||
|
||||
def set_from_cache(self, cache=None):
|
||||
self.n = loads(self.n)
|
||||
self._fun.set_from_cache(cache)
|
||||
self._args.set_from_cache(cache)
|
||||
self._kwargs.set_from_cache(cache)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self, shared_memory: SharedMemory):
|
||||
if not self.done:
|
||||
self.result = self.fun(self.n, *self.args, **self.kwargs)
|
||||
self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things
|
||||
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None)
|
||||
args = [shared_memory[arg] for arg in self.args]
|
||||
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
|
||||
self.result = fun(*args, **kwargs)
|
||||
self.done = True
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
@@ -225,11 +225,13 @@ class ParPool:
|
||||
self.id = id(self)
|
||||
self.handle = 0
|
||||
self.tasks = {}
|
||||
self.last_task = Task(self.id, fun, args, kwargs)
|
||||
self.bar = bar
|
||||
self.bar_lengths = {}
|
||||
self.spool = PoolSingleton(self)
|
||||
self.manager = self.spool.manager
|
||||
self.fun = fun
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.is_started = False
|
||||
|
||||
def __getstate__(self):
|
||||
@@ -242,10 +244,12 @@ class ParPool:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if self.id in self.spool.pools:
|
||||
self.spool.pools.pop(self.id)
|
||||
self.spool.remove_pool(self.id)
|
||||
|
||||
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
|
||||
def __call__(self, n, handle=None, barlength=1):
|
||||
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
|
||||
|
||||
def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1):
|
||||
if self.id not in self.spool.pools:
|
||||
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
|
||||
if handle is None:
|
||||
@@ -255,18 +259,11 @@ class ParPool:
|
||||
new_handle = handle
|
||||
if new_handle in self:
|
||||
raise ValueError(f'handle {new_handle} already present')
|
||||
new_task = copy(self.last_task)
|
||||
if fun is not None:
|
||||
new_task.fun = fun
|
||||
if args is not None:
|
||||
new_task.args = args
|
||||
if kwargs is not None:
|
||||
new_task.kwargs = kwargs
|
||||
new_task.handle = new_handle
|
||||
new_task.n = n
|
||||
self.tasks[new_handle] = new_task
|
||||
self.last_task = new_task
|
||||
self.spool.add_task(new_task)
|
||||
task = Task(self.spool.shared_memory, self.id, new_handle,
|
||||
fun or self.fun, args or self.args, kwargs or self.kwargs)
|
||||
self.tasks[new_handle] = task
|
||||
self.last_task = task
|
||||
self.spool.add_task(task)
|
||||
self.bar_lengths[new_handle] = barlength
|
||||
if handle is None:
|
||||
return new_handle
|
||||
@@ -286,9 +283,9 @@ class ParPool:
|
||||
self.spool.get_from_queue()
|
||||
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
|
||||
and not self.working:
|
||||
# retry a task if the process was killed while retrieving the task
|
||||
# retry a task if the process was killed while working on a task
|
||||
self.spool.add_task(self.tasks[handle])
|
||||
warn(f'Task {handle} was restarted because the process retrieving it was probably killed.')
|
||||
warn(f'Task {handle} was restarted because the process working on it was probably killed.')
|
||||
result = self.tasks[handle].result
|
||||
self.tasks.pop(handle)
|
||||
return result
|
||||
@@ -310,17 +307,17 @@ class ParPool:
|
||||
task = self.tasks[handle]
|
||||
print(f'Error from process working on iteration {handle}:\n')
|
||||
print(error)
|
||||
self.close()
|
||||
print('Retrying in main thread...')
|
||||
fun = task.fun.__name__
|
||||
task()
|
||||
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun))
|
||||
task(self.spool.shared_memory)
|
||||
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
|
||||
self.spool.shared_memory.remove_task(self.id, self.tasks[handle])
|
||||
|
||||
def done(self, task):
|
||||
if task.handle in self: # if not, the task was restarted erroneously
|
||||
self.tasks[task.handle] = task
|
||||
if hasattr(self.bar, 'update'):
|
||||
self.bar.update(self.bar_lengths.pop(task.handle))
|
||||
self.spool.shared_memory.remove_task(self.id, task)
|
||||
|
||||
def started(self, handle, pid):
|
||||
self.is_started = True
|
||||
@@ -344,11 +341,13 @@ class PoolSingleton:
|
||||
new.event = ctx.Event()
|
||||
new.queue_in = ctx.Queue(3 * new.n_processes)
|
||||
new.queue_out = ctx.Queue(new.n_processes)
|
||||
new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event))
|
||||
new.manager = ctx.Manager()
|
||||
new.shared_memory = SharedMemory(new.manager)
|
||||
new.pool = ctx.Pool(new.n_processes,
|
||||
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
|
||||
new.is_alive = True
|
||||
new.handle = 0
|
||||
new.pools = {}
|
||||
new.manager = ctx.Manager()
|
||||
cls.instance = new
|
||||
return cls.instance
|
||||
|
||||
@@ -359,6 +358,11 @@ class PoolSingleton:
|
||||
def __getstate__(self):
|
||||
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
|
||||
|
||||
def remove_pool(self, pool_id):
|
||||
self.shared_memory.remove_pool(pool_id)
|
||||
if pool_id in self.pools:
|
||||
self.pools.pop(pool_id)
|
||||
|
||||
def error(self, error):
|
||||
self.close()
|
||||
raise Exception(f'Error occurred in worker: {error}')
|
||||
@@ -391,6 +395,7 @@ class PoolSingleton:
|
||||
while self.queue_in.full():
|
||||
self.get_from_queue()
|
||||
self.queue_in.put(task)
|
||||
self.shared_memory.garbage_collect()
|
||||
|
||||
def get_newest_for_pool(self, pool):
|
||||
""" Request the newest key and result and delete its record. Wait if result not yet available. """
|
||||
@@ -437,8 +442,8 @@ class PoolSingleton:
|
||||
|
||||
class Worker:
|
||||
""" Manages executing the target function which will be executed in different processes. """
|
||||
def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48):
|
||||
self.cache = DequeDict(cachesize)
|
||||
def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event):
|
||||
self.shared_memory = shared_memory
|
||||
self.queue_in = queue_in
|
||||
self.queue_out = queue_out
|
||||
self.n_workers = n_workers
|
||||
@@ -459,10 +464,10 @@ class Worker:
|
||||
task = self.queue_in.get(True, 0.02)
|
||||
try:
|
||||
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
||||
task.set_from_cache(self.cache)
|
||||
self.add_to_queue('done', task.pool_id, task())
|
||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||
except Exception:
|
||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||
self.shared_memory.garbage_collect()
|
||||
except multiprocessing.queues.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
import copyreg
|
||||
from io import BytesIO
|
||||
from pickle import PicklingError, dispatch_table
|
||||
from pickle import PicklingError
|
||||
|
||||
import dill
|
||||
|
||||
failed_rv = (lambda *args, **kwargs: None, ())
|
||||
loads = dill.loads
|
||||
|
||||
|
||||
class CouldNotBePickled:
|
||||
def __init__(self, class_name):
|
||||
self.class_name = class_name
|
||||
|
||||
def __repr__(self):
|
||||
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
|
||||
|
||||
@classmethod
|
||||
def reduce(cls, item):
|
||||
return cls, (type(item).__name__,)
|
||||
|
||||
|
||||
class Pickler(dill.Pickler):
|
||||
""" Overload dill to ignore unpickleble parts of objects.
|
||||
""" Overload dill to ignore unpicklable parts of objects.
|
||||
You probably didn't want to use these parts anyhow.
|
||||
However, if you did, you'll have to find some way to make them pickleble.
|
||||
However, if you did, you'll have to find some way to make them picklable.
|
||||
"""
|
||||
def save(self, obj, save_persistent_id=True):
|
||||
""" Copied from pickle and amended. """
|
||||
@@ -43,7 +55,7 @@ class Pickler(dill.Pickler):
|
||||
|
||||
# Check private dispatch table if any, or else
|
||||
# copyreg.dispatch_table
|
||||
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
|
||||
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
else:
|
||||
@@ -66,14 +78,14 @@ class Pickler(dill.Pickler):
|
||||
raise PicklingError("Can't pickle %r object: %r" %
|
||||
(t.__name__, obj))
|
||||
except Exception:
|
||||
rv = failed_rv
|
||||
rv = CouldNotBePickled.reduce(obj)
|
||||
|
||||
# Check for string returned by reduce(), meaning "save as global"
|
||||
if isinstance(rv, str):
|
||||
try:
|
||||
self.save_global(obj, rv)
|
||||
except Exception:
|
||||
self.save_global(obj, failed_rv)
|
||||
self.save_global(obj, CouldNotBePickled.reduce(obj))
|
||||
return
|
||||
|
||||
# Assert that reduce() returned a tuple
|
||||
@@ -90,7 +102,7 @@ class Pickler(dill.Pickler):
|
||||
try:
|
||||
self.save_reduce(obj=obj, *rv)
|
||||
except Exception:
|
||||
self.save_reduce(obj=obj, *failed_rv)
|
||||
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
||||
|
||||
|
||||
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):
|
||||
|
||||
Reference in New Issue
Block a user