- Implement reference counting to make sure the parent process does not trash the item before all children trashed it.
This commit is contained in:
@@ -16,20 +16,26 @@ cpu_count = int(multiprocessing.cpu_count())
|
||||
class SharedMemory(UserDict):
|
||||
def __init__(self, manager):
|
||||
super().__init__()
|
||||
self.data = manager.dict() # item_id_c: dilled representation of object
|
||||
self.data = manager.dict() # item_id: dilled representation of object
|
||||
self.references = manager.dict() # item_id: counter
|
||||
self.cache = {} # item_id: object
|
||||
self.trash_can = {}
|
||||
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
||||
|
||||
def __getstate__(self):
|
||||
return self.data
|
||||
return self.data, self.references
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self: # values will not be changed
|
||||
def __setitem__(self, item_id, value):
|
||||
if item_id not in self: # values will not be changed
|
||||
try:
|
||||
self.data[key] = False, value
|
||||
self.data[item_id] = False, value
|
||||
except Exception: # only use our pickler when necessary
|
||||
self.data[key] = True, dumps(value, recurse=True)
|
||||
self.cache[key] = value # the id of the object will not be reused as long as the object exists
|
||||
self.data[item_id] = True, dumps(value, recurse=True)
|
||||
if item_id in self.references:
|
||||
self.references[item_id] += 1
|
||||
else:
|
||||
self.references[item_id] = 1
|
||||
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
|
||||
|
||||
def add_item(self, item, pool_id, task_handle):
|
||||
item_id = id(item)
|
||||
@@ -41,33 +47,51 @@ class SharedMemory(UserDict):
|
||||
return item_id
|
||||
|
||||
def remove_pool(self, pool_id):
|
||||
""" remove objects used by a pool that won't be needed anymore """
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
||||
for key in set(self.data.keys()) - set(self.pool_ids):
|
||||
del self[key]
|
||||
for item_id in set(self.data.keys()) - set(self.pool_ids):
|
||||
del self[item_id]
|
||||
self.garbage_collect()
|
||||
|
||||
def remove_task(self, pool_id, task):
|
||||
""" remove objects used by a task that won't be needed anymore """
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
||||
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
||||
del self[item_id]
|
||||
self.garbage_collect()
|
||||
|
||||
# worker functions
|
||||
def __setstate__(self, data):
|
||||
self.data = data
|
||||
def __setstate__(self, state):
|
||||
self.data, self.references = state
|
||||
self.cache = {}
|
||||
self.trash_can = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not key in self.cache:
|
||||
dilled, value = self.data[key]
|
||||
def __getitem__(self, item_id):
|
||||
if item_id not in self.cache:
|
||||
dilled, value = self.data[item_id]
|
||||
if dilled:
|
||||
value = loads(value)
|
||||
self.cache[key] = value
|
||||
return self.cache[key]
|
||||
if item_id in self.references:
|
||||
self.references[item_id] += 1
|
||||
else:
|
||||
self.references[item_id] = 1
|
||||
self.cache[item_id] = value
|
||||
return self.cache[item_id]
|
||||
|
||||
def garbage_collect(self):
|
||||
for key in set(self.cache) - set(self.data.keys()):
|
||||
del self.cache[key]
|
||||
""" clean up the cache """
|
||||
for item_id in set(self.cache) - set(self.data.keys()):
|
||||
self.references[item_id] -= 1
|
||||
if self.trash_can is not None and item_id not in self.trash_can:
|
||||
self.trash_can[item_id] = self.cache[item_id]
|
||||
del self.cache[item_id]
|
||||
|
||||
if self.trash_can:
|
||||
for item_id in set(self.trash_can):
|
||||
if self.references[item_id] == 0:
|
||||
# make sure every process removed the object before removing it in the parent
|
||||
del self.references[item_id]
|
||||
del self.trash_can[item_id]
|
||||
|
||||
|
||||
class Chunks:
|
||||
@@ -467,12 +491,12 @@ class Worker:
|
||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||
except Exception:
|
||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||
self.shared_memory.garbage_collect()
|
||||
except multiprocessing.queues.Empty:
|
||||
continue
|
||||
pass
|
||||
except Exception:
|
||||
self.add_to_queue('error', None, format_exc())
|
||||
self.event.set()
|
||||
self.shared_memory.garbage_collect()
|
||||
for child in multiprocessing.active_children():
|
||||
child.kill()
|
||||
with self.n_workers.get_lock():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "parfor"
|
||||
version = "2023.10.0"
|
||||
version = "2023.10.1"
|
||||
description = "A package to mimic the use of parfor as done in Matlab."
|
||||
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
||||
license = "GPLv3"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from parfor import Chunks, ParPool, parfor, pmap
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class SequenceIterator:
|
||||
@@ -68,3 +69,21 @@ def test_pmap():
|
||||
return i * j * k
|
||||
|
||||
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
||||
|
||||
|
||||
def test_id_reuse():
|
||||
def fun(i):
|
||||
return i[0].a
|
||||
|
||||
@dataclass
|
||||
class T:
|
||||
a: int = 3
|
||||
|
||||
def gen(total):
|
||||
for i in range(total):
|
||||
t = T(i)
|
||||
yield t
|
||||
del t
|
||||
|
||||
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000)
|
||||
assert all([i == j for i, j in enumerate(a)])
|
||||
|
||||
Reference in New Issue
Block a user