From e936292905d653eddbcfc617b8bb3b9f314732e3 Mon Sep 17 00:00:00 2001 From: Wim Pomp Date: Fri, 27 Oct 2023 17:48:52 +0200 Subject: [PATCH] - Implement reference counting to make sure the parent process does not trash the item before all children trashed it. --- parfor/__init__.py | 64 ++++++++++++++++++++++++++++++-------------- pyproject.toml | 2 +- tests/test_parfor.py | 19 +++++++++++++ 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/parfor/__init__.py b/parfor/__init__.py index 7173893..5de2abe 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -16,20 +16,26 @@ cpu_count = int(multiprocessing.cpu_count()) class SharedMemory(UserDict): def __init__(self, manager): super().__init__() - self.data = manager.dict() # item_id_c: dilled representation of object + self.data = manager.dict() # item_id: dilled representation of object + self.references = manager.dict() # item_id: counter self.cache = {} # item_id: object + self.trash_can = {} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} def __getstate__(self): - return self.data + return self.data, self.references - def __setitem__(self, key, value): - if key not in self: # values will not be changed + def __setitem__(self, item_id, value): + if item_id not in self: # values will not be changed try: - self.data[key] = False, value + self.data[item_id] = False, value except Exception: # only use our pickler when necessary - self.data[key] = True, dumps(value, recurse=True) - self.cache[key] = value # the id of the object will not be reused as long as the object exists + self.data[item_id] = True, dumps(value, recurse=True) + if item_id in self.references: + self.references[item_id] += 1 + else: + self.references[item_id] = 1 + self.cache[item_id] = value # the id of the object will not be reused as long as the object exists def add_item(self, item, pool_id, task_handle): item_id = id(item) @@ -41,33 +47,51 @@ class SharedMemory(UserDict): return item_id def remove_pool(self, pool_id): + """ remove objects used by a pool that won't be needed anymore """ self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} - for key in set(self.data.keys()) - set(self.pool_ids): - del self[key] + for item_id in set(self.data.keys()) - set(self.pool_ids): + del self[item_id] self.garbage_collect() def remove_task(self, pool_id, task): + """ remove objects used by a task that won't be needed anymore """ self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): del self[item_id] self.garbage_collect() # worker functions - def __setstate__(self, data): - self.data = data + def __setstate__(self, state): + self.data, self.references = state self.cache = {} + self.trash_can = None - def __getitem__(self, key): - if not key in self.cache: - dilled, value = self.data[key] + def __getitem__(self, item_id): + if item_id not in self.cache: + dilled, value = self.data[item_id] if dilled: value = loads(value) - self.cache[key] = value - return self.cache[key] + if item_id in self.references: + self.references[item_id] += 1 + else: + self.references[item_id] = 1 + self.cache[item_id] = value + return self.cache[item_id] def garbage_collect(self): - for key in set(self.cache) - set(self.data.keys()): - del self.cache[key] + """ clean up the cache """ + for item_id in set(self.cache) - set(self.data.keys()): + self.references[item_id] -= 1 + if self.trash_can is not None and item_id not in self.trash_can: + self.trash_can[item_id] = self.cache[item_id] + del self.cache[item_id] + + if self.trash_can: + for item_id in set(self.trash_can): + if self.references[item_id] == 0: + # make sure every process removed the object before removing it in the parent + del self.references[item_id] + del self.trash_can[item_id] class Chunks: @@ -467,12 +491,12 @@ class Worker: self.add_to_queue('done', task.pool_id, task(self.shared_memory)) except Exception: self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) - self.shared_memory.garbage_collect() except multiprocessing.queues.Empty: - continue + pass except Exception: self.add_to_queue('error', None, format_exc()) self.event.set() + self.shared_memory.garbage_collect() for child in multiprocessing.active_children(): child.kill() with self.n_workers.get_lock(): diff --git a/pyproject.toml b/pyproject.toml index 3a7b3e8..7821594 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "parfor" -version = "2023.10.0" +version = "2023.10.1" description = "A package to mimic the use of parfor as done in Matlab." authors = ["Wim Pomp "] license = "GPLv3" diff --git a/tests/test_parfor.py b/tests/test_parfor.py index 5f806cd..ef6561a 100644 --- a/tests/test_parfor.py +++ b/tests/test_parfor.py @@ -1,5 +1,6 @@ import pytest from parfor import Chunks, ParPool, parfor, pmap +from dataclasses import dataclass class SequenceIterator: @@ -68,3 +69,21 @@ def test_pmap(): return i * j * k assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] + + +def test_id_reuse(): + def fun(i): + return i[0].a + + @dataclass + class T: + a: int = 3 + + def gen(total): + for i in range(total): + t = T(i) + yield t + del t + + a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000) + assert all([i == j for i, j in enumerate(a)])