- Use a shared memory approach.

- Track unpickled objects by CouldNotBePickled class.
- Bump minimal python to 3.8.
This commit is contained in:
Wim Pomp
2023-10-12 14:46:42 +02:00
parent 0263b6a4c7
commit 265470e0ac
3 changed files with 139 additions and 122 deletions

View File

@@ -1,7 +1,6 @@
import multiprocessing
from collections import OrderedDict
from collections import UserDict
from contextlib import ExitStack
from copy import copy
from functools import wraps
from os import getpid
from traceback import format_exc
@@ -9,11 +8,68 @@ from warnings import warn
from tqdm.auto import tqdm
from .pickler import Pickler, dumps, loads
from .pickler import dumps, loads
cpu_count = int(multiprocessing.cpu_count())
class SharedMemory(UserDict):
def __init__(self, manager):
super().__init__()
self.data = manager.dict() # item_id_c: dilled representation of object
self.cache = {} # item_id: object
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self):
return self.data
def __setitem__(self, key, value):
if key not in self: # values will not be changed
try:
self.data[key] = False, value
except Exception: # only use our pickler when necessary
self.data[key] = True, dumps(value, recurse=True)
self.cache[key] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle):
item_id = id(item)
self[item_id] = item
if item_id in self.pool_ids:
self.pool_ids[item_id].add((pool_id, task_handle))
else:
self.pool_ids[item_id] = {(pool_id, task_handle)}
return item_id
def remove_pool(self, pool_id):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for key in set(self.data.keys()) - set(self.pool_ids):
del self[key]
self.garbage_collect()
def remove_task(self, pool_id, task):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
del self[item_id]
self.garbage_collect()
# worker functions
def __setstate__(self, data):
self.data = data
self.cache = {}
def __getitem__(self, key):
if not key in self.cache:
dilled, value = self.data[key]
if dilled:
value = loads(value)
self.cache[key] = value
return self.cache[key]
def garbage_collect(self):
for key in set(self.cache) - set(self.data.keys()):
del self.cache[key]
class Chunks:
""" Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
@@ -107,95 +163,39 @@ class ExternalBar:
self.callback(n)
class Hasher:
def __init__(self, obj, hsh=None):
if hsh is not None:
self.obj, self.str, self.hash = None, obj, hsh
elif isinstance(obj, Hasher):
self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash
else:
self.obj = obj
self.str = dumps(self.obj, recurse=True)
self.hash = hash(self.str)
def __reduce__(self):
return self.__class__, (self.str, self.hash)
def set_from_cache(self, cache=None):
if cache is None:
self.obj = loads(self.str)
elif self.hash in cache:
self.obj = cache[self.hash]
else:
self.obj = cache[self.hash] = loads(self.str)
class HashDescriptor:
def __set_name__(self, owner, name):
self.owner, self.name = owner, '_' + name
def __set__(self, instance, value):
if isinstance(value, Hasher):
setattr(instance, self.name, value)
else:
setattr(instance, self.name, Hasher(value))
def __get__(self, instance, owner):
return getattr(instance, self.name).obj
class DequeDict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
super().__init__(*args, **kwargs)
def __truncate__(self):
while len(self) > self.maxlen:
self.popitem(False)
def __setitem__(self, *args, **kwargs):
super().__setitem__(*args, **kwargs)
self.__truncate__()
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self.__truncate__()
class Task:
fun = HashDescriptor()
args = HashDescriptor()
kwargs = HashDescriptor()
def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None):
self.pool_id = pool_id
self.fun = fun or (lambda *args, **kwargs: None)
self.args = args or ()
self.kwargs = kwargs or {}
self.handle = handle
self.n = n
self.done = done
self.result = loads(result) if self.done else None
self.fun = shared_memory.add_item(fun, pool_id, handle)
self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
for item in kwargs.items()]
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.done = False
self.result = None
self.pid = None
def __reduce__(self):
if self.done:
return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done,
dumps(self.result, recurse=True))
def __getstate__(self):
state = self.__dict__
if self.result is not None:
state['result'] = dumps(self.result, recurse=True)
return state
def __setstate__(self, state):
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
if state['result'] is None:
self.result = None
else:
return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle,
dumps(self.n, recurse=True), self.done)
self.result = loads(state['result'])
def set_from_cache(self, cache=None):
self.n = loads(self.n)
self._fun.set_from_cache(cache)
self._args.set_from_cache(cache)
self._kwargs.set_from_cache(cache)
def __call__(self):
def __call__(self, shared_memory: SharedMemory):
if not self.done:
self.result = self.fun(self.n, *self.args, **self.kwargs)
self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None)
args = [shared_memory[arg] for arg in self.args]
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
self.result = fun(*args, **kwargs)
self.done = True
return self
def __repr__(self):
@@ -225,11 +225,13 @@ class ParPool:
self.id = id(self)
self.handle = 0
self.tasks = {}
self.last_task = Task(self.id, fun, args, kwargs)
self.bar = bar
self.bar_lengths = {}
self.spool = PoolSingleton(self)
self.manager = self.spool.manager
self.fun = fun
self.args = args
self.kwargs = kwargs
self.is_started = False
def __getstate__(self):
@@ -242,10 +244,12 @@ class ParPool:
self.close()
def close(self):
if self.id in self.spool.pools:
self.spool.pools.pop(self.id)
self.spool.remove_pool(self.id)
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
def __call__(self, n, handle=None, barlength=1):
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1):
if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None:
@@ -255,18 +259,11 @@ class ParPool:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
new_task = copy(self.last_task)
if fun is not None:
new_task.fun = fun
if args is not None:
new_task.args = args
if kwargs is not None:
new_task.kwargs = kwargs
new_task.handle = new_handle
new_task.n = n
self.tasks[new_handle] = new_task
self.last_task = new_task
self.spool.add_task(new_task)
task = Task(self.spool.shared_memory, self.id, new_handle,
fun or self.fun, args or self.args, kwargs or self.kwargs)
self.tasks[new_handle] = task
self.last_task = task
self.spool.add_task(task)
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
@@ -286,9 +283,9 @@ class ParPool:
self.spool.get_from_queue()
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
# retry a task if the process was killed while retrieving the task
# retry a task if the process was killed while working on a task
self.spool.add_task(self.tasks[handle])
warn(f'Task {handle} was restarted because the process retrieving it was probably killed.')
warn(f'Task {handle} was restarted because the process working on it was probably killed.')
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
@@ -310,17 +307,17 @@ class ParPool:
task = self.tasks[handle]
print(f'Error from process working on iteration {handle}:\n')
print(error)
self.close()
print('Retrying in main thread...')
fun = task.fun.__name__
task()
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun))
task(self.spool.shared_memory)
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
self.spool.shared_memory.remove_task(self.id, self.tasks[handle])
def done(self, task):
if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle))
self.spool.shared_memory.remove_task(self.id, task)
def started(self, handle, pid):
self.is_started = True
@@ -344,11 +341,13 @@ class PoolSingleton:
new.event = ctx.Event()
new.queue_in = ctx.Queue(3 * new.n_processes)
new.queue_out = ctx.Queue(new.n_processes)
new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event))
new.manager = ctx.Manager()
new.shared_memory = SharedMemory(new.manager)
new.pool = ctx.Pool(new.n_processes,
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
new.is_alive = True
new.handle = 0
new.pools = {}
new.manager = ctx.Manager()
cls.instance = new
return cls.instance
@@ -359,6 +358,11 @@ class PoolSingleton:
def __getstate__(self):
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def remove_pool(self, pool_id):
self.shared_memory.remove_pool(pool_id)
if pool_id in self.pools:
self.pools.pop(pool_id)
def error(self, error):
self.close()
raise Exception(f'Error occurred in worker: {error}')
@@ -391,6 +395,7 @@ class PoolSingleton:
while self.queue_in.full():
self.get_from_queue()
self.queue_in.put(task)
self.shared_memory.garbage_collect()
def get_newest_for_pool(self, pool):
""" Request the newest key and result and delete its record. Wait if result not yet available. """
@@ -437,8 +442,8 @@ class PoolSingleton:
class Worker:
""" Manages executing the target function which will be executed in different processes. """
def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48):
self.cache = DequeDict(cachesize)
def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event):
self.shared_memory = shared_memory
self.queue_in = queue_in
self.queue_out = queue_out
self.n_workers = n_workers
@@ -459,10 +464,10 @@ class Worker:
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.pool_id, task.handle, pid)
task.set_from_cache(self.cache)
self.add_to_queue('done', task.pool_id, task())
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception:
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.shared_memory.garbage_collect()
except multiprocessing.queues.Empty:
continue
except Exception:

View File

@@ -1,16 +1,28 @@
import copyreg
from io import BytesIO
from pickle import PicklingError, dispatch_table
from pickle import PicklingError
import dill
failed_rv = (lambda *args, **kwargs: None, ())
loads = dill.loads
class CouldNotBePickled:
def __init__(self, class_name):
self.class_name = class_name
def __repr__(self):
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
@classmethod
def reduce(cls, item):
return cls, (type(item).__name__,)
class Pickler(dill.Pickler):
""" Overload dill to ignore unpickleble parts of objects.
""" Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them pickleble.
However, if you did, you'll have to find some way to make them picklable.
"""
def save(self, obj, save_persistent_id=True):
""" Copied from pickle and amended. """
@@ -43,7 +55,7 @@ class Pickler(dill.Pickler):
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
@@ -66,14 +78,14 @@ class Pickler(dill.Pickler):
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
except Exception:
rv = failed_rv
rv = CouldNotBePickled.reduce(obj)
# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
try:
self.save_global(obj, rv)
except Exception:
self.save_global(obj, failed_rv)
self.save_global(obj, CouldNotBePickled.reduce(obj))
return
# Assert that reduce() returned a tuple
@@ -90,7 +102,7 @@ class Pickler(dill.Pickler):
try:
self.save_reduce(obj=obj, *rv)
except Exception:
self.save_reduce(obj=obj, *failed_rv)
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds):

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "parfor"
version = "2023.9.0"
version = "2023.10.0"
description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3"
@@ -9,7 +9,7 @@ keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
repository = "https://github.com/wimpomp/parfor"
[tool.poetry.dependencies]
python = "^3.6"
python = "^3.8"
tqdm = ">=4.50.0"
dill = ">=0.3.0"
pytest = { version = "*", optional = true }