- Protect reference counter with a lock because changes to a multiprocessing manager dict are not atomic.

This commit is contained in:
Wim Pomp
2023-11-08 12:50:30 +01:00
parent 098d1810c3
commit ad8d9a4efb
2 changed files with 19 additions and 15 deletions

View File

@@ -18,12 +18,13 @@ class SharedMemory(UserDict):
super().__init__() super().__init__()
self.data = manager.dict() # item_id: dilled representation of object self.data = manager.dict() # item_id: dilled representation of object
self.references = manager.dict() # item_id: counter self.references = manager.dict() # item_id: counter
self.references_lock = manager.RLock()
self.cache = {} # item_id: object self.cache = {} # item_id: object
self.trash_can = {} self.trash_can = {}
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self): def __getstate__(self):
return self.data, self.references return self.data, self.references, self.references_lock
def __setitem__(self, item_id, value): def __setitem__(self, item_id, value):
if item_id not in self: # values will not be changed if item_id not in self: # values will not be changed
@@ -31,10 +32,11 @@ class SharedMemory(UserDict):
self.data[item_id] = False, value self.data[item_id] = False, value
except Exception: # only use our pickler when necessary # noqa except Exception: # only use our pickler when necessary # noqa
self.data[item_id] = True, dumps(value, recurse=True) self.data[item_id] = True, dumps(value, recurse=True)
if item_id in self.references: with self.references_lock:
self.references[item_id] += 1 try:
else: self.references[item_id] += 1
self.references[item_id] = 1 except KeyError:
self.references[item_id] = 1
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle): def add_item(self, item, pool_id, task_handle):
@@ -62,7 +64,7 @@ class SharedMemory(UserDict):
# worker functions # worker functions
def __setstate__(self, state): def __setstate__(self, state):
self.data, self.references = state self.data, self.references, self.references_lock = state
self.cache = {} self.cache = {}
self.trash_can = None self.trash_can = None
@@ -71,20 +73,22 @@ class SharedMemory(UserDict):
dilled, value = self.data[item_id] dilled, value = self.data[item_id]
if dilled: if dilled:
value = loads(value) value = loads(value)
if item_id in self.references: with self.references_lock:
self.references[item_id] += 1 if item_id in self.references:
else: self.references[item_id] += 1
self.references[item_id] = 1 else:
self.references[item_id] = 1
self.cache[item_id] = value self.cache[item_id] = value
return self.cache[item_id] return self.cache[item_id]
def garbage_collect(self): def garbage_collect(self):
""" clean up the cache """ """ clean up the cache """
for item_id in set(self.cache) - set(self.data.keys()): for item_id in set(self.cache) - set(self.data.keys()):
if item_id in self.references: with self.references_lock:
self.references[item_id] -= 1 try:
else: self.references[item_id] -= 1
self.references[item_id] = 0 except KeyError:
self.references[item_id] = 0
if self.trash_can is not None and item_id not in self.trash_can: if self.trash_can is not None and item_id not in self.trash_can:
self.trash_can[item_id] = self.cache[item_id] self.trash_can[item_id] = self.cache[item_id]
del self.cache[item_id] del self.cache[item_id]

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2023.11.1" version = "2023.11.2"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"