diff --git a/parfor/__init__.py b/parfor/__init__.py index 8088789..53fd5b3 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -2,35 +2,35 @@ from __future__ import annotations import logging import os -import warnings -from contextlib import ExitStack, redirect_stdout, redirect_stderr -from io import StringIO +from contextlib import ExitStack, redirect_stderr, redirect_stdout from functools import wraps from importlib.metadata import version +from io import StringIO from multiprocessing.shared_memory import SharedMemory from traceback import format_exc from typing import ( Any, Callable, Generator, + Hashable, Iterable, Iterator, - Sized, - Hashable, NoReturn, Optional, Protocol, Sequence, + Sized, ) import numpy as np import ray from numpy.typing import ArrayLike, DTypeLike +from ray.remote_function import RemoteFunction +from ray.types import ObjectRef from tqdm.auto import tqdm from .pickler import dumps, loads - __version__ = version("parfor") cpu_count = int(os.cpu_count()) @@ -193,30 +193,42 @@ class ExternalBar(Iterable): self.callback(n) -@ray.remote -def worker(task): - try: - with ExitStack() as stack: # noqa - if task.allow_output: - out = StringIO() - err = StringIO() - stack.enter_context(redirect_stdout(out)) - stack.enter_context(redirect_stderr(err)) - else: - stack.enter_context(redirect_stdout(open(os.devnull, "w"))) - stack.enter_context(redirect_stderr(open(os.devnull, "w"))) - try: - task() - task.status = ("done",) - except Exception: # noqa - task.status = "task_error", format_exc() - if task.allow_output: - task.out = out.getvalue() - task.err = err.getvalue() - except KeyboardInterrupt: # noqa - pass +def get_worker(n_processes) -> RemoteFunction: + n_processes = n_processes or PoolSingleton.cpu_count + num_cpus = None if n_processes is None else cpu_count / n_processes - return task + if not ray.is_initialized(): + os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" + ray.init(logging_level=logging.ERROR, log_to_driver=False) + + def worker(task): + try: + with ExitStack() as stack: # noqa + if task.allow_output: + out = StringIO() + err = StringIO() + stack.enter_context(redirect_stdout(out)) + stack.enter_context(redirect_stderr(err)) + else: + stack.enter_context(redirect_stdout(open(os.devnull, "w"))) + stack.enter_context(redirect_stderr(open(os.devnull, "w"))) + try: + task() + task.status = ("done",) + except Exception: # noqa + task.status = "task_error", format_exc() + if task.allow_output: + task.out = out.getvalue() + task.err = err.getvalue() + except KeyboardInterrupt: # noqa + pass + + return task + + if num_cpus: + return ray.remote(num_cpus=num_cpus)(worker) # type: ignore + else: + return ray.remote(worker) # type: ignore class Task: @@ -322,7 +334,8 @@ class ParPool: self.fun = fun self.args = args self.kwargs = kwargs - PoolSingleton(n_processes) + self.n_processes = n_processes or PoolSingleton.cpu_count + self.worker = get_worker(self.n_processes) def __getstate__(self) -> NoReturn: raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.") @@ -366,7 +379,8 @@ class ParPool: kwargs or self.kwargs, allow_output or self.allow_output, ) - task.future = worker.remote(task) + self.block_until_space_available() + task.future = self.worker.remote(task) self.tasks[new_handle] = task self.bar_lengths[new_handle] = barlength if handle is None: @@ -381,13 +395,10 @@ class ParPool: def __getitem__(self, handle: Hashable) -> Any: """Request result and delete its record. Wait if result not yet available.""" if handle not in self: - raise ValueError(f"No handle: {handle} in pool") - task = self.tasks[handle] - if task.future is None: - return task.result - else: - task = ray.get(self.tasks[handle].future) - return self.finalize_task(task) + raise KeyError(f"No task with handle: {handle} in pool") + task = self.finalize_task(self.tasks[handle]) + self.tasks.pop(task.handle) + return task.result def __contains__(self, handle: Hashable) -> bool: return handle in self.tasks @@ -395,36 +406,58 @@ class ParPool: def __delitem__(self, handle: Hashable) -> None: self.tasks.pop(handle) - def finalize_task(self, task: Task) -> Any: - code, *args = task.status - if task.out: - if hasattr(self.bar, "write"): - self.bar.write(task.out, end="") - else: - print(task.out, end="") - if task.err: - if hasattr(self.bar, "write"): - self.bar.write(task.err, end="") - else: - print(task.err, end="") - getattr(self, code)(task, *args) - self.tasks.pop(task.handle) - return task.result + def finalize_task(self, future: ObjectRef | Task) -> Task: + if isinstance(future, Task): + task: Task = future + future = task.future + else: + task = None # type: ignore - def get_newest(self) -> Optional[Any]: - """Request the newest handle and result and delete its record. Wait if result not yet available.""" + if future is not None: + task: Task = ray.get(future) # type: ignore + code, *args = task.status + if task.out: + if hasattr(self.bar, "write"): + self.bar.write(task.out, end="") + else: + print(task.out, end="") + if task.err: + if hasattr(self.bar, "write"): + self.bar.write(task.err, end="") + else: + print(task.err, end="") + getattr(self, code)(task, *args) + self.tasks[task.handle] = task + return task + + def block_until_space_available(self) -> None: + if len(self.tasks) < 3 * self.n_processes: + return while True: if self.tasks: - for handle, task in self.tasks.items(): - if handle in self.tasks: - try: - if task.future is None: - return task.handle, task.result - else: - task = ray.get(task.future, timeout=0.01) - return task.handle, self.finalize_task(task) - except ray.exceptions.GetTimeoutError: - pass + futures = [task.future for task in self.tasks.values() if task.future is not None] + done, busy = ray.wait(futures, num_returns=1, timeout=0.01) + for d in done: + self.finalize_task(d) # type: ignore + if len(busy) < 3 * self.n_processes: + return + + def get_newest(self) -> Any: + """Request the newest handle and result and delete its record. Wait if result not yet available.""" + if self.tasks: + done = [task for task in self.tasks.values() if task.future is None] + if done: + task = done[0] + self.tasks.pop(task.handle) + return task.handle, task.result + while True: + futures = [task.future for task in self.tasks.values() if task.future is not None] + done, _ = ray.wait(futures, num_returns=1, timeout=0.01) + if done: + task = self.finalize_task(done[0]) + self.tasks.pop(task.handle) + return task.handle, task.result + raise StopIteration def task_error(self, task: Task, error: Exception) -> None: if task.handle in self: @@ -443,23 +476,7 @@ class ParPool: class PoolSingleton: - instance: PoolSingleton = None - cpu_count: int = int(os.cpu_count()) - - def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton: - # restart if any workers have shut down or if we want to have a different number of processes - n_processes = n_processes or cls.cpu_count - if cls.instance is None or cls.instance.n_processes != n_processes: - cls.instance = super().__new__(cls) - cls.instance.n_processes = n_processes - if ray.is_initialized(): - if cls.instance.n_processes != n_processes: - warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, " - f"probably with n_processes={cls.instance.n_processes}") - else: - os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" - ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False) - return cls.instance + cpu_count: int = os.cpu_count() class Worker: diff --git a/parfor/pickler.py b/parfor/pickler.py index 0a016a4..7b9cb4b 100644 --- a/parfor/pickler.py +++ b/parfor/pickler.py @@ -23,12 +23,13 @@ class CouldNotBePickled: class Pickler(dill.Pickler): - """ Overload dill to ignore unpicklable parts of objects. - You probably didn't want to use these parts anyhow. - However, if you did, you'll have to find some way to make them picklable. + """Overload dill to ignore unpicklable parts of objects. + You probably didn't want to use these parts anyhow. + However, if you did, you'll have to find some way to make them picklable. """ + def save(self, obj: Any, save_persistent_id: bool = True) -> None: - """ Copied from pickle and amended. """ + """Copied from pickle and amended.""" self.framer.commit_frame() # Check for persistent id (defined by a subclass) @@ -58,7 +59,7 @@ class Pickler(dill.Pickler): # Check private dispatch table if any, or else # copyreg.dispatch_table - reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t) + reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t) if reduce is not None: rv = reduce(obj) else: @@ -78,8 +79,7 @@ class Pickler(dill.Pickler): if reduce is not None: rv = reduce() else: - raise PicklingError("Can't pickle %r object: %r" % - (t.__name__, obj)) + raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj)) except Exception: # noqa rv = CouldNotBePickled.reduce(obj) @@ -98,8 +98,7 @@ class Pickler(dill.Pickler): # Assert that it returned an appropriately sized tuple length = len(rv) if not (2 <= length <= 6): - raise PicklingError("Tuple returned by %s must have " - "two to six elements" % reduce) + raise PicklingError("Tuple returned by %s must have two to six elements" % reduce) # Save the reduce() output and finally memoize the object try: @@ -108,12 +107,13 @@ class Pickler(dill.Pickler): self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) -def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, - **kwds: Any) -> bytes: +def dumps( + obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any +) -> bytes: """pickle an object to a string""" - protocol = dill.settings['protocol'] if protocol is None else int(protocol) + protocol = dill.settings["protocol"] if protocol is None else int(protocol) _kwds = kwds.copy() _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) with BytesIO() as file: Pickler(file, protocol, **_kwds).dump(obj) - return file.getvalue() \ No newline at end of file + return file.getvalue() diff --git a/pyproject.toml b/pyproject.toml index 413bf47..3a91044 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "parfor" -version = "2026.1.4" +version = "2026.1.5" description = "A package to mimic the use of parfor as done in Matlab." authors = [ { name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" } diff --git a/tests/test_parfor.py b/tests/test_parfor.py index 2d178ab..22a346c 100644 --- a/tests/test_parfor.py +++ b/tests/test_parfor.py @@ -1,6 +1,8 @@ from __future__ import annotations from dataclasses import dataclass +from os import getpid +from time import sleep from typing import Any, Iterator, Optional, Sequence import numpy as np @@ -141,6 +143,16 @@ def test_id_reuse() -> None: assert all([i == j for i, j in enumerate(a)]) +@pytest.mark.parametrize("n_processes", (2, 4, 6)) +def test_n_processes(n_processes) -> None: + @parfor(range(12), n_processes=n_processes) + def fun(i): # noqa + sleep(0.25) + return getpid() + + assert len(set(fun)) <= n_processes + + def test_shared_array() -> None: def fun(i, a): a[i] = i @@ -150,3 +162,13 @@ def test_shared_array() -> None: b = np.array(arr) assert np.all(b == np.arange(len(arr))) + + +def test_nesting() -> None: + def a(i): + return i**2 + + def b(i): + return pmap(a, range(i, i + 50)) + + assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]