- introduce n_processes to change the number of processes in the pool

This commit is contained in:
Wim Pomp
2024-05-24 16:57:35 +02:00
parent ac4d599646
commit 9783c1d1f2
5 changed files with 62 additions and 34 deletions

View File

@@ -256,13 +256,13 @@ class ParPool:
The target function and its argument can be changed at any time.
"""
def __init__(self, fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None, kwargs: dict[str, Any] = None, bar: Bar = None):
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
self.id = id(self)
self.handle = 0
self.tasks = {}
self.bar = bar
self.bar_lengths = {}
self.spool = PoolSingleton(self)
self.spool = PoolSingleton(n_processes, self)
self.manager = self.spool.manager
self.fun = fun
self.args = args
@@ -372,13 +372,14 @@ class PoolSingleton:
instance = None
def __new__(cls, *args: Any, **kwargs: Any) -> PoolSingleton:
if cls.instance is not None: # restart if any workers have shut down
if cls.instance.n_workers.value < cls.instance.n_processes:
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
if cls.instance is not None:
if cls.instance.n_workers.value < cls.instance.n_processes or cls.instance.n_processes != n_processes:
cls.instance.close()
if cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls)
new.n_processes = cpu_count
new.n_processes = n_processes or cpu_count
new.instance = new
new.is_started = False
ctx = Context()
@@ -396,7 +397,7 @@ class PoolSingleton:
cls.instance = new
return cls.instance
def __init__(self, parpool: Parpool = None) -> None: # noqa
def __init__(self, n_processes: int = None, parpool: Parpool = None) -> None: # noqa
if parpool is not None:
self.pools[parpool.id] = parpool
@@ -457,7 +458,7 @@ class PoolSingleton:
@classmethod
def close(cls) -> None:
if hasattr(cls, 'instance') and cls.instance is not None:
if cls.instance is not None:
instance = cls.instance
cls.instance = None
@@ -549,7 +550,7 @@ class Worker:
def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
**bar_kwargs: Any) -> list[Result]:
n_processes: int = None, **bar_kwargs: Any) -> list[Result]:
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
@@ -567,6 +568,8 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
or a callback function taking the number of passed iterations as an argument
serial: execute in series instead of parallel if True, None (default): let pmap decide
length: deprecated alias for total
n_processes: number of processes to use,
the parallel pool will be restarted if the current pool does not have the right number of processes
**bar_kwargs: keywords arguments for tqdm.tqdm
output:
@@ -654,7 +657,7 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar = stack.enter_context(ExternalBar(callback=bar))
else:
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, bar) as p:
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=iterable.lengths[i])
if bar.total is None or bar.total < i+1: