From 265470e0ac980d0b47b6e1a737700f26ea4672ab Mon Sep 17 00:00:00 2001 From: Wim Pomp Date: Thu, 12 Oct 2023 14:46:42 +0200 Subject: [PATCH] - Use a shared memory approach. - Track unpickled objects by CouldNotBePickled class. - Bump minimal python to 3.8. --- parfor/__init__.py | 229 +++++++++++++++++++++++---------------------- parfor/pickler.py | 28 ++++-- pyproject.toml | 4 +- 3 files changed, 139 insertions(+), 122 deletions(-) diff --git a/parfor/__init__.py b/parfor/__init__.py index 9088a2d..7173893 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -1,7 +1,6 @@ import multiprocessing -from collections import OrderedDict +from collections import UserDict from contextlib import ExitStack -from copy import copy from functools import wraps from os import getpid from traceback import format_exc @@ -9,11 +8,68 @@ from warnings import warn from tqdm.auto import tqdm -from .pickler import Pickler, dumps, loads +from .pickler import dumps, loads cpu_count = int(multiprocessing.cpu_count()) +class SharedMemory(UserDict): + def __init__(self, manager): + super().__init__() + self.data = manager.dict() # item_id_c: dilled representation of object + self.cache = {} # item_id: object + self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} + + def __getstate__(self): + return self.data + + def __setitem__(self, key, value): + if key not in self: # values will not be changed + try: + self.data[key] = False, value + except Exception: # only use our pickler when necessary + self.data[key] = True, dumps(value, recurse=True) + self.cache[key] = value # the id of the object will not be reused as long as the object exists + + def add_item(self, item, pool_id, task_handle): + item_id = id(item) + self[item_id] = item + if item_id in self.pool_ids: + self.pool_ids[item_id].add((pool_id, task_handle)) + else: + self.pool_ids[item_id] = {(pool_id, task_handle)} + return item_id + + def remove_pool(self, pool_id): + self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} + for key in set(self.data.keys()) - set(self.pool_ids): + del self[key] + self.garbage_collect() + + def remove_task(self, pool_id, task): + self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} + for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): + del self[item_id] + self.garbage_collect() + + # worker functions + def __setstate__(self, data): + self.data = data + self.cache = {} + + def __getitem__(self, key): + if not key in self.cache: + dilled, value = self.data[key] + if dilled: + value = loads(value) + self.cache[key] = value + return self.cache[key] + + def garbage_collect(self): + for key in set(self.cache) - set(self.data.keys()): + del self.cache[key] + + class Chunks: """ Yield successive chunks from lists. Usage: chunks(list0, list1, ...) @@ -107,95 +163,39 @@ class ExternalBar: self.callback(n) -class Hasher: - def __init__(self, obj, hsh=None): - if hsh is not None: - self.obj, self.str, self.hash = None, obj, hsh - elif isinstance(obj, Hasher): - self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash - else: - self.obj = obj - self.str = dumps(self.obj, recurse=True) - self.hash = hash(self.str) - - def __reduce__(self): - return self.__class__, (self.str, self.hash) - - def set_from_cache(self, cache=None): - if cache is None: - self.obj = loads(self.str) - elif self.hash in cache: - self.obj = cache[self.hash] - else: - self.obj = cache[self.hash] = loads(self.str) - - -class HashDescriptor: - def __set_name__(self, owner, name): - self.owner, self.name = owner, '_' + name - - def __set__(self, instance, value): - if isinstance(value, Hasher): - setattr(instance, self.name, value) - else: - setattr(instance, self.name, Hasher(value)) - - def __get__(self, instance, owner): - return getattr(instance, self.name).obj - - -class DequeDict(OrderedDict): - def __init__(self, maxlen=None, *args, **kwargs): - self.maxlen = maxlen - super().__init__(*args, **kwargs) - - def __truncate__(self): - while len(self) > self.maxlen: - self.popitem(False) - - def __setitem__(self, *args, **kwargs): - super().__setitem__(*args, **kwargs) - self.__truncate__() - - def update(self, *args, **kwargs): - super().update(*args, **kwargs) - self.__truncate__() - - class Task: - fun = HashDescriptor() - args = HashDescriptor() - kwargs = HashDescriptor() - - def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None): + def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None): self.pool_id = pool_id - self.fun = fun or (lambda *args, **kwargs: None) - self.args = args or () - self.kwargs = kwargs or {} self.handle = handle - self.n = n - self.done = done - self.result = loads(result) if self.done else None + self.fun = shared_memory.add_item(fun, pool_id, handle) + self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args] + self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle) + for item in kwargs.items()] + self.name = fun.__name__ if hasattr(fun, '__name__') else None + self.done = False + self.result = None self.pid = None - def __reduce__(self): - if self.done: - return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done, - dumps(self.result, recurse=True)) + def __getstate__(self): + state = self.__dict__ + if self.result is not None: + state['result'] = dumps(self.result, recurse=True) + return state + + def __setstate__(self, state): + self.__dict__.update({key: value for key, value in state.items() if key != 'result'}) + if state['result'] is None: + self.result = None else: - return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle, - dumps(self.n, recurse=True), self.done) + self.result = loads(state['result']) - def set_from_cache(self, cache=None): - self.n = loads(self.n) - self._fun.set_from_cache(cache) - self._args.set_from_cache(cache) - self._kwargs.set_from_cache(cache) - - def __call__(self): + def __call__(self, shared_memory: SharedMemory): if not self.done: - self.result = self.fun(self.n, *self.args, **self.kwargs) - self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things + fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) + args = [shared_memory[arg] for arg in self.args] + kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs]) + self.result = fun(*args, **kwargs) + self.done = True return self def __repr__(self): @@ -225,11 +225,13 @@ class ParPool: self.id = id(self) self.handle = 0 self.tasks = {} - self.last_task = Task(self.id, fun, args, kwargs) self.bar = bar self.bar_lengths = {} self.spool = PoolSingleton(self) self.manager = self.spool.manager + self.fun = fun + self.args = args + self.kwargs = kwargs self.is_started = False def __getstate__(self): @@ -242,10 +244,12 @@ class ParPool: self.close() def close(self): - if self.id in self.spool.pools: - self.spool.pools.pop(self.id) + self.spool.remove_pool(self.id) - def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1): + def __call__(self, n, handle=None, barlength=1): + self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength) + + def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1): if self.id not in self.spool.pools: raise ValueError(f'this pool is not registered (anymore) with the pool singleton') if handle is None: @@ -255,18 +259,11 @@ class ParPool: new_handle = handle if new_handle in self: raise ValueError(f'handle {new_handle} already present') - new_task = copy(self.last_task) - if fun is not None: - new_task.fun = fun - if args is not None: - new_task.args = args - if kwargs is not None: - new_task.kwargs = kwargs - new_task.handle = new_handle - new_task.n = n - self.tasks[new_handle] = new_task - self.last_task = new_task - self.spool.add_task(new_task) + task = Task(self.spool.shared_memory, self.id, new_handle, + fun or self.fun, args or self.args, kwargs or self.kwargs) + self.tasks[new_handle] = task + self.last_task = task + self.spool.add_task(task) self.bar_lengths[new_handle] = barlength if handle is None: return new_handle @@ -286,9 +283,9 @@ class ParPool: self.spool.get_from_queue() if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \ and not self.working: - # retry a task if the process was killed while retrieving the task + # retry a task if the process was killed while working on a task self.spool.add_task(self.tasks[handle]) - warn(f'Task {handle} was restarted because the process retrieving it was probably killed.') + warn(f'Task {handle} was restarted because the process working on it was probably killed.') result = self.tasks[handle].result self.tasks.pop(handle) return result @@ -310,17 +307,17 @@ class ParPool: task = self.tasks[handle] print(f'Error from process working on iteration {handle}:\n') print(error) - self.close() print('Retrying in main thread...') - fun = task.fun.__name__ - task() - raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun)) + task(self.spool.shared_memory) + raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.') + self.spool.shared_memory.remove_task(self.id, self.tasks[handle]) def done(self, task): if task.handle in self: # if not, the task was restarted erroneously self.tasks[task.handle] = task if hasattr(self.bar, 'update'): self.bar.update(self.bar_lengths.pop(task.handle)) + self.spool.shared_memory.remove_task(self.id, task) def started(self, handle, pid): self.is_started = True @@ -344,11 +341,13 @@ class PoolSingleton: new.event = ctx.Event() new.queue_in = ctx.Queue(3 * new.n_processes) new.queue_out = ctx.Queue(new.n_processes) - new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event)) + new.manager = ctx.Manager() + new.shared_memory = SharedMemory(new.manager) + new.pool = ctx.Pool(new.n_processes, + Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event)) new.is_alive = True new.handle = 0 new.pools = {} - new.manager = ctx.Manager() cls.instance = new return cls.instance @@ -359,6 +358,11 @@ class PoolSingleton: def __getstate__(self): raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') + def remove_pool(self, pool_id): + self.shared_memory.remove_pool(pool_id) + if pool_id in self.pools: + self.pools.pop(pool_id) + def error(self, error): self.close() raise Exception(f'Error occurred in worker: {error}') @@ -391,6 +395,7 @@ class PoolSingleton: while self.queue_in.full(): self.get_from_queue() self.queue_in.put(task) + self.shared_memory.garbage_collect() def get_newest_for_pool(self, pool): """ Request the newest key and result and delete its record. Wait if result not yet available. """ @@ -437,8 +442,8 @@ class PoolSingleton: class Worker: """ Manages executing the target function which will be executed in different processes. """ - def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48): - self.cache = DequeDict(cachesize) + def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event): + self.shared_memory = shared_memory self.queue_in = queue_in self.queue_out = queue_out self.n_workers = n_workers @@ -459,10 +464,10 @@ class Worker: task = self.queue_in.get(True, 0.02) try: self.add_to_queue('started', task.pool_id, task.handle, pid) - task.set_from_cache(self.cache) - self.add_to_queue('done', task.pool_id, task()) + self.add_to_queue('done', task.pool_id, task(self.shared_memory)) except Exception: self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) + self.shared_memory.garbage_collect() except multiprocessing.queues.Empty: continue except Exception: diff --git a/parfor/pickler.py b/parfor/pickler.py index aa4041d..692e05a 100644 --- a/parfor/pickler.py +++ b/parfor/pickler.py @@ -1,16 +1,28 @@ +import copyreg from io import BytesIO -from pickle import PicklingError, dispatch_table +from pickle import PicklingError import dill -failed_rv = (lambda *args, **kwargs: None, ()) loads = dill.loads +class CouldNotBePickled: + def __init__(self, class_name): + self.class_name = class_name + + def __repr__(self): + return f"Item of type '{self.class_name}' could not be pickled and was omitted." + + @classmethod + def reduce(cls, item): + return cls, (type(item).__name__,) + + class Pickler(dill.Pickler): - """ Overload dill to ignore unpickleble parts of objects. + """ Overload dill to ignore unpicklable parts of objects. You probably didn't want to use these parts anyhow. - However, if you did, you'll have to find some way to make them pickleble. + However, if you did, you'll have to find some way to make them picklable. """ def save(self, obj, save_persistent_id=True): """ Copied from pickle and amended. """ @@ -43,7 +55,7 @@ class Pickler(dill.Pickler): # Check private dispatch table if any, or else # copyreg.dispatch_table - reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) + reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t) if reduce is not None: rv = reduce(obj) else: @@ -66,14 +78,14 @@ class Pickler(dill.Pickler): raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj)) except Exception: - rv = failed_rv + rv = CouldNotBePickled.reduce(obj) # Check for string returned by reduce(), meaning "save as global" if isinstance(rv, str): try: self.save_global(obj, rv) except Exception: - self.save_global(obj, failed_rv) + self.save_global(obj, CouldNotBePickled.reduce(obj)) return # Assert that reduce() returned a tuple @@ -90,7 +102,7 @@ class Pickler(dill.Pickler): try: self.save_reduce(obj=obj, *rv) except Exception: - self.save_reduce(obj=obj, *failed_rv) + self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): diff --git a/pyproject.toml b/pyproject.toml index a67f4e2..3a7b3e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "parfor" -version = "2023.9.0" +version = "2023.10.0" description = "A package to mimic the use of parfor as done in Matlab." authors = ["Wim Pomp "] license = "GPLv3" @@ -9,7 +9,7 @@ keywords = ["parfor", "concurrency", "multiprocessing", "parallel"] repository = "https://github.com/wimpomp/parfor" [tool.poetry.dependencies] -python = "^3.6" +python = "^3.8" tqdm = ">=4.50.0" dill = ">=0.3.0" pytest = { version = "*", optional = true }