- Implement reference counting to make sure the parent process does not trash the item before all children trashed it.
This commit is contained in:
@@ -16,20 +16,26 @@ cpu_count = int(multiprocessing.cpu_count())
|
|||||||
class SharedMemory(UserDict):
|
class SharedMemory(UserDict):
|
||||||
def __init__(self, manager):
|
def __init__(self, manager):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = manager.dict() # item_id_c: dilled representation of object
|
self.data = manager.dict() # item_id: dilled representation of object
|
||||||
|
self.references = manager.dict() # item_id: counter
|
||||||
self.cache = {} # item_id: object
|
self.cache = {} # item_id: object
|
||||||
|
self.trash_can = {}
|
||||||
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return self.data
|
return self.data, self.references
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, item_id, value):
|
||||||
if key not in self: # values will not be changed
|
if item_id not in self: # values will not be changed
|
||||||
try:
|
try:
|
||||||
self.data[key] = False, value
|
self.data[item_id] = False, value
|
||||||
except Exception: # only use our pickler when necessary
|
except Exception: # only use our pickler when necessary
|
||||||
self.data[key] = True, dumps(value, recurse=True)
|
self.data[item_id] = True, dumps(value, recurse=True)
|
||||||
self.cache[key] = value # the id of the object will not be reused as long as the object exists
|
if item_id in self.references:
|
||||||
|
self.references[item_id] += 1
|
||||||
|
else:
|
||||||
|
self.references[item_id] = 1
|
||||||
|
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
|
||||||
|
|
||||||
def add_item(self, item, pool_id, task_handle):
|
def add_item(self, item, pool_id, task_handle):
|
||||||
item_id = id(item)
|
item_id = id(item)
|
||||||
@@ -41,33 +47,51 @@ class SharedMemory(UserDict):
|
|||||||
return item_id
|
return item_id
|
||||||
|
|
||||||
def remove_pool(self, pool_id):
|
def remove_pool(self, pool_id):
|
||||||
|
""" remove objects used by a pool that won't be needed anymore """
|
||||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
||||||
for key in set(self.data.keys()) - set(self.pool_ids):
|
for item_id in set(self.data.keys()) - set(self.pool_ids):
|
||||||
del self[key]
|
del self[item_id]
|
||||||
self.garbage_collect()
|
self.garbage_collect()
|
||||||
|
|
||||||
def remove_task(self, pool_id, task):
|
def remove_task(self, pool_id, task):
|
||||||
|
""" remove objects used by a task that won't be needed anymore """
|
||||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
||||||
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
||||||
del self[item_id]
|
del self[item_id]
|
||||||
self.garbage_collect()
|
self.garbage_collect()
|
||||||
|
|
||||||
# worker functions
|
# worker functions
|
||||||
def __setstate__(self, data):
|
def __setstate__(self, state):
|
||||||
self.data = data
|
self.data, self.references = state
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.trash_can = None
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, item_id):
|
||||||
if not key in self.cache:
|
if item_id not in self.cache:
|
||||||
dilled, value = self.data[key]
|
dilled, value = self.data[item_id]
|
||||||
if dilled:
|
if dilled:
|
||||||
value = loads(value)
|
value = loads(value)
|
||||||
self.cache[key] = value
|
if item_id in self.references:
|
||||||
return self.cache[key]
|
self.references[item_id] += 1
|
||||||
|
else:
|
||||||
|
self.references[item_id] = 1
|
||||||
|
self.cache[item_id] = value
|
||||||
|
return self.cache[item_id]
|
||||||
|
|
||||||
def garbage_collect(self):
|
def garbage_collect(self):
|
||||||
for key in set(self.cache) - set(self.data.keys()):
|
""" clean up the cache """
|
||||||
del self.cache[key]
|
for item_id in set(self.cache) - set(self.data.keys()):
|
||||||
|
self.references[item_id] -= 1
|
||||||
|
if self.trash_can is not None and item_id not in self.trash_can:
|
||||||
|
self.trash_can[item_id] = self.cache[item_id]
|
||||||
|
del self.cache[item_id]
|
||||||
|
|
||||||
|
if self.trash_can:
|
||||||
|
for item_id in set(self.trash_can):
|
||||||
|
if self.references[item_id] == 0:
|
||||||
|
# make sure every process removed the object before removing it in the parent
|
||||||
|
del self.references[item_id]
|
||||||
|
del self.trash_can[item_id]
|
||||||
|
|
||||||
|
|
||||||
class Chunks:
|
class Chunks:
|
||||||
@@ -467,12 +491,12 @@ class Worker:
|
|||||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||||
except Exception:
|
except Exception:
|
||||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||||
self.shared_memory.garbage_collect()
|
|
||||||
except multiprocessing.queues.Empty:
|
except multiprocessing.queues.Empty:
|
||||||
continue
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
self.add_to_queue('error', None, format_exc())
|
self.add_to_queue('error', None, format_exc())
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
self.shared_memory.garbage_collect()
|
||||||
for child in multiprocessing.active_children():
|
for child in multiprocessing.active_children():
|
||||||
child.kill()
|
child.kill()
|
||||||
with self.n_workers.get_lock():
|
with self.n_workers.get_lock():
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "parfor"
|
name = "parfor"
|
||||||
version = "2023.10.0"
|
version = "2023.10.1"
|
||||||
description = "A package to mimic the use of parfor as done in Matlab."
|
description = "A package to mimic the use of parfor as done in Matlab."
|
||||||
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
||||||
license = "GPLv3"
|
license = "GPLv3"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from parfor import Chunks, ParPool, parfor, pmap
|
from parfor import Chunks, ParPool, parfor, pmap
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class SequenceIterator:
|
class SequenceIterator:
|
||||||
@@ -68,3 +69,21 @@ def test_pmap():
|
|||||||
return i * j * k
|
return i * j * k
|
||||||
|
|
||||||
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_reuse():
|
||||||
|
def fun(i):
|
||||||
|
return i[0].a
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class T:
|
||||||
|
a: int = 3
|
||||||
|
|
||||||
|
def gen(total):
|
||||||
|
for i in range(total):
|
||||||
|
t = T(i)
|
||||||
|
yield t
|
||||||
|
del t
|
||||||
|
|
||||||
|
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000)
|
||||||
|
assert all([i == j for i, j in enumerate(a)])
|
||||||
|
|||||||
Reference in New Issue
Block a user