- n_processes can now be changed each time
- block adding tasks if pool is busy
This commit is contained in:
@@ -2,35 +2,35 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
from contextlib import ExitStack, redirect_stderr, redirect_stdout
|
||||||
from contextlib import ExitStack, redirect_stdout, redirect_stderr
|
|
||||||
from io import StringIO
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from io import StringIO
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
from traceback import format_exc
|
from traceback import format_exc
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
Generator,
|
||||||
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
Sized,
|
|
||||||
Hashable,
|
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Sized,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ray
|
import ray
|
||||||
from numpy.typing import ArrayLike, DTypeLike
|
from numpy.typing import ArrayLike, DTypeLike
|
||||||
|
from ray.remote_function import RemoteFunction
|
||||||
|
from ray.types import ObjectRef
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .pickler import dumps, loads
|
from .pickler import dumps, loads
|
||||||
|
|
||||||
|
|
||||||
__version__ = version("parfor")
|
__version__ = version("parfor")
|
||||||
cpu_count = int(os.cpu_count())
|
cpu_count = int(os.cpu_count())
|
||||||
|
|
||||||
@@ -193,30 +193,42 @@ class ExternalBar(Iterable):
|
|||||||
self.callback(n)
|
self.callback(n)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
def get_worker(n_processes) -> RemoteFunction:
|
||||||
def worker(task):
|
n_processes = n_processes or PoolSingleton.cpu_count
|
||||||
try:
|
num_cpus = None if n_processes is None else cpu_count / n_processes
|
||||||
with ExitStack() as stack: # noqa
|
|
||||||
if task.allow_output:
|
|
||||||
out = StringIO()
|
|
||||||
err = StringIO()
|
|
||||||
stack.enter_context(redirect_stdout(out))
|
|
||||||
stack.enter_context(redirect_stderr(err))
|
|
||||||
else:
|
|
||||||
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
|
|
||||||
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
|
|
||||||
try:
|
|
||||||
task()
|
|
||||||
task.status = ("done",)
|
|
||||||
except Exception: # noqa
|
|
||||||
task.status = "task_error", format_exc()
|
|
||||||
if task.allow_output:
|
|
||||||
task.out = out.getvalue()
|
|
||||||
task.err = err.getvalue()
|
|
||||||
except KeyboardInterrupt: # noqa
|
|
||||||
pass
|
|
||||||
|
|
||||||
return task
|
if not ray.is_initialized():
|
||||||
|
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||||
|
ray.init(logging_level=logging.ERROR, log_to_driver=False)
|
||||||
|
|
||||||
|
def worker(task):
|
||||||
|
try:
|
||||||
|
with ExitStack() as stack: # noqa
|
||||||
|
if task.allow_output:
|
||||||
|
out = StringIO()
|
||||||
|
err = StringIO()
|
||||||
|
stack.enter_context(redirect_stdout(out))
|
||||||
|
stack.enter_context(redirect_stderr(err))
|
||||||
|
else:
|
||||||
|
stack.enter_context(redirect_stdout(open(os.devnull, "w")))
|
||||||
|
stack.enter_context(redirect_stderr(open(os.devnull, "w")))
|
||||||
|
try:
|
||||||
|
task()
|
||||||
|
task.status = ("done",)
|
||||||
|
except Exception: # noqa
|
||||||
|
task.status = "task_error", format_exc()
|
||||||
|
if task.allow_output:
|
||||||
|
task.out = out.getvalue()
|
||||||
|
task.err = err.getvalue()
|
||||||
|
except KeyboardInterrupt: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
if num_cpus:
|
||||||
|
return ray.remote(num_cpus=num_cpus)(worker) # type: ignore
|
||||||
|
else:
|
||||||
|
return ray.remote(worker) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
@@ -322,7 +334,8 @@ class ParPool:
|
|||||||
self.fun = fun
|
self.fun = fun
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
PoolSingleton(n_processes)
|
self.n_processes = n_processes or PoolSingleton.cpu_count
|
||||||
|
self.worker = get_worker(self.n_processes)
|
||||||
|
|
||||||
def __getstate__(self) -> NoReturn:
|
def __getstate__(self) -> NoReturn:
|
||||||
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
|
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
|
||||||
@@ -366,7 +379,8 @@ class ParPool:
|
|||||||
kwargs or self.kwargs,
|
kwargs or self.kwargs,
|
||||||
allow_output or self.allow_output,
|
allow_output or self.allow_output,
|
||||||
)
|
)
|
||||||
task.future = worker.remote(task)
|
self.block_until_space_available()
|
||||||
|
task.future = self.worker.remote(task)
|
||||||
self.tasks[new_handle] = task
|
self.tasks[new_handle] = task
|
||||||
self.bar_lengths[new_handle] = barlength
|
self.bar_lengths[new_handle] = barlength
|
||||||
if handle is None:
|
if handle is None:
|
||||||
@@ -381,13 +395,10 @@ class ParPool:
|
|||||||
def __getitem__(self, handle: Hashable) -> Any:
|
def __getitem__(self, handle: Hashable) -> Any:
|
||||||
"""Request result and delete its record. Wait if result not yet available."""
|
"""Request result and delete its record. Wait if result not yet available."""
|
||||||
if handle not in self:
|
if handle not in self:
|
||||||
raise ValueError(f"No handle: {handle} in pool")
|
raise KeyError(f"No task with handle: {handle} in pool")
|
||||||
task = self.tasks[handle]
|
task = self.finalize_task(self.tasks[handle])
|
||||||
if task.future is None:
|
self.tasks.pop(task.handle)
|
||||||
return task.result
|
return task.result
|
||||||
else:
|
|
||||||
task = ray.get(self.tasks[handle].future)
|
|
||||||
return self.finalize_task(task)
|
|
||||||
|
|
||||||
def __contains__(self, handle: Hashable) -> bool:
|
def __contains__(self, handle: Hashable) -> bool:
|
||||||
return handle in self.tasks
|
return handle in self.tasks
|
||||||
@@ -395,36 +406,58 @@ class ParPool:
|
|||||||
def __delitem__(self, handle: Hashable) -> None:
|
def __delitem__(self, handle: Hashable) -> None:
|
||||||
self.tasks.pop(handle)
|
self.tasks.pop(handle)
|
||||||
|
|
||||||
def finalize_task(self, task: Task) -> Any:
|
def finalize_task(self, future: ObjectRef | Task) -> Task:
|
||||||
code, *args = task.status
|
if isinstance(future, Task):
|
||||||
if task.out:
|
task: Task = future
|
||||||
if hasattr(self.bar, "write"):
|
future = task.future
|
||||||
self.bar.write(task.out, end="")
|
else:
|
||||||
else:
|
task = None # type: ignore
|
||||||
print(task.out, end="")
|
|
||||||
if task.err:
|
|
||||||
if hasattr(self.bar, "write"):
|
|
||||||
self.bar.write(task.err, end="")
|
|
||||||
else:
|
|
||||||
print(task.err, end="")
|
|
||||||
getattr(self, code)(task, *args)
|
|
||||||
self.tasks.pop(task.handle)
|
|
||||||
return task.result
|
|
||||||
|
|
||||||
def get_newest(self) -> Optional[Any]:
|
if future is not None:
|
||||||
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
|
task: Task = ray.get(future) # type: ignore
|
||||||
|
code, *args = task.status
|
||||||
|
if task.out:
|
||||||
|
if hasattr(self.bar, "write"):
|
||||||
|
self.bar.write(task.out, end="")
|
||||||
|
else:
|
||||||
|
print(task.out, end="")
|
||||||
|
if task.err:
|
||||||
|
if hasattr(self.bar, "write"):
|
||||||
|
self.bar.write(task.err, end="")
|
||||||
|
else:
|
||||||
|
print(task.err, end="")
|
||||||
|
getattr(self, code)(task, *args)
|
||||||
|
self.tasks[task.handle] = task
|
||||||
|
return task
|
||||||
|
|
||||||
|
def block_until_space_available(self) -> None:
|
||||||
|
if len(self.tasks) < 3 * self.n_processes:
|
||||||
|
return
|
||||||
while True:
|
while True:
|
||||||
if self.tasks:
|
if self.tasks:
|
||||||
for handle, task in self.tasks.items():
|
futures = [task.future for task in self.tasks.values() if task.future is not None]
|
||||||
if handle in self.tasks:
|
done, busy = ray.wait(futures, num_returns=1, timeout=0.01)
|
||||||
try:
|
for d in done:
|
||||||
if task.future is None:
|
self.finalize_task(d) # type: ignore
|
||||||
return task.handle, task.result
|
if len(busy) < 3 * self.n_processes:
|
||||||
else:
|
return
|
||||||
task = ray.get(task.future, timeout=0.01)
|
|
||||||
return task.handle, self.finalize_task(task)
|
def get_newest(self) -> Any:
|
||||||
except ray.exceptions.GetTimeoutError:
|
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
|
||||||
pass
|
if self.tasks:
|
||||||
|
done = [task for task in self.tasks.values() if task.future is None]
|
||||||
|
if done:
|
||||||
|
task = done[0]
|
||||||
|
self.tasks.pop(task.handle)
|
||||||
|
return task.handle, task.result
|
||||||
|
while True:
|
||||||
|
futures = [task.future for task in self.tasks.values() if task.future is not None]
|
||||||
|
done, _ = ray.wait(futures, num_returns=1, timeout=0.01)
|
||||||
|
if done:
|
||||||
|
task = self.finalize_task(done[0])
|
||||||
|
self.tasks.pop(task.handle)
|
||||||
|
return task.handle, task.result
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
def task_error(self, task: Task, error: Exception) -> None:
|
def task_error(self, task: Task, error: Exception) -> None:
|
||||||
if task.handle in self:
|
if task.handle in self:
|
||||||
@@ -443,23 +476,7 @@ class ParPool:
|
|||||||
|
|
||||||
|
|
||||||
class PoolSingleton:
|
class PoolSingleton:
|
||||||
instance: PoolSingleton = None
|
cpu_count: int = os.cpu_count()
|
||||||
cpu_count: int = int(os.cpu_count())
|
|
||||||
|
|
||||||
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
|
||||||
# restart if any workers have shut down or if we want to have a different number of processes
|
|
||||||
n_processes = n_processes or cls.cpu_count
|
|
||||||
if cls.instance is None or cls.instance.n_processes != n_processes:
|
|
||||||
cls.instance = super().__new__(cls)
|
|
||||||
cls.instance.n_processes = n_processes
|
|
||||||
if ray.is_initialized():
|
|
||||||
if cls.instance.n_processes != n_processes:
|
|
||||||
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
|
|
||||||
f"probably with n_processes={cls.instance.n_processes}")
|
|
||||||
else:
|
|
||||||
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
|
||||||
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
|
||||||
return cls.instance
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
|
|||||||
@@ -23,12 +23,13 @@ class CouldNotBePickled:
|
|||||||
|
|
||||||
|
|
||||||
class Pickler(dill.Pickler):
|
class Pickler(dill.Pickler):
|
||||||
""" Overload dill to ignore unpicklable parts of objects.
|
"""Overload dill to ignore unpicklable parts of objects.
|
||||||
You probably didn't want to use these parts anyhow.
|
You probably didn't want to use these parts anyhow.
|
||||||
However, if you did, you'll have to find some way to make them picklable.
|
However, if you did, you'll have to find some way to make them picklable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
||||||
""" Copied from pickle and amended. """
|
"""Copied from pickle and amended."""
|
||||||
self.framer.commit_frame()
|
self.framer.commit_frame()
|
||||||
|
|
||||||
# Check for persistent id (defined by a subclass)
|
# Check for persistent id (defined by a subclass)
|
||||||
@@ -58,7 +59,7 @@ class Pickler(dill.Pickler):
|
|||||||
|
|
||||||
# Check private dispatch table if any, or else
|
# Check private dispatch table if any, or else
|
||||||
# copyreg.dispatch_table
|
# copyreg.dispatch_table
|
||||||
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
reduce = getattr(self, "dispatch_table", copyreg.dispatch_table).get(t)
|
||||||
if reduce is not None:
|
if reduce is not None:
|
||||||
rv = reduce(obj)
|
rv = reduce(obj)
|
||||||
else:
|
else:
|
||||||
@@ -78,8 +79,7 @@ class Pickler(dill.Pickler):
|
|||||||
if reduce is not None:
|
if reduce is not None:
|
||||||
rv = reduce()
|
rv = reduce()
|
||||||
else:
|
else:
|
||||||
raise PicklingError("Can't pickle %r object: %r" %
|
raise PicklingError("Can't pickle %r object: %r" % (t.__name__, obj))
|
||||||
(t.__name__, obj))
|
|
||||||
except Exception: # noqa
|
except Exception: # noqa
|
||||||
rv = CouldNotBePickled.reduce(obj)
|
rv = CouldNotBePickled.reduce(obj)
|
||||||
|
|
||||||
@@ -98,8 +98,7 @@ class Pickler(dill.Pickler):
|
|||||||
# Assert that it returned an appropriately sized tuple
|
# Assert that it returned an appropriately sized tuple
|
||||||
length = len(rv)
|
length = len(rv)
|
||||||
if not (2 <= length <= 6):
|
if not (2 <= length <= 6):
|
||||||
raise PicklingError("Tuple returned by %s must have "
|
raise PicklingError("Tuple returned by %s must have two to six elements" % reduce)
|
||||||
"two to six elements" % reduce)
|
|
||||||
|
|
||||||
# Save the reduce() output and finally memoize the object
|
# Save the reduce() output and finally memoize the object
|
||||||
try:
|
try:
|
||||||
@@ -108,12 +107,13 @@ class Pickler(dill.Pickler):
|
|||||||
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
||||||
|
|
||||||
|
|
||||||
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
|
def dumps(
|
||||||
**kwds: Any) -> bytes:
|
obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, **kwds: Any
|
||||||
|
) -> bytes:
|
||||||
"""pickle an object to a string"""
|
"""pickle an object to a string"""
|
||||||
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
protocol = dill.settings["protocol"] if protocol is None else int(protocol)
|
||||||
_kwds = kwds.copy()
|
_kwds = kwds.copy()
|
||||||
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
||||||
with BytesIO() as file:
|
with BytesIO() as file:
|
||||||
Pickler(file, protocol, **_kwds).dump(obj)
|
Pickler(file, protocol, **_kwds).dump(obj)
|
||||||
return file.getvalue()
|
return file.getvalue()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "parfor"
|
name = "parfor"
|
||||||
version = "2026.1.4"
|
version = "2026.1.5"
|
||||||
description = "A package to mimic the use of parfor as done in Matlab."
|
description = "A package to mimic the use of parfor as done in Matlab."
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from os import getpid
|
||||||
|
from time import sleep
|
||||||
from typing import Any, Iterator, Optional, Sequence
|
from typing import Any, Iterator, Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -141,6 +143,16 @@ def test_id_reuse() -> None:
|
|||||||
assert all([i == j for i, j in enumerate(a)])
|
assert all([i == j for i, j in enumerate(a)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_processes", (2, 4, 6))
|
||||||
|
def test_n_processes(n_processes) -> None:
|
||||||
|
@parfor(range(12), n_processes=n_processes)
|
||||||
|
def fun(i): # noqa
|
||||||
|
sleep(0.25)
|
||||||
|
return getpid()
|
||||||
|
|
||||||
|
assert len(set(fun)) <= n_processes
|
||||||
|
|
||||||
|
|
||||||
def test_shared_array() -> None:
|
def test_shared_array() -> None:
|
||||||
def fun(i, a):
|
def fun(i, a):
|
||||||
a[i] = i
|
a[i] = i
|
||||||
@@ -150,3 +162,13 @@ def test_shared_array() -> None:
|
|||||||
b = np.array(arr)
|
b = np.array(arr)
|
||||||
|
|
||||||
assert np.all(b == np.arange(len(arr)))
|
assert np.all(b == np.arange(len(arr)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_nesting() -> None:
|
||||||
|
def a(i):
|
||||||
|
return i**2
|
||||||
|
|
||||||
|
def b(i):
|
||||||
|
return pmap(a, range(i, i + 50))
|
||||||
|
|
||||||
|
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]
|
||||||
|
|||||||
Reference in New Issue
Block a user