- n_processes can now be changed each time

- block adding tasks if pool is busy
This commit is contained in:
Wim Pomp
2026-01-09 11:37:55 +01:00
parent 1098239af9
commit 9fdbf49b5c
4 changed files with 135 additions and 96 deletions

View File

@@ -2,35 +2,35 @@ from __future__ import annotations
import logging
import os
import warnings
from contextlib import ExitStack, redirect_stdout, redirect_stderr
from io import StringIO
from contextlib import ExitStack, redirect_stderr, redirect_stdout
from functools import wraps
from importlib.metadata import version
from io import StringIO
from multiprocessing.shared_memory import SharedMemory
from traceback import format_exc
from typing import (
Any,
Callable,
Generator,
Hashable,
Iterable,
Iterator,
Sized,
Hashable,
NoReturn,
Optional,
Protocol,
Sequence,
Sized,
)
import numpy as np
import ray
from numpy.typing import ArrayLike, DTypeLike
from ray.remote_function import RemoteFunction
from ray.types import ObjectRef
from tqdm.auto import tqdm
from .pickler import dumps, loads
__version__ = version("parfor")
cpu_count = int(os.cpu_count())
@@ -193,30 +193,42 @@ class ExternalBar(Iterable):
self.callback(n)
@ray.remote
def worker(task):
try:
with ExitStack() as stack: # noqa
if task.allow_output:
out = StringIO()
err = StringIO()
stack.enter_context(redirect_stdout(out))
stack.enter_context(redirect_stderr(err))
else:
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
try:
task()
task.status = ("done",)
except Exception: # noqa
task.status = "task_error", format_exc()
if task.allow_output:
task.out = out.getvalue()
task.err = err.getvalue()
except KeyboardInterrupt: # noqa
pass
def get_worker(n_processes) -> RemoteFunction:
n_processes = n_processes or PoolSingleton.cpu_count
num_cpus = None if n_processes is None else cpu_count / n_processes
return task
if not ray.is_initialized():
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(logging_level=logging.ERROR, log_to_driver=False)
def worker(task):
try:
with ExitStack() as stack: # noqa
if task.allow_output:
out = StringIO()
err = StringIO()
stack.enter_context(redirect_stdout(out))
stack.enter_context(redirect_stderr(err))
else:
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
try:
task()
task.status = ("done",)
except Exception: # noqa
task.status = "task_error", format_exc()
if task.allow_output:
task.out = out.getvalue()
task.err = err.getvalue()
except KeyboardInterrupt: # noqa
pass
return task
if num_cpus:
return ray.remote(num_cpus=num_cpus)(worker) # type: ignore
else:
return ray.remote(worker) # type: ignore
class Task:
@@ -322,7 +334,8 @@ class ParPool:
self.fun = fun
self.args = args
self.kwargs = kwargs
PoolSingleton(n_processes)
self.n_processes = n_processes or PoolSingleton.cpu_count
self.worker = get_worker(self.n_processes)
def __getstate__(self) -> NoReturn:
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
@@ -366,7 +379,8 @@ class ParPool:
kwargs or self.kwargs,
allow_output or self.allow_output,
)
task.future = worker.remote(task)
self.block_until_space_available()
task.future = self.worker.remote(task)
self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength
if handle is None:
@@ -381,13 +395,10 @@ class ParPool:
def __getitem__(self, handle: Hashable) -> Any:
"""Request result and delete its record. Wait if result not yet available."""
if handle not in self:
raise ValueError(f"No handle: {handle} in pool")
task = self.tasks[handle]
if task.future is None:
return task.result
else:
task = ray.get(self.tasks[handle].future)
return self.finalize_task(task)
raise KeyError(f"No task with handle: {handle} in pool")
task = self.finalize_task(self.tasks[handle])
self.tasks.pop(task.handle)
return task.result
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
@@ -395,36 +406,58 @@ class ParPool:
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def finalize_task(self, task: Task) -> Any:
code, *args = task.status
if task.out:
if hasattr(self.bar, "write"):
self.bar.write(task.out, end="")
else:
print(task.out, end="")
if task.err:
if hasattr(self.bar, "write"):
self.bar.write(task.err, end="")
else:
print(task.err, end="")
getattr(self, code)(task, *args)
self.tasks.pop(task.handle)
return task.result
def finalize_task(self, future: ObjectRef | Task) -> Task:
if isinstance(future, Task):
task: Task = future
future = task.future
else:
task = None # type: ignore
def get_newest(self) -> Optional[Any]:
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
if future is not None:
task: Task = ray.get(future) # type: ignore
code, *args = task.status
if task.out:
if hasattr(self.bar, "write"):
self.bar.write(task.out, end="")
else:
print(task.out, end="")
if task.err:
if hasattr(self.bar, "write"):
self.bar.write(task.err, end="")
else:
print(task.err, end="")
getattr(self, code)(task, *args)
self.tasks[task.handle] = task
return task
def block_until_space_available(self) -> None:
if len(self.tasks) < 3 * self.n_processes:
return
while True:
if self.tasks:
for handle, task in self.tasks.items():
if handle in self.tasks:
try:
if task.future is None:
return task.handle, task.result
else:
task = ray.get(task.future, timeout=0.01)
return task.handle, self.finalize_task(task)
except ray.exceptions.GetTimeoutError:
pass
futures = [task.future for task in self.tasks.values() if task.future is not None]
done, busy = ray.wait(futures, num_returns=1, timeout=0.01)
for d in done:
self.finalize_task(d) # type: ignore
if len(busy) < 3 * self.n_processes:
return
def get_newest(self) -> Any:
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
if self.tasks:
done = [task for task in self.tasks.values() if task.future is None]
if done:
task = done[0]
self.tasks.pop(task.handle)
return task.handle, task.result
while True:
futures = [task.future for task in self.tasks.values() if task.future is not None]
done, _ = ray.wait(futures, num_returns=1, timeout=0.01)
if done:
task = self.finalize_task(done[0])
self.tasks.pop(task.handle)
return task.handle, task.result
raise StopIteration
def task_error(self, task: Task, error: Exception) -> None:
if task.handle in self:
@@ -443,23 +476,7 @@ class ParPool:
class PoolSingleton:
instance: PoolSingleton = None
cpu_count: int = int(os.cpu_count())
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
n_processes = n_processes or cls.cpu_count
if cls.instance is None or cls.instance.n_processes != n_processes:
cls.instance = super().__new__(cls)
cls.instance.n_processes = n_processes
if ray.is_initialized():
if cls.instance.n_processes != n_processes:
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
f"probably with n_processes={cls.instance.n_processes}")
else:
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
return cls.instance
cpu_count: int = os.cpu_count()
class Worker:

View File

@@ -23,12 +23,13 @@ class CouldNotBePickled:
class Pickler(dill.Pickler):
""" Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable.
"""Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable.
"""
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
""" Copied from pickle and amended. """
"""Copied from pickle and amended."""
self.framer.commit_frame()
# Check for persistent id (defined by a subclass)
@@ -58,7 +59,7 @@ class Pickler(dill.Pickler):
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
@@ -78,8 +79,7 @@ class Pickler(dill.Pickler):
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj))
except Exception: # noqa
rv = CouldNotBePickled.reduce(obj)
@@ -98,8 +98,7 @@ class Pickler(dill.Pickler):
# Assert that it returned an appropriately sized tuple
length = len(rv)
if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
raise PicklingError("Tuple returned by %s must have two to six elements" % reduce)
# Save the reduce() output and finally memoize the object
try:
@@ -108,12 +107,13 @@ class Pickler(dill.Pickler):
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
**kwds: Any) -> bytes:
def dumps(
obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any
) -> bytes:
"""pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
protocol = dill.settings["protocol"] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
with BytesIO() as file:
Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue()
return file.getvalue()

View File

@@ -1,6 +1,6 @@
[project]
name = "parfor"
version = "2026.1.4"
version = "2026.1.5"
description = "A package to mimic the use of parfor as done in Matlab."
authors = [
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import numpy as np
@@ -141,6 +143,16 @@ def test_id_reuse() -> None:
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.parametrize("n_processes", (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) <= n_processes
def test_shared_array() -> None:
def fun(i, a):
a[i] = i
@@ -150,3 +162,13 @@ def test_shared_array() -> None:
b = np.array(arr)
assert np.all(b == np.arange(len(arr)))
def test_nesting() -> None:
def a(i):
return i**2
def b(i):
return pmap(a, range(i, i + 50))
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]