- Implement reference counting to make sure the parent process does not trash the item before all children trashed it.

This commit is contained in:
Wim Pomp
2023-10-27 17:48:52 +02:00
parent 265470e0ac
commit e936292905
3 changed files with 64 additions and 21 deletions

View File

@@ -16,20 +16,26 @@ cpu_count = int(multiprocessing.cpu_count())
class SharedMemory(UserDict): class SharedMemory(UserDict):
def __init__(self, manager): def __init__(self, manager):
super().__init__() super().__init__()
self.data = manager.dict() # item_id_c: dilled representation of object self.data = manager.dict() # item_id: dilled representation of object
self.references = manager.dict() # item_id: counter
self.cache = {} # item_id: object self.cache = {} # item_id: object
self.trash_can = {}
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self): def __getstate__(self):
return self.data return self.data, self.references
def __setitem__(self, key, value): def __setitem__(self, item_id, value):
if key not in self: # values will not be changed if item_id not in self: # values will not be changed
try: try:
self.data[key] = False, value self.data[item_id] = False, value
except Exception: # only use our pickler when necessary except Exception: # only use our pickler when necessary
self.data[key] = True, dumps(value, recurse=True) self.data[item_id] = True, dumps(value, recurse=True)
self.cache[key] = value # the id of the object will not be reused as long as the object exists if item_id in self.references:
self.references[item_id] += 1
else:
self.references[item_id] = 1
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle): def add_item(self, item, pool_id, task_handle):
item_id = id(item) item_id = id(item)
@@ -41,33 +47,51 @@ class SharedMemory(UserDict):
return item_id return item_id
def remove_pool(self, pool_id): def remove_pool(self, pool_id):
""" remove objects used by a pool that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for key in set(self.data.keys()) - set(self.pool_ids): for item_id in set(self.data.keys()) - set(self.pool_ids):
del self[key] del self[item_id]
self.garbage_collect() self.garbage_collect()
def remove_task(self, pool_id, task): def remove_task(self, pool_id, task):
""" remove objects used by a task that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
del self[item_id] del self[item_id]
self.garbage_collect() self.garbage_collect()
# worker functions # worker functions
def __setstate__(self, data): def __setstate__(self, state):
self.data = data self.data, self.references = state
self.cache = {} self.cache = {}
self.trash_can = None
def __getitem__(self, key): def __getitem__(self, item_id):
if not key in self.cache: if item_id not in self.cache:
dilled, value = self.data[key] dilled, value = self.data[item_id]
if dilled: if dilled:
value = loads(value) value = loads(value)
self.cache[key] = value if item_id in self.references:
return self.cache[key] self.references[item_id] += 1
else:
self.references[item_id] = 1
self.cache[item_id] = value
return self.cache[item_id]
def garbage_collect(self): def garbage_collect(self):
for key in set(self.cache) - set(self.data.keys()): """ clean up the cache """
del self.cache[key] for item_id in set(self.cache) - set(self.data.keys()):
self.references[item_id] -= 1
if self.trash_can is not None and item_id not in self.trash_can:
self.trash_can[item_id] = self.cache[item_id]
del self.cache[item_id]
if self.trash_can:
for item_id in set(self.trash_can):
if self.references[item_id] == 0:
# make sure every process removed the object before removing it in the parent
del self.references[item_id]
del self.trash_can[item_id]
class Chunks: class Chunks:
@@ -467,12 +491,12 @@ class Worker:
self.add_to_queue('done', task.pool_id, task(self.shared_memory)) self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception: except Exception:
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.shared_memory.garbage_collect()
except multiprocessing.queues.Empty: except multiprocessing.queues.Empty:
continue pass
except Exception: except Exception:
self.add_to_queue('error', None, format_exc()) self.add_to_queue('error', None, format_exc())
self.event.set() self.event.set()
self.shared_memory.garbage_collect()
for child in multiprocessing.active_children(): for child in multiprocessing.active_children():
child.kill() child.kill()
with self.n_workers.get_lock(): with self.n_workers.get_lock():

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2023.10.0" version = "2023.10.1"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from parfor import Chunks, ParPool, parfor, pmap from parfor import Chunks, ParPool, parfor, pmap
from dataclasses import dataclass
class SequenceIterator: class SequenceIterator:
@@ -68,3 +69,21 @@ def test_pmap():
return i * j * k return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_id_reuse():
def fun(i):
return i[0].a
@dataclass
class T:
a: int = 3
def gen(total):
for i in range(total):
t = T(i)
yield t
del t
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000)
assert all([i == j for i, j in enumerate(a)])