- n_processes can now be changed each time

- block adding tasks if pool is busy
This commit is contained in:
Wim Pomp
2026-01-09 11:37:55 +01:00
parent 1098239af9
commit 9fdbf49b5c
4 changed files with 135 additions and 96 deletions

View File

@@ -2,35 +2,35 @@ from __future__ import annotations
import logging import logging
import os import os
import warnings from contextlib import ExitStack, redirect_stderr, redirect_stdout
from contextlib import ExitStack, redirect_stdout, redirect_stderr
from io import StringIO
from functools import wraps from functools import wraps
from importlib.metadata import version from importlib.metadata import version
from io import StringIO
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from traceback import format_exc from traceback import format_exc
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generator, Generator,
Hashable,
Iterable, Iterable,
Iterator, Iterator,
Sized,
Hashable,
NoReturn, NoReturn,
Optional, Optional,
Protocol, Protocol,
Sequence, Sequence,
Sized,
) )
import numpy as np import numpy as np
import ray import ray
from numpy.typing import ArrayLike, DTypeLike from numpy.typing import ArrayLike, DTypeLike
from ray.remote_function import RemoteFunction
from ray.types import ObjectRef
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .pickler import dumps, loads from .pickler import dumps, loads
__version__ = version("parfor") __version__ = version("parfor")
cpu_count = int(os.cpu_count()) cpu_count = int(os.cpu_count())
@@ -193,30 +193,42 @@ class ExternalBar(Iterable):
self.callback(n) self.callback(n)
@ray.remote def get_worker(n_processes) -> RemoteFunction:
def worker(task): n_processes = n_processes or PoolSingleton.cpu_count
try: num_cpus = None if n_processes is None else cpu_count / n_processes
with ExitStack() as stack: # noqa
if task.allow_output:
out = StringIO()
err = StringIO()
stack.enter_context(redirect_stdout(out))
stack.enter_context(redirect_stderr(err))
else:
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
try:
task()
task.status = ("done",)
except Exception: # noqa
task.status = "task_error", format_exc()
if task.allow_output:
task.out = out.getvalue()
task.err = err.getvalue()
except KeyboardInterrupt: # noqa
pass
return task if not ray.is_initialized():
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(logging_level=logging.ERROR, log_to_driver=False)
def worker(task):
try:
with ExitStack() as stack: # noqa
if task.allow_output:
out = StringIO()
err = StringIO()
stack.enter_context(redirect_stdout(out))
stack.enter_context(redirect_stderr(err))
else:
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
try:
task()
task.status = ("done",)
except Exception: # noqa
task.status = "task_error", format_exc()
if task.allow_output:
task.out = out.getvalue()
task.err = err.getvalue()
except KeyboardInterrupt: # noqa
pass
return task
if num_cpus:
return ray.remote(num_cpus=num_cpus)(worker) # type: ignore
else:
return ray.remote(worker) # type: ignore
class Task: class Task:
@@ -322,7 +334,8 @@ class ParPool:
self.fun = fun self.fun = fun
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
PoolSingleton(n_processes) self.n_processes = n_processes or PoolSingleton.cpu_count
self.worker = get_worker(self.n_processes)
def __getstate__(self) -> NoReturn: def __getstate__(self) -> NoReturn:
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.") raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
@@ -366,7 +379,8 @@ class ParPool:
kwargs or self.kwargs, kwargs or self.kwargs,
allow_output or self.allow_output, allow_output or self.allow_output,
) )
task.future = worker.remote(task) self.block_until_space_available()
task.future = self.worker.remote(task)
self.tasks[new_handle] = task self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength self.bar_lengths[new_handle] = barlength
if handle is None: if handle is None:
@@ -381,13 +395,10 @@ class ParPool:
def __getitem__(self, handle: Hashable) -> Any: def __getitem__(self, handle: Hashable) -> Any:
"""Request result and delete its record. Wait if result not yet available.""" """Request result and delete its record. Wait if result not yet available."""
if handle not in self: if handle not in self:
raise ValueError(f"No handle: {handle} in pool") raise KeyError(f"No task with handle: {handle} in pool")
task = self.tasks[handle] task = self.finalize_task(self.tasks[handle])
if task.future is None: self.tasks.pop(task.handle)
return task.result return task.result
else:
task = ray.get(self.tasks[handle].future)
return self.finalize_task(task)
def __contains__(self, handle: Hashable) -> bool: def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks return handle in self.tasks
@@ -395,36 +406,58 @@ class ParPool:
def __delitem__(self, handle: Hashable) -> None: def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle) self.tasks.pop(handle)
def finalize_task(self, task: Task) -> Any: def finalize_task(self, future: ObjectRef | Task) -> Task:
code, *args = task.status if isinstance(future, Task):
if task.out: task: Task = future
if hasattr(self.bar, "write"): future = task.future
self.bar.write(task.out, end="") else:
else: task = None # type: ignore
print(task.out, end="")
if task.err:
if hasattr(self.bar, "write"):
self.bar.write(task.err, end="")
else:
print(task.err, end="")
getattr(self, code)(task, *args)
self.tasks.pop(task.handle)
return task.result
def get_newest(self) -> Optional[Any]: if future is not None:
"""Request the newest handle and result and delete its record. Wait if result not yet available.""" task: Task = ray.get(future) # type: ignore
code, *args = task.status
if task.out:
if hasattr(self.bar, "write"):
self.bar.write(task.out, end="")
else:
print(task.out, end="")
if task.err:
if hasattr(self.bar, "write"):
self.bar.write(task.err, end="")
else:
print(task.err, end="")
getattr(self, code)(task, *args)
self.tasks[task.handle] = task
return task
def block_until_space_available(self) -> None:
if len(self.tasks) < 3 * self.n_processes:
return
while True: while True:
if self.tasks: if self.tasks:
for handle, task in self.tasks.items(): futures = [task.future for task in self.tasks.values() if task.future is not None]
if handle in self.tasks: done, busy = ray.wait(futures, num_returns=1, timeout=0.01)
try: for d in done:
if task.future is None: self.finalize_task(d) # type: ignore
return task.handle, task.result if len(busy) < 3 * self.n_processes:
else: return
task = ray.get(task.future, timeout=0.01)
return task.handle, self.finalize_task(task) def get_newest(self) -> Any:
except ray.exceptions.GetTimeoutError: """Request the newest handle and result and delete its record. Wait if result not yet available."""
pass if self.tasks:
done = [task for task in self.tasks.values() if task.future is None]
if done:
task = done[0]
self.tasks.pop(task.handle)
return task.handle, task.result
while True:
futures = [task.future for task in self.tasks.values() if task.future is not None]
done, _ = ray.wait(futures, num_returns=1, timeout=0.01)
if done:
task = self.finalize_task(done[0])
self.tasks.pop(task.handle)
return task.handle, task.result
raise StopIteration
def task_error(self, task: Task, error: Exception) -> None: def task_error(self, task: Task, error: Exception) -> None:
if task.handle in self: if task.handle in self:
@@ -443,23 +476,7 @@ class ParPool:
class PoolSingleton: class PoolSingleton:
instance: PoolSingleton = None cpu_count: int = os.cpu_count()
cpu_count: int = int(os.cpu_count())
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
n_processes = n_processes or cls.cpu_count
if cls.instance is None or cls.instance.n_processes != n_processes:
cls.instance = super().__new__(cls)
cls.instance.n_processes = n_processes
if ray.is_initialized():
if cls.instance.n_processes != n_processes:
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
f"probably with n_processes={cls.instance.n_processes}")
else:
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
return cls.instance
class Worker: class Worker:

View File

@@ -23,12 +23,13 @@ class CouldNotBePickled:
class Pickler(dill.Pickler): class Pickler(dill.Pickler):
""" Overload dill to ignore unpicklable parts of objects. """Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow. You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable. However, if you did, you'll have to find some way to make them picklable.
""" """
def save(self, obj: Any, save_persistent_id: bool = True) -> None: def save(self, obj: Any, save_persistent_id: bool = True) -> None:
""" Copied from pickle and amended. """ """Copied from pickle and amended."""
self.framer.commit_frame() self.framer.commit_frame()
# Check for persistent id (defined by a subclass) # Check for persistent id (defined by a subclass)
@@ -58,7 +59,7 @@ class Pickler(dill.Pickler):
# Check private dispatch table if any, or else # Check private dispatch table if any, or else
# copyreg.dispatch_table # copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t) reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t)
if reduce is not None: if reduce is not None:
rv = reduce(obj) rv = reduce(obj)
else: else:
@@ -78,8 +79,7 @@ class Pickler(dill.Pickler):
if reduce is not None: if reduce is not None:
rv = reduce() rv = reduce()
else: else:
raise PicklingError("Can't pickle %r object: %r" % raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj))
(t.__name__, obj))
except Exception: # noqa except Exception: # noqa
rv = CouldNotBePickled.reduce(obj) rv = CouldNotBePickled.reduce(obj)
@@ -98,8 +98,7 @@ class Pickler(dill.Pickler):
# Assert that it returned an appropriately sized tuple # Assert that it returned an appropriately sized tuple
length = len(rv) length = len(rv)
if not (2 <= length <= 6): if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have " raise PicklingError("Tuple returned by %s must have two to six elements" % reduce)
"two to six elements" % reduce)
# Save the reduce() output and finally memoize the object # Save the reduce() output and finally memoize the object
try: try:
@@ -108,12 +107,13 @@ class Pickler(dill.Pickler):
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, def dumps(
**kwds: Any) -> bytes: obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any
) -> bytes:
"""pickle an object to a string""" """pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol) protocol = dill.settings["protocol"] if protocol is None else int(protocol)
_kwds = kwds.copy() _kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
with BytesIO() as file: with BytesIO() as file:
Pickler(file, protocol, **_kwds).dump(obj) Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue() return file.getvalue()

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "parfor" name = "parfor"
version = "2026.1.4" version = "2026.1.5"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = [ authors = [
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" } { name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence from typing import Any, Iterator, Optional, Sequence
import numpy as np import numpy as np
@@ -141,6 +143,16 @@ def test_id_reuse() -> None:
assert all([i == j for i, j in enumerate(a)]) assert all([i == j for i, j in enumerate(a)])
@pytest.mark.parametrize("n_processes", (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) <= n_processes
def test_shared_array() -> None: def test_shared_array() -> None:
def fun(i, a): def fun(i, a):
a[i] = i a[i] = i
@@ -150,3 +162,13 @@ def test_shared_array() -> None:
b = np.array(arr) b = np.array(arr)
assert np.all(b == np.arange(len(arr))) assert np.all(b == np.arange(len(arr)))
def test_nesting() -> None:
def a(i):
return i**2
def b(i):
return pmap(a, range(i, i + 50))
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]