- n_processes can now be changed each time

- block adding tasks if pool is busy
This commit is contained in:
Wim Pomp
2026-01-09 11:37:55 +01:00
parent 1098239af9
commit 9fdbf49b5c
4 changed files with 135 additions and 96 deletions

View File

@@ -2,35 +2,35 @@ from __future__ import annotations
import logging
import os
import warnings
from contextlib import ExitStack, redirect_stdout, redirect_stderr
from io import StringIO
from contextlib import ExitStack, redirect_stderr, redirect_stdout
from functools import wraps
from importlib.metadata import version
from io import StringIO
from multiprocessing.shared_memory import SharedMemory
from traceback import format_exc
from typing import (
Any,
Callable,
Generator,
Hashable,
Iterable,
Iterator,
Sized,
Hashable,
NoReturn,
Optional,
Protocol,
Sequence,
Sized,
)
import numpy as np
import ray
from numpy.typing import ArrayLike, DTypeLike
from ray.remote_function import RemoteFunction
from ray.types import ObjectRef
from tqdm.auto import tqdm
from .pickler import dumps, loads
__version__ = version("parfor")
cpu_count = int(os.cpu_count())
@@ -193,7 +193,14 @@ class ExternalBar(Iterable):
self.callback(n)
@ray.remote
def get_worker(n_processes) -> RemoteFunction:
n_processes = n_processes or PoolSingleton.cpu_count
num_cpus = None if n_processes is None else cpu_count / n_processes
if not ray.is_initialized():
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(logging_level=logging.ERROR, log_to_driver=False)
def worker(task):
try:
with ExitStack() as stack: # noqa
@@ -218,6 +225,11 @@ def worker(task):
return task
if num_cpus:
return ray.remote(num_cpus=num_cpus)(worker) # type: ignore
else:
return ray.remote(worker) # type: ignore
class Task:
def __init__(
@@ -322,7 +334,8 @@ class ParPool:
self.fun = fun
self.args = args
self.kwargs = kwargs
PoolSingleton(n_processes)
self.n_processes = n_processes or PoolSingleton.cpu_count
self.worker = get_worker(self.n_processes)
def __getstate__(self) -> NoReturn:
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
@@ -366,7 +379,8 @@ class ParPool:
kwargs or self.kwargs,
allow_output or self.allow_output,
)
task.future = worker.remote(task)
self.block_until_space_available()
task.future = self.worker.remote(task)
self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength
if handle is None:
@@ -381,13 +395,10 @@ class ParPool:
def __getitem__(self, handle: Hashable) -> Any:
"""Request result and delete its record. Wait if result not yet available."""
if handle not in self:
raise ValueError(f"No handle: {handle} in pool")
task = self.tasks[handle]
if task.future is None:
raise KeyError(f"No task with handle: {handle} in pool")
task = self.finalize_task(self.tasks[handle])
self.tasks.pop(task.handle)
return task.result
else:
task = ray.get(self.tasks[handle].future)
return self.finalize_task(task)
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
@@ -395,7 +406,15 @@ class ParPool:
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def finalize_task(self, task: Task) -> Any:
def finalize_task(self, future: ObjectRef | Task) -> Task:
if isinstance(future, Task):
task: Task = future
future = task.future
else:
task = None # type: ignore
if future is not None:
task: Task = ray.get(future) # type: ignore
code, *args = task.status
if task.out:
if hasattr(self.bar, "write"):
@@ -408,23 +427,37 @@ class ParPool:
else:
print(task.err, end="")
getattr(self, code)(task, *args)
self.tasks.pop(task.handle)
return task.result
self.tasks[task.handle] = task
return task
def get_newest(self) -> Optional[Any]:
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
def block_until_space_available(self) -> None:
if len(self.tasks) < 3 * self.n_processes:
return
while True:
if self.tasks:
for handle, task in self.tasks.items():
if handle in self.tasks:
try:
if task.future is None:
futures = [task.future for task in self.tasks.values() if task.future is not None]
done, busy = ray.wait(futures, num_returns=1, timeout=0.01)
for d in done:
self.finalize_task(d) # type: ignore
if len(busy) < 3 * self.n_processes:
return
def get_newest(self) -> Any:
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
if self.tasks:
done = [task for task in self.tasks.values() if task.future is None]
if done:
task = done[0]
self.tasks.pop(task.handle)
return task.handle, task.result
else:
task = ray.get(task.future, timeout=0.01)
return task.handle, self.finalize_task(task)
except ray.exceptions.GetTimeoutError:
pass
while True:
futures = [task.future for task in self.tasks.values() if task.future is not None]
done, _ = ray.wait(futures, num_returns=1, timeout=0.01)
if done:
task = self.finalize_task(done[0])
self.tasks.pop(task.handle)
return task.handle, task.result
raise StopIteration
def task_error(self, task: Task, error: Exception) -> None:
if task.handle in self:
@@ -443,23 +476,7 @@ class ParPool:
class PoolSingleton:
instance: PoolSingleton = None
cpu_count: int = int(os.cpu_count())
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
n_processes = n_processes or cls.cpu_count
if cls.instance is None or cls.instance.n_processes != n_processes:
cls.instance = super().__new__(cls)
cls.instance.n_processes = n_processes
if ray.is_initialized():
if cls.instance.n_processes != n_processes:
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
f"probably with n_processes={cls.instance.n_processes}")
else:
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
return cls.instance
cpu_count: int = os.cpu_count()
class Worker:

View File

@@ -27,6 +27,7 @@ class Pickler(dill.Pickler):
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable.
"""
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
"""Copied from pickle and amended."""
self.framer.commit_frame()
@@ -58,7 +59,7 @@ class Pickler(dill.Pickler):
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
@@ -78,8 +79,7 @@ class Pickler(dill.Pickler):
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj))
except Exception: # noqa
rv = CouldNotBePickled.reduce(obj)
@@ -98,8 +98,7 @@ class Pickler(dill.Pickler):
# Assert that it returned an appropriately sized tuple
length = len(rv)
if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
raise PicklingError("Tuple returned by %s must have two to six elements" % reduce)
# Save the reduce() output and finally memoize the object
try:
@@ -108,10 +107,11 @@ class Pickler(dill.Pickler):
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
**kwds: Any) -> bytes:
def dumps(
obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any
) -> bytes:
"""pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
protocol = dill.settings["protocol"] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
with BytesIO() as file:

View File

@@ -1,6 +1,6 @@
[project]
name = "parfor"
version = "2026.1.4"
version = "2026.1.5"
description = "A package to mimic the use of parfor as done in Matlab."
authors = [
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import numpy as np
@@ -141,6 +143,16 @@ def test_id_reuse() -> None:
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.parametrize("n_processes", (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) <= n_processes
def test_shared_array() -> None:
def fun(i, a):
a[i] = i
@@ -150,3 +162,13 @@ def test_shared_array() -> None:
b = np.array(arr)
assert np.all(b == np.arange(len(arr)))
def test_nesting() -> None:
def a(i):
return i**2
def b(i):
return pmap(a, range(i, i + 50))
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]