- Implement reference counting to make sure the parent process does not trash the item before all children trashed it.
This commit is contained in:
@@ -16,20 +16,26 @@ cpu_count = int(multiprocessing.cpu_count())
|
||||
class SharedMemory(UserDict):
|
||||
def __init__(self, manager):
|
||||
super().__init__()
|
||||
self.data = manager.dict() # item_id_c: dilled representation of object
|
||||
self.data = manager.dict() # item_id: dilled representation of object
|
||||
self.references = manager.dict() # item_id: counter
|
||||
self.cache = {} # item_id: object
|
||||
self.trash_can = {}
|
||||
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
||||
|
||||
def __getstate__(self):
|
||||
return self.data
|
||||
return self.data, self.references
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self: # values will not be changed
|
||||
def __setitem__(self, item_id, value):
|
||||
if item_id not in self: # values will not be changed
|
||||
try:
|
||||
self.data[key] = False, value
|
||||
self.data[item_id] = False, value
|
||||
except Exception: # only use our pickler when necessary
|
||||
self.data[key] = True, dumps(value, recurse=True)
|
||||
self.cache[key] = value # the id of the object will not be reused as long as the object exists
|
||||
self.data[item_id] = True, dumps(value, recurse=True)
|
||||
if item_id in self.references:
|
||||
self.references[item_id] += 1
|
||||
else:
|
||||
self.references[item_id] = 1
|
||||
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
|
||||
|
||||
def add_item(self, item, pool_id, task_handle):
|
||||
item_id = id(item)
|
||||
@@ -41,33 +47,51 @@ class SharedMemory(UserDict):
|
||||
return item_id
|
||||
|
||||
def remove_pool(self, pool_id):
|
||||
""" remove objects used by a pool that won't be needed anymore """
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
||||
for key in set(self.data.keys()) - set(self.pool_ids):
|
||||
del self[key]
|
||||
for item_id in set(self.data.keys()) - set(self.pool_ids):
|
||||
del self[item_id]
|
||||
self.garbage_collect()
|
||||
|
||||
def remove_task(self, pool_id, task):
|
||||
""" remove objects used by a task that won't be needed anymore """
|
||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
||||
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
||||
del self[item_id]
|
||||
self.garbage_collect()
|
||||
|
||||
# worker functions
|
||||
def __setstate__(self, data):
|
||||
self.data = data
|
||||
def __setstate__(self, state):
|
||||
self.data, self.references = state
|
||||
self.cache = {}
|
||||
self.trash_can = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not key in self.cache:
|
||||
dilled, value = self.data[key]
|
||||
def __getitem__(self, item_id):
|
||||
if item_id not in self.cache:
|
||||
dilled, value = self.data[item_id]
|
||||
if dilled:
|
||||
value = loads(value)
|
||||
self.cache[key] = value
|
||||
return self.cache[key]
|
||||
if item_id in self.references:
|
||||
self.references[item_id] += 1
|
||||
else:
|
||||
self.references[item_id] = 1
|
||||
self.cache[item_id] = value
|
||||
return self.cache[item_id]
|
||||
|
||||
def garbage_collect(self):
|
||||
for key in set(self.cache) - set(self.data.keys()):
|
||||
del self.cache[key]
|
||||
""" clean up the cache """
|
||||
for item_id in set(self.cache) - set(self.data.keys()):
|
||||
self.references[item_id] -= 1
|
||||
if self.trash_can is not None and item_id not in self.trash_can:
|
||||
self.trash_can[item_id] = self.cache[item_id]
|
||||
del self.cache[item_id]
|
||||
|
||||
if self.trash_can:
|
||||
for item_id in set(self.trash_can):
|
||||
if self.references[item_id] == 0:
|
||||
# make sure every process removed the object before removing it in the parent
|
||||
del self.references[item_id]
|
||||
del self.trash_can[item_id]
|
||||
|
||||
|
||||
class Chunks:
|
||||
@@ -467,12 +491,12 @@ class Worker:
|
||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||
except Exception:
|
||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||
self.shared_memory.garbage_collect()
|
||||
except multiprocessing.queues.Empty:
|
||||
continue
|
||||
pass
|
||||
except Exception:
|
||||
self.add_to_queue('error', None, format_exc())
|
||||
self.event.set()
|
||||
self.shared_memory.garbage_collect()
|
||||
for child in multiprocessing.active_children():
|
||||
child.kill()
|
||||
with self.n_workers.get_lock():
|
||||
|
||||
Reference in New Issue
Block a user