- Protect reference counter with a lock because changes to a multiprocessing manager dict are not atomic.

This commit is contained in:
Wim Pomp
2023-11-08 12:50:30 +01:00
parent 098d1810c3
commit ad8d9a4efb
2 changed files with 19 additions and 15 deletions

View File

@@ -18,12 +18,13 @@ class SharedMemory(UserDict):
super().__init__() super().__init__()
self.data = manager.dict() # item_id: dilled representation of object self.data = manager.dict() # item_id: dilled representation of object
self.references = manager.dict() # item_id: counter self.references = manager.dict() # item_id: counter
self.references_lock = manager.RLock()
self.cache = {} # item_id: object self.cache = {} # item_id: object
self.trash_can = {} self.trash_can = {}
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self): def __getstate__(self):
return self.data, self.references return self.data, self.references, self.references_lock
def __setitem__(self, item_id, value): def __setitem__(self, item_id, value):
if item_id not in self: # values will not be changed if item_id not in self: # values will not be changed
@@ -31,9 +32,10 @@ class SharedMemory(UserDict):
self.data[item_id] = False, value self.data[item_id] = False, value
except Exception: # only use our pickler when necessary # noqa except Exception: # only use our pickler when necessary # noqa
self.data[item_id] = True, dumps(value, recurse=True) self.data[item_id] = True, dumps(value, recurse=True)
if item_id in self.references: with self.references_lock:
try:
self.references[item_id] += 1 self.references[item_id] += 1
else: except KeyError:
self.references[item_id] = 1 self.references[item_id] = 1
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
@@ -62,7 +64,7 @@ class SharedMemory(UserDict):
# worker functions # worker functions
def __setstate__(self, state): def __setstate__(self, state):
self.data, self.references = state self.data, self.references, self.references_lock = state
self.cache = {} self.cache = {}
self.trash_can = None self.trash_can = None
@@ -71,6 +73,7 @@ class SharedMemory(UserDict):
dilled, value = self.data[item_id] dilled, value = self.data[item_id]
if dilled: if dilled:
value = loads(value) value = loads(value)
with self.references_lock:
if item_id in self.references: if item_id in self.references:
self.references[item_id] += 1 self.references[item_id] += 1
else: else:
@@ -81,9 +84,10 @@ class SharedMemory(UserDict):
def garbage_collect(self): def garbage_collect(self):
""" clean up the cache """ """ clean up the cache """
for item_id in set(self.cache) - set(self.data.keys()): for item_id in set(self.cache) - set(self.data.keys()):
if item_id in self.references: with self.references_lock:
try:
self.references[item_id] -= 1 self.references[item_id] -= 1
else: except KeyError:
self.references[item_id] = 0 self.references[item_id] = 0
if self.trash_can is not None and item_id not in self.trash_can: if self.trash_can is not None and item_id not in self.trash_can:
self.trash_can[item_id] = self.cache[item_id] self.trash_can[item_id] = self.cache[item_id]

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2023.11.1" version = "2023.11.2"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"