- Use a shared memory approach.

- Track unpickled objects by CouldNotBePickled class.
- Bump minimal python to 3.8.
This commit is contained in:
Wim Pomp
2023-10-12 14:46:42 +02:00
parent 0263b6a4c7
commit 265470e0ac
3 changed files with 139 additions and 122 deletions

View File

@@ -1,7 +1,6 @@
import multiprocessing import multiprocessing
from collections import OrderedDict from collections import UserDict
from contextlib import ExitStack from contextlib import ExitStack
from copy import copy
from functools import wraps from functools import wraps
from os import getpid from os import getpid
from traceback import format_exc from traceback import format_exc
@@ -9,11 +8,68 @@ from warnings import warn
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .pickler import Pickler, dumps, loads from .pickler import dumps, loads
cpu_count = int(multiprocessing.cpu_count()) cpu_count = int(multiprocessing.cpu_count())
class SharedMemory(UserDict):
def __init__(self, manager):
super().__init__()
self.data = manager.dict() # item_id_c: dilled representation of object
self.cache = {} # item_id: object
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self):
return self.data
def __setitem__(self, key, value):
if key not in self: # values will not be changed
try:
self.data[key] = False, value
except Exception: # only use our pickler when necessary
self.data[key] = True, dumps(value, recurse=True)
self.cache[key] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle):
item_id = id(item)
self[item_id] = item
if item_id in self.pool_ids:
self.pool_ids[item_id].add((pool_id, task_handle))
else:
self.pool_ids[item_id] = {(pool_id, task_handle)}
return item_id
def remove_pool(self, pool_id):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for key in set(self.data.keys()) - set(self.pool_ids):
del self[key]
self.garbage_collect()
def remove_task(self, pool_id, task):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
del self[item_id]
self.garbage_collect()
# worker functions
def __setstate__(self, data):
self.data = data
self.cache = {}
def __getitem__(self, key):
if not key in self.cache:
dilled, value = self.data[key]
if dilled:
value = loads(value)
self.cache[key] = value
return self.cache[key]
def garbage_collect(self):
for key in set(self.cache) - set(self.data.keys()):
del self.cache[key]
class Chunks: class Chunks:
""" Yield successive chunks from lists. """ Yield successive chunks from lists.
Usage: chunks(list0, list1, ...) Usage: chunks(list0, list1, ...)
@@ -107,95 +163,39 @@ class ExternalBar:
self.callback(n) self.callback(n)
class Hasher:
def __init__(self, obj, hsh=None):
if hsh is not None:
self.obj, self.str, self.hash = None, obj, hsh
elif isinstance(obj, Hasher):
self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash
else:
self.obj = obj
self.str = dumps(self.obj, recurse=True)
self.hash = hash(self.str)
def __reduce__(self):
return self.__class__, (self.str, self.hash)
def set_from_cache(self, cache=None):
if cache is None:
self.obj = loads(self.str)
elif self.hash in cache:
self.obj = cache[self.hash]
else:
self.obj = cache[self.hash] = loads(self.str)
class HashDescriptor:
def __set_name__(self, owner, name):
self.owner, self.name = owner, '_' + name
def __set__(self, instance, value):
if isinstance(value, Hasher):
setattr(instance, self.name, value)
else:
setattr(instance, self.name, Hasher(value))
def __get__(self, instance, owner):
return getattr(instance, self.name).obj
class DequeDict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
super().__init__(*args, **kwargs)
def __truncate__(self):
while len(self) > self.maxlen:
self.popitem(False)
def __setitem__(self, *args, **kwargs):
super().__setitem__(*args, **kwargs)
self.__truncate__()
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self.__truncate__()
class Task: class Task:
fun = HashDescriptor() def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None):
args = HashDescriptor()
kwargs = HashDescriptor()
def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
self.pool_id = pool_id self.pool_id = pool_id
self.fun = fun or (lambda *args, **kwargs: None)
self.args = args or ()
self.kwargs = kwargs or {}
self.handle = handle self.handle = handle
self.n = n self.fun = shared_memory.add_item(fun, pool_id, handle)
self.done = done self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
self.result = loads(result) if self.done else None self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
for item in kwargs.items()]
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.done = False
self.result = None
self.pid = None self.pid = None
def __reduce__(self): def __getstate__(self):
if self.done: state = self.__dict__
return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done, if self.result is not None:
dumps(self.result, recurse=True)) state['result'] = dumps(self.result, recurse=True)
return state
def __setstate__(self, state):
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
if state['result'] is None:
self.result = None
else: else:
return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle, self.result = loads(state['result'])
dumps(self.n, recurse=True), self.done)
def set_from_cache(self, cache=None): def __call__(self, shared_memory: SharedMemory):
self.n = loads(self.n)
self._fun.set_from_cache(cache)
self._args.set_from_cache(cache)
self._kwargs.set_from_cache(cache)
def __call__(self):
if not self.done: if not self.done:
self.result = self.fun(self.n, *self.args, **self.kwargs) fun = shared_memory[self.fun] or (lambda *args, **kwargs: None)
self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things args = [shared_memory[arg] for arg in self.args]
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
self.result = fun(*args, **kwargs)
self.done = True
return self return self
def __repr__(self): def __repr__(self):
@@ -225,11 +225,13 @@ class ParPool:
self.id = id(self) self.id = id(self)
self.handle = 0 self.handle = 0
self.tasks = {} self.tasks = {}
self.last_task = Task(self.id, fun, args, kwargs)
self.bar = bar self.bar = bar
self.bar_lengths = {} self.bar_lengths = {}
self.spool = PoolSingleton(self) self.spool = PoolSingleton(self)
self.manager = self.spool.manager self.manager = self.spool.manager
self.fun = fun
self.args = args
self.kwargs = kwargs
self.is_started = False self.is_started = False
def __getstate__(self): def __getstate__(self):
@@ -242,10 +244,12 @@ class ParPool:
self.close() self.close()
def close(self): def close(self):
if self.id in self.spool.pools: self.spool.remove_pool(self.id)
self.spool.pools.pop(self.id)
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1): def __call__(self, n, handle=None, barlength=1):
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1):
if self.id not in self.spool.pools: if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton') raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None: if handle is None:
@@ -255,18 +259,11 @@ class ParPool:
new_handle = handle new_handle = handle
if new_handle in self: if new_handle in self:
raise ValueError(f'handle {new_handle} already present') raise ValueError(f'handle {new_handle} already present')
new_task = copy(self.last_task) task = Task(self.spool.shared_memory, self.id, new_handle,
if fun is not None: fun or self.fun, args or self.args, kwargs or self.kwargs)
new_task.fun = fun self.tasks[new_handle] = task
if args is not None: self.last_task = task
new_task.args = args self.spool.add_task(task)
if kwargs is not None:
new_task.kwargs = kwargs
new_task.handle = new_handle
new_task.n = n
self.tasks[new_handle] = new_task
self.last_task = new_task
self.spool.add_task(new_task)
self.bar_lengths[new_handle] = barlength self.bar_lengths[new_handle] = barlength
if handle is None: if handle is None:
return new_handle return new_handle
@@ -286,9 +283,9 @@ class ParPool:
self.spool.get_from_queue() self.spool.get_from_queue()
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \ if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working: and not self.working:
# retry a task if the process was killed while retrieving the task # retry a task if the process was killed while working on a task
self.spool.add_task(self.tasks[handle]) self.spool.add_task(self.tasks[handle])
warn(f'Task {handle} was restarted because the process retrieving it was probably killed.') warn(f'Task {handle} was restarted because the process working on it was probably killed.')
result = self.tasks[handle].result result = self.tasks[handle].result
self.tasks.pop(handle) self.tasks.pop(handle)
return result return result
@@ -310,17 +307,17 @@ class ParPool:
task = self.tasks[handle] task = self.tasks[handle]
print(f'Error from process working on iteration {handle}:\n') print(f'Error from process working on iteration {handle}:\n')
print(error) print(error)
self.close()
print('Retrying in main thread...') print('Retrying in main thread...')
fun = task.fun.__name__ task(self.spool.shared_memory)
task() raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun)) self.spool.shared_memory.remove_task(self.id, self.tasks[handle])
def done(self, task): def done(self, task):
if task.handle in self: # if not, the task was restarted erroneously if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task self.tasks[task.handle] = task
if hasattr(self.bar, 'update'): if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle)) self.bar.update(self.bar_lengths.pop(task.handle))
self.spool.shared_memory.remove_task(self.id, task)
def started(self, handle, pid): def started(self, handle, pid):
self.is_started = True self.is_started = True
@@ -344,11 +341,13 @@ class PoolSingleton:
new.event = ctx.Event() new.event = ctx.Event()
new.queue_in = ctx.Queue(3 * new.n_processes) new.queue_in = ctx.Queue(3 * new.n_processes)
new.queue_out = ctx.Queue(new.n_processes) new.queue_out = ctx.Queue(new.n_processes)
new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event)) new.manager = ctx.Manager()
new.shared_memory = SharedMemory(new.manager)
new.pool = ctx.Pool(new.n_processes,
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
new.is_alive = True new.is_alive = True
new.handle = 0 new.handle = 0
new.pools = {} new.pools = {}
new.manager = ctx.Manager()
cls.instance = new cls.instance = new
return cls.instance return cls.instance
@@ -359,6 +358,11 @@ class PoolSingleton:
def __getstate__(self): def __getstate__(self):
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def remove_pool(self, pool_id):
self.shared_memory.remove_pool(pool_id)
if pool_id in self.pools:
self.pools.pop(pool_id)
def error(self, error): def error(self, error):
self.close() self.close()
raise Exception(f'Error occurred in worker: {error}') raise Exception(f'Error occurred in worker: {error}')
@@ -391,6 +395,7 @@ class PoolSingleton:
while self.queue_in.full(): while self.queue_in.full():
self.get_from_queue() self.get_from_queue()
self.queue_in.put(task) self.queue_in.put(task)
self.shared_memory.garbage_collect()
def get_newest_for_pool(self, pool): def get_newest_for_pool(self, pool):
""" Request the newest key and result and delete its record. Wait if result not yet available. """ """ Request the newest key and result and delete its record. Wait if result not yet available. """
@@ -437,8 +442,8 @@ class PoolSingleton:
class Worker: class Worker:
""" Manages executing the target function which will be executed in different processes. """ """ Manages executing the target function which will be executed in different processes. """
def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48): def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event):
self.cache = DequeDict(cachesize) self.shared_memory = shared_memory
self.queue_in = queue_in self.queue_in = queue_in
self.queue_out = queue_out self.queue_out = queue_out
self.n_workers = n_workers self.n_workers = n_workers
@@ -459,10 +464,10 @@ class Worker:
task = self.queue_in.get(True, 0.02) task = self.queue_in.get(True, 0.02)
try: try:
self.add_to_queue('started', task.pool_id, task.handle, pid) self.add_to_queue('started', task.pool_id, task.handle, pid)
task.set_from_cache(self.cache) self.add_to_queue('done', task.pool_id, task(self.shared_memory))
self.add_to_queue('done', task.pool_id, task())
except Exception: except Exception:
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.shared_memory.garbage_collect()
except multiprocessing.queues.Empty: except multiprocessing.queues.Empty:
continue continue
except Exception: except Exception:

View File

@@ -1,16 +1,28 @@
import copyreg
from io import BytesIO from io import BytesIO
from pickle import PicklingError, dispatch_table from pickle import PicklingError
import dill import dill
failed_rv = (lambda *args, **kwargs: None, ())
loads = dill.loads loads = dill.loads
class CouldNotBePickled:
def __init__(self, class_name):
self.class_name = class_name
def __repr__(self):
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
@classmethod
def reduce(cls, item):
return cls, (type(item).__name__,)
class Pickler(dill.Pickler): class Pickler(dill.Pickler):
""" Overload dill to ignore unpickleble parts of objects. """ Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow. You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them pickleble. However, if you did, you'll have to find some way to make them picklable.
""" """
def save(self, obj, save_persistent_id=True): def save(self, obj, save_persistent_id=True):
""" Copied from pickle and amended. """ """ Copied from pickle and amended. """
@@ -43,7 +55,7 @@ class Pickler(dill.Pickler):
# Check private dispatch table if any, or else # Check private dispatch table if any, or else
# copyreg.dispatch_table # copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
if reduce is not None: if reduce is not None:
rv = reduce(obj) rv = reduce(obj)
else: else:
@@ -66,14 +78,14 @@ class Pickler(dill.Pickler):
raise PicklingError("Can't pickle %r object: %r" % raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj)) (t.__name__, obj))
except Exception: except Exception:
rv = failed_rv rv = CouldNotBePickled.reduce(obj)
# Check for string returned by reduce(), meaning "save as global" # Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str): if isinstance(rv, str):
try: try:
self.save_global(obj, rv) self.save_global(obj, rv)
except Exception: except Exception:
self.save_global(obj, failed_rv) self.save_global(obj, CouldNotBePickled.reduce(obj))
return return
# Assert that reduce() returned a tuple # Assert that reduce() returned a tuple
@@ -90,7 +102,7 @@ class Pickler(dill.Pickler):
try: try:
self.save_reduce(obj=obj, *rv) self.save_reduce(obj=obj, *rv)
except Exception: except Exception:
self.save_reduce(obj=obj, *failed_rv) self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2023.9.0" version = "2023.10.0"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"
@@ -9,7 +9,7 @@ keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
repository = "https://github.com/wimpomp/parfor" repository = "https://github.com/wimpomp/parfor"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.6" python = "^3.8"
tqdm = ">=4.50.0" tqdm = ">=4.50.0"
dill = ">=0.3.0" dill = ">=0.3.0"
pytest = { version = "*", optional = true } pytest = { version = "*", optional = true }