- n_processes can now be changed each time
- block adding tasks if pool is busy
This commit is contained in:
@@ -2,35 +2,35 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import ExitStack, redirect_stdout, redirect_stderr
|
||||
from io import StringIO
|
||||
from contextlib import ExitStack, redirect_stderr, redirect_stdout
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from io import StringIO
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from traceback import format_exc
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sized,
|
||||
Hashable,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Sized,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from numpy.typing import ArrayLike, DTypeLike
|
||||
from ray.remote_function import RemoteFunction
|
||||
from ray.types import ObjectRef
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .pickler import dumps, loads
|
||||
|
||||
|
||||
__version__ = version("parfor")
|
||||
cpu_count = int(os.cpu_count())
|
||||
|
||||
@@ -193,30 +193,42 @@ class ExternalBar(Iterable):
|
||||
self.callback(n)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def worker(task):
|
||||
try:
|
||||
with ExitStack() as stack: # noqa
|
||||
if task.allow_output:
|
||||
out = StringIO()
|
||||
err = StringIO()
|
||||
stack.enter_context(redirect_stdout(out))
|
||||
stack.enter_context(redirect_stderr(err))
|
||||
else:
|
||||
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
|
||||
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
|
||||
try:
|
||||
task()
|
||||
task.status = ("done",)
|
||||
except Exception: # noqa
|
||||
task.status = "task_error", format_exc()
|
||||
if task.allow_output:
|
||||
task.out = out.getvalue()
|
||||
task.err = err.getvalue()
|
||||
except KeyboardInterrupt: # noqa
|
||||
pass
|
||||
def get_worker(n_processes) -> RemoteFunction:
|
||||
n_processes = n_processes or PoolSingleton.cpu_count
|
||||
num_cpus = None if n_processes is None else cpu_count / n_processes
|
||||
|
||||
return task
|
||||
if not ray.is_initialized():
|
||||
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||
ray.init(logging_level=logging.ERROR, log_to_driver=False)
|
||||
|
||||
def worker(task):
|
||||
try:
|
||||
with ExitStack() as stack: # noqa
|
||||
if task.allow_output:
|
||||
out = StringIO()
|
||||
err = StringIO()
|
||||
stack.enter_context(redirect_stdout(out))
|
||||
stack.enter_context(redirect_stderr(err))
|
||||
else:
|
||||
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
|
||||
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
|
||||
try:
|
||||
task()
|
||||
task.status = ("done",)
|
||||
except Exception: # noqa
|
||||
task.status = "task_error", format_exc()
|
||||
if task.allow_output:
|
||||
task.out = out.getvalue()
|
||||
task.err = err.getvalue()
|
||||
except KeyboardInterrupt: # noqa
|
||||
pass
|
||||
|
||||
return task
|
||||
|
||||
if num_cpus:
|
||||
return ray.remote(num_cpus=num_cpus)(worker) # type: ignore
|
||||
else:
|
||||
return ray.remote(worker) # type: ignore
|
||||
|
||||
|
||||
class Task:
|
||||
@@ -322,7 +334,8 @@ class ParPool:
|
||||
self.fun = fun
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
PoolSingleton(n_processes)
|
||||
self.n_processes = n_processes or PoolSingleton.cpu_count
|
||||
self.worker = get_worker(self.n_processes)
|
||||
|
||||
def __getstate__(self) -> NoReturn:
|
||||
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
|
||||
@@ -366,7 +379,8 @@ class ParPool:
|
||||
kwargs or self.kwargs,
|
||||
allow_output or self.allow_output,
|
||||
)
|
||||
task.future = worker.remote(task)
|
||||
self.block_until_space_available()
|
||||
task.future = self.worker.remote(task)
|
||||
self.tasks[new_handle] = task
|
||||
self.bar_lengths[new_handle] = barlength
|
||||
if handle is None:
|
||||
@@ -381,13 +395,10 @@ class ParPool:
|
||||
def __getitem__(self, handle: Hashable) -> Any:
|
||||
"""Request result and delete its record. Wait if result not yet available."""
|
||||
if handle not in self:
|
||||
raise ValueError(f"No handle: {handle} in pool")
|
||||
task = self.tasks[handle]
|
||||
if task.future is None:
|
||||
return task.result
|
||||
else:
|
||||
task = ray.get(self.tasks[handle].future)
|
||||
return self.finalize_task(task)
|
||||
raise KeyError(f"No task with handle: {handle} in pool")
|
||||
task = self.finalize_task(self.tasks[handle])
|
||||
self.tasks.pop(task.handle)
|
||||
return task.result
|
||||
|
||||
def __contains__(self, handle: Hashable) -> bool:
|
||||
return handle in self.tasks
|
||||
@@ -395,36 +406,58 @@ class ParPool:
|
||||
def __delitem__(self, handle: Hashable) -> None:
|
||||
self.tasks.pop(handle)
|
||||
|
||||
def finalize_task(self, task: Task) -> Any:
|
||||
code, *args = task.status
|
||||
if task.out:
|
||||
if hasattr(self.bar, "write"):
|
||||
self.bar.write(task.out, end="")
|
||||
else:
|
||||
print(task.out, end="")
|
||||
if task.err:
|
||||
if hasattr(self.bar, "write"):
|
||||
self.bar.write(task.err, end="")
|
||||
else:
|
||||
print(task.err, end="")
|
||||
getattr(self, code)(task, *args)
|
||||
self.tasks.pop(task.handle)
|
||||
return task.result
|
||||
def finalize_task(self, future: ObjectRef | Task) -> Task:
|
||||
if isinstance(future, Task):
|
||||
task: Task = future
|
||||
future = task.future
|
||||
else:
|
||||
task = None # type: ignore
|
||||
|
||||
def get_newest(self) -> Optional[Any]:
|
||||
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
|
||||
if future is not None:
|
||||
task: Task = ray.get(future) # type: ignore
|
||||
code, *args = task.status
|
||||
if task.out:
|
||||
if hasattr(self.bar, "write"):
|
||||
self.bar.write(task.out, end="")
|
||||
else:
|
||||
print(task.out, end="")
|
||||
if task.err:
|
||||
if hasattr(self.bar, "write"):
|
||||
self.bar.write(task.err, end="")
|
||||
else:
|
||||
print(task.err, end="")
|
||||
getattr(self, code)(task, *args)
|
||||
self.tasks[task.handle] = task
|
||||
return task
|
||||
|
||||
def block_until_space_available(self) -> None:
|
||||
if len(self.tasks) < 3 * self.n_processes:
|
||||
return
|
||||
while True:
|
||||
if self.tasks:
|
||||
for handle, task in self.tasks.items():
|
||||
if handle in self.tasks:
|
||||
try:
|
||||
if task.future is None:
|
||||
return task.handle, task.result
|
||||
else:
|
||||
task = ray.get(task.future, timeout=0.01)
|
||||
return task.handle, self.finalize_task(task)
|
||||
except ray.exceptions.GetTimeoutError:
|
||||
pass
|
||||
futures = [task.future for task in self.tasks.values() if task.future is not None]
|
||||
done, busy = ray.wait(futures, num_returns=1, timeout=0.01)
|
||||
for d in done:
|
||||
self.finalize_task(d) # type: ignore
|
||||
if len(busy) < 3 * self.n_processes:
|
||||
return
|
||||
|
||||
def get_newest(self) -> Any:
|
||||
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
|
||||
if self.tasks:
|
||||
done = [task for task in self.tasks.values() if task.future is None]
|
||||
if done:
|
||||
task = done[0]
|
||||
self.tasks.pop(task.handle)
|
||||
return task.handle, task.result
|
||||
while True:
|
||||
futures = [task.future for task in self.tasks.values() if task.future is not None]
|
||||
done, _ = ray.wait(futures, num_returns=1, timeout=0.01)
|
||||
if done:
|
||||
task = self.finalize_task(done[0])
|
||||
self.tasks.pop(task.handle)
|
||||
return task.handle, task.result
|
||||
raise StopIteration
|
||||
|
||||
def task_error(self, task: Task, error: Exception) -> None:
|
||||
if task.handle in self:
|
||||
@@ -443,23 +476,7 @@ class ParPool:
|
||||
|
||||
|
||||
class PoolSingleton:
|
||||
instance: PoolSingleton = None
|
||||
cpu_count: int = int(os.cpu_count())
|
||||
|
||||
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
||||
# restart if any workers have shut down or if we want to have a different number of processes
|
||||
n_processes = n_processes or cls.cpu_count
|
||||
if cls.instance is None or cls.instance.n_processes != n_processes:
|
||||
cls.instance = super().__new__(cls)
|
||||
cls.instance.n_processes = n_processes
|
||||
if ray.is_initialized():
|
||||
if cls.instance.n_processes != n_processes:
|
||||
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
|
||||
f"probably with n_processes={cls.instance.n_processes}")
|
||||
else:
|
||||
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
||||
return cls.instance
|
||||
cpu_count: int = os.cpu_count()
|
||||
|
||||
|
||||
class Worker:
|
||||
|
||||
@@ -23,12 +23,13 @@ class CouldNotBePickled:
|
||||
|
||||
|
||||
class Pickler(dill.Pickler):
|
||||
""" Overload dill to ignore unpicklable parts of objects.
|
||||
You probably didn't want to use these parts anyhow.
|
||||
However, if you did, you'll have to find some way to make them picklable.
|
||||
"""Overload dill to ignore unpicklable parts of objects.
|
||||
You probably didn't want to use these parts anyhow.
|
||||
However, if you did, you'll have to find some way to make them picklable.
|
||||
"""
|
||||
|
||||
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
||||
""" Copied from pickle and amended. """
|
||||
"""Copied from pickle and amended."""
|
||||
self.framer.commit_frame()
|
||||
|
||||
# Check for persistent id (defined by a subclass)
|
||||
@@ -58,7 +59,7 @@ class Pickler(dill.Pickler):
|
||||
|
||||
# Check private dispatch table if any, or else
|
||||
# copyreg.dispatch_table
|
||||
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
||||
reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
else:
|
||||
@@ -78,8 +79,7 @@ class Pickler(dill.Pickler):
|
||||
if reduce is not None:
|
||||
rv = reduce()
|
||||
else:
|
||||
raise PicklingError("Can't pickle %r object: %r" %
|
||||
(t.__name__, obj))
|
||||
raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj))
|
||||
except Exception: # noqa
|
||||
rv = CouldNotBePickled.reduce(obj)
|
||||
|
||||
@@ -98,8 +98,7 @@ class Pickler(dill.Pickler):
|
||||
# Assert that it returned an appropriately sized tuple
|
||||
length = len(rv)
|
||||
if not (2 <= length <= 6):
|
||||
raise PicklingError("Tuple returned by %s must have "
|
||||
"two to six elements" % reduce)
|
||||
raise PicklingError("Tuple returned by %s must have two to six elements" % reduce)
|
||||
|
||||
# Save the reduce() output and finally memoize the object
|
||||
try:
|
||||
@@ -108,12 +107,13 @@ class Pickler(dill.Pickler):
|
||||
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
||||
|
||||
|
||||
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
|
||||
**kwds: Any) -> bytes:
|
||||
def dumps(
|
||||
obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any
|
||||
) -> bytes:
|
||||
"""pickle an object to a string"""
|
||||
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
||||
protocol = dill.settings["protocol"] if protocol is None else int(protocol)
|
||||
_kwds = kwds.copy()
|
||||
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
||||
with BytesIO() as file:
|
||||
Pickler(file, protocol, **_kwds).dump(obj)
|
||||
return file.getvalue()
|
||||
return file.getvalue()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "parfor"
|
||||
version = "2026.1.4"
|
||||
version = "2026.1.5"
|
||||
description = "A package to mimic the use of parfor as done in Matlab."
|
||||
authors = [
|
||||
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from os import getpid
|
||||
from time import sleep
|
||||
from typing import Any, Iterator, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
@@ -141,6 +143,16 @@ def test_id_reuse() -> None:
|
||||
assert all([i == j for i, j in enumerate(a)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_processes", (2, 4, 6))
|
||||
def test_n_processes(n_processes) -> None:
|
||||
@parfor(range(12), n_processes=n_processes)
|
||||
def fun(i): # noqa
|
||||
sleep(0.25)
|
||||
return getpid()
|
||||
|
||||
assert len(set(fun)) <= n_processes
|
||||
|
||||
|
||||
def test_shared_array() -> None:
|
||||
def fun(i, a):
|
||||
a[i] = i
|
||||
@@ -150,3 +162,13 @@ def test_shared_array() -> None:
|
||||
b = np.array(arr)
|
||||
|
||||
assert np.all(b == np.arange(len(arr)))
|
||||
|
||||
|
||||
def test_nesting() -> None:
|
||||
def a(i):
|
||||
return i**2
|
||||
|
||||
def b(i):
|
||||
return pmap(a, range(i, i + 50))
|
||||
|
||||
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]
|
||||
|
||||
Reference in New Issue
Block a user