- Use a shared memory approach.

- Track unpickled objects by CouldNotBePickled class.
- Bump minimal python to 3.8.
This commit is contained in:
Wim Pomp
2023-10-12 14:46:42 +02:00
parent 0263b6a4c7
commit 265470e0ac
3 changed files with 139 additions and 122 deletions

View File

@@ -1,7 +1,6 @@
import multiprocessing
from collections import OrderedDict
from collections import UserDict
from contextlib import ExitStack
from copy import copy
from functools import wraps
from os import getpid
from traceback import format_exc
@@ -9,11 +8,68 @@ from warnings import warn
from tqdm.auto import tqdm
from .pickler import Pickler, dumps, loads
from .pickler import dumps, loads
cpu_count = int(multiprocessing.cpu_count())
class SharedMemory(UserDict):
def __init__(self, manager):
super().__init__()
self.data = manager.dict() # item_id_c: dilled representation of object
self.cache = {} # item_id: object
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self):
return self.data
def __setitem__(self, key, value):
if key not in self: # values will not be changed
try:
self.data[key] = False, value
except Exception: # only use our pickler when necessary
self.data[key] = True, dumps(value, recurse=True)
self.cache[key] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle):
item_id = id(item)
self[item_id] = item
if item_id in self.pool_ids:
self.pool_ids[item_id].add((pool_id, task_handle))
else:
self.pool_ids[item_id] = {(pool_id, task_handle)}
return item_id
def remove_pool(self, pool_id):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for key in set(self.data.keys()) - set(self.pool_ids):
del self[key]
self.garbage_collect()
def remove_task(self, pool_id, task):
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
del self[item_id]
self.garbage_collect()
# worker functions
def __setstate__(self, data):
self.data = data
self.cache = {}
def __getitem__(self, key):
if not key in self.cache:
dilled, value = self.data[key]
if dilled:
value = loads(value)
self.cache[key] = value
return self.cache[key]
def garbage_collect(self):
for key in set(self.cache) - set(self.data.keys()):
del self.cache[key]
class Chunks:
""" Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
@@ -107,95 +163,39 @@ class ExternalBar:
self.callback(n)
class Hasher:
def __init__(self, obj, hsh=None):
if hsh is not None:
self.obj, self.str, self.hash = None, obj, hsh
elif isinstance(obj, Hasher):
self.obj, self.str, self.hash = obj.obj, obj.str, obj.hash
else:
self.obj = obj
self.str = dumps(self.obj, recurse=True)
self.hash = hash(self.str)
def __reduce__(self):
return self.__class__, (self.str, self.hash)
def set_from_cache(self, cache=None):
if cache is None:
self.obj = loads(self.str)
elif self.hash in cache:
self.obj = cache[self.hash]
else:
self.obj = cache[self.hash] = loads(self.str)
class HashDescriptor:
def __set_name__(self, owner, name):
self.owner, self.name = owner, '_' + name
def __set__(self, instance, value):
if isinstance(value, Hasher):
setattr(instance, self.name, value)
else:
setattr(instance, self.name, Hasher(value))
def __get__(self, instance, owner):
return getattr(instance, self.name).obj
class DequeDict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
super().__init__(*args, **kwargs)
def __truncate__(self):
while len(self) > self.maxlen:
self.popitem(False)
def __setitem__(self, *args, **kwargs):
super().__setitem__(*args, **kwargs)
self.__truncate__()
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self.__truncate__()
class Task:
fun = HashDescriptor()
args = HashDescriptor()
kwargs = HashDescriptor()
def __init__(self, pool_id, fun=None, args=None, kwargs=None, handle=None, n=None, done=False, result=None):
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None):
self.pool_id = pool_id
self.fun = fun or (lambda *args, **kwargs: None)
self.args = args or ()
self.kwargs = kwargs or {}
self.handle = handle
self.n = n
self.done = done
self.result = loads(result) if self.done else None
self.fun = shared_memory.add_item(fun, pool_id, handle)
self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
for item in kwargs.items()]
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.done = False
self.result = None
self.pid = None
def __reduce__(self):
if self.done:
return self.__class__, (self.pool_id, None, None, None, self.handle, None, self.done,
dumps(self.result, recurse=True))
def __getstate__(self):
state = self.__dict__
if self.result is not None:
state['result'] = dumps(self.result, recurse=True)
return state
def __setstate__(self, state):
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
if state['result'] is None:
self.result = None
else:
return self.__class__, (self.pool_id, self._fun, self._args, self._kwargs, self.handle,
dumps(self.n, recurse=True), self.done)
self.result = loads(state['result'])
def set_from_cache(self, cache=None):
self.n = loads(self.n)
self._fun.set_from_cache(cache)
self._args.set_from_cache(cache)
self._kwargs.set_from_cache(cache)
def __call__(self):
def __call__(self, shared_memory: SharedMemory):
if not self.done:
self.result = self.fun(self.n, *self.args, **self.kwargs)
self.fun, self.args, self.kwargs, self.done = None, None, None, True # Remove potentially big things
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None)
args = [shared_memory[arg] for arg in self.args]
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
self.result = fun(*args, **kwargs)
self.done = True
return self
def __repr__(self):
@@ -225,11 +225,13 @@ class ParPool:
self.id = id(self)
self.handle = 0
self.tasks = {}
self.last_task = Task(self.id, fun, args, kwargs)
self.bar = bar
self.bar_lengths = {}
self.spool = PoolSingleton(self)
self.manager = self.spool.manager
self.fun = fun
self.args = args
self.kwargs = kwargs
self.is_started = False
def __getstate__(self):
@@ -242,10 +244,12 @@ class ParPool:
self.close()
def close(self):
if self.id in self.spool.pools:
self.spool.pools.pop(self.id)
self.spool.remove_pool(self.id)
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None, barlength=1):
def __call__(self, n, handle=None, barlength=1):
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1):
if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None:
@@ -255,18 +259,11 @@ class ParPool:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
new_task = copy(self.last_task)
if fun is not None:
new_task.fun = fun
if args is not None:
new_task.args = args
if kwargs is not None:
new_task.kwargs = kwargs
new_task.handle = new_handle
new_task.n = n
self.tasks[new_handle] = new_task
self.last_task = new_task
self.spool.add_task(new_task)
task = Task(self.spool.shared_memory, self.id, new_handle,
fun or self.fun, args or self.args, kwargs or self.kwargs)
self.tasks[new_handle] = task
self.last_task = task
self.spool.add_task(task)
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
@@ -286,9 +283,9 @@ class ParPool:
self.spool.get_from_queue()
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
# retry a task if the process was killed while retrieving the task
# retry a task if the process was killed while working on a task
self.spool.add_task(self.tasks[handle])
warn(f'Task {handle} was restarted because the process retrieving it was probably killed.')
warn(f'Task {handle} was restarted because the process working on it was probably killed.')
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
@@ -310,17 +307,17 @@ class ParPool:
task = self.tasks[handle]
print(f'Error from process working on iteration {handle}:\n')
print(error)
self.close()
print('Retrying in main thread...')
fun = task.fun.__name__
task()
raise Exception('Function \'{}\' cannot be executed by parfor, amend or execute in serial.'.format(fun))
task(self.spool.shared_memory)
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
self.spool.shared_memory.remove_task(self.id, self.tasks[handle])
def done(self, task):
if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle))
self.spool.shared_memory.remove_task(self.id, task)
def started(self, handle, pid):
self.is_started = True
@@ -344,11 +341,13 @@ class PoolSingleton:
new.event = ctx.Event()
new.queue_in = ctx.Queue(3 * new.n_processes)
new.queue_out = ctx.Queue(new.n_processes)
new.pool = ctx.Pool(new.n_processes, Worker(new.queue_in, new.queue_out, new.n_workers, new.event))
new.manager = ctx.Manager()
new.shared_memory = SharedMemory(new.manager)
new.pool = ctx.Pool(new.n_processes,
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
new.is_alive = True
new.handle = 0
new.pools = {}
new.manager = ctx.Manager()
cls.instance = new
return cls.instance
@@ -359,6 +358,11 @@ class PoolSingleton:
def __getstate__(self):
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def remove_pool(self, pool_id):
self.shared_memory.remove_pool(pool_id)
if pool_id in self.pools:
self.pools.pop(pool_id)
def error(self, error):
self.close()
raise Exception(f'Error occurred in worker: {error}')
@@ -391,6 +395,7 @@ class PoolSingleton:
while self.queue_in.full():
self.get_from_queue()
self.queue_in.put(task)
self.shared_memory.garbage_collect()
def get_newest_for_pool(self, pool):
""" Request the newest key and result and delete its record. Wait if result not yet available. """
@@ -437,8 +442,8 @@ class PoolSingleton:
class Worker:
""" Manages executing the target function which will be executed in different processes. """
def __init__(self, queue_in, queue_out, n_workers, event, cachesize=48):
self.cache = DequeDict(cachesize)
def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event):
self.shared_memory = shared_memory
self.queue_in = queue_in
self.queue_out = queue_out
self.n_workers = n_workers
@@ -459,10 +464,10 @@ class Worker:
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.pool_id, task.handle, pid)
task.set_from_cache(self.cache)
self.add_to_queue('done', task.pool_id, task())
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception:
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.shared_memory.garbage_collect()
except multiprocessing.queues.Empty:
continue
except Exception: