Now based on ray, which enables nested parallel computations.
This commit is contained in:
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@@ -11,9 +11,9 @@ jobs:
|
|||||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install
|
- name: Install
|
||||||
|
|||||||
@@ -1,47 +1,114 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import logging
|
||||||
from contextlib import ExitStack
|
import os
|
||||||
|
import warnings
|
||||||
|
from contextlib import ExitStack, redirect_stdout, redirect_stderr
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import Any, Callable, Generator, Iterable, Iterator, Sized
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
from warnings import warn
|
from traceback import format_exc
|
||||||
|
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, Hashable, NoReturn, Optional, Protocol, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import ray
|
||||||
|
from numpy.typing import ArrayLike, DTypeLike
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from . import gil, nogil
|
|
||||||
from .common import Bar, SharedArray, cpu_count
|
|
||||||
|
|
||||||
if hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled(): # noqa
|
__version__ = version("parfor")
|
||||||
from .nogil import ParPool, PoolSingleton, Task, Worker
|
cpu_count = int(os.cpu_count())
|
||||||
else:
|
|
||||||
from .gil import ParPool, PoolSingleton, Task, Worker
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = version('parfor')
|
class Bar(Protocol):
|
||||||
|
def update(self, n: int = 1) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class SharedArray(np.ndarray):
|
||||||
|
"""Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one
|
||||||
|
process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent
|
||||||
|
race conditions!
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
shape: int | Sequence[int],
|
||||||
|
dtype: DTypeLike = float,
|
||||||
|
shm: str | SharedMemory = None,
|
||||||
|
offset: int = 0,
|
||||||
|
strides: tuple[int, int] = None,
|
||||||
|
order: str = None,
|
||||||
|
) -> SharedArray:
|
||||||
|
if isinstance(shm, str):
|
||||||
|
shm = SharedMemory(shm)
|
||||||
|
elif shm is None:
|
||||||
|
shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize) # type: ignore
|
||||||
|
new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order)
|
||||||
|
new.shm = shm
|
||||||
|
return new
|
||||||
|
|
||||||
|
def __reduce__(
|
||||||
|
self,
|
||||||
|
) -> tuple[
|
||||||
|
Callable[[int | Sequence[int], DTypeLike, str], SharedArray],
|
||||||
|
tuple[int | tuple[int, ...], np.dtype, str],
|
||||||
|
]:
|
||||||
|
return self.__class__, (self.shape, self.dtype, self.shm.name)
|
||||||
|
|
||||||
|
def __enter__(self) -> SharedArray:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
if hasattr(self, "shm"):
|
||||||
|
self.shm.close()
|
||||||
|
self.shm.unlink()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if hasattr(self, "shm"):
|
||||||
|
self.shm.close()
|
||||||
|
|
||||||
|
def __array_finalize__(self, obj: np.ndarray | None) -> None:
|
||||||
|
if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray):
|
||||||
|
raise TypeError("view casting to SharedArray is not implemented because right now we need to make a copy")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_array(cls, array: ArrayLike) -> SharedArray:
|
||||||
|
"""copy existing array into a SharedArray"""
|
||||||
|
array = np.asarray(array)
|
||||||
|
new = cls(array.shape, array.dtype)
|
||||||
|
new[:] = array[:]
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Chunks(Iterable):
|
class Chunks(Iterable):
|
||||||
""" Yield successive chunks from lists.
|
"""Yield successive chunks from lists.
|
||||||
Usage: chunks(list0, list1, ...)
|
Usage: chunks(list0, list1, ...)
|
||||||
chunks(list0, list1, ..., size=s)
|
chunks(list0, list1, ..., size=s)
|
||||||
chunks(list0, list1, ..., number=n)
|
chunks(list0, list1, ..., number=n)
|
||||||
chunks(list0, list1, ..., ratio=r)
|
chunks(list0, list1, ..., ratio=r)
|
||||||
size: size of chunks, might change to optimize division between chunks
|
size: size of chunks, might change to optimize division between chunks
|
||||||
number: number of chunks, coerced to 1 <= n <= len(list0)
|
number: number of chunks, coerced to 1 <= n <= len(list0)
|
||||||
ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0)
|
ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0)
|
||||||
both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size
|
both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size
|
||||||
both ratio and number are given: use ratio
|
both ratio and number are given: use ratio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None,
|
def __init__(
|
||||||
ratio: float = None, length: int = None) -> None:
|
self,
|
||||||
|
*iterables: Iterable[Any] | Sized,
|
||||||
|
size: int = None,
|
||||||
|
number: int = None,
|
||||||
|
ratio: float = None,
|
||||||
|
length: int = None,
|
||||||
|
) -> None:
|
||||||
if length is None:
|
if length is None:
|
||||||
try:
|
try:
|
||||||
length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
|
length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an'
|
raise TypeError(
|
||||||
' argument.')
|
"Cannot determine the length of the iterables(s), so the length must be provided as an argument."
|
||||||
|
)
|
||||||
if size is not None and (number is not None or ratio is not None):
|
if size is not None and (number is not None or ratio is not None):
|
||||||
if number is None:
|
if number is None:
|
||||||
number = int(cpu_count * ratio)
|
number = int(cpu_count * ratio)
|
||||||
@@ -54,12 +121,17 @@ class Chunks(Iterable):
|
|||||||
self.iterators = [iter(arg) for arg in iterables]
|
self.iterators = [iter(arg) for arg in iterables]
|
||||||
self.number_of_items = length
|
self.number_of_items = length
|
||||||
self.length = min(length, number)
|
self.length = min(length, number)
|
||||||
self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
|
self.lengths = [
|
||||||
for i in range(self.length)]
|
((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
|
||||||
|
for i in range(self.length)
|
||||||
|
]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Any]:
|
def __iter__(self) -> Iterator[Any]:
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length)
|
p, q = (
|
||||||
|
(i * self.number_of_items // self.length),
|
||||||
|
((i + 1) * self.number_of_items // self.length),
|
||||||
|
)
|
||||||
if len(self.iterators) == 1:
|
if len(self.iterators) == 1:
|
||||||
yield [next(self.iterators[0]) for _ in range(q - p)]
|
yield [next(self.iterators[0]) for _ in range(q - p)]
|
||||||
else:
|
else:
|
||||||
@@ -70,7 +142,12 @@ class Chunks(Iterable):
|
|||||||
|
|
||||||
|
|
||||||
class ExternalBar(Iterable):
|
class ExternalBar(Iterable):
|
||||||
def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
iterable: Iterable = None,
|
||||||
|
callback: Callable[[int], None] = None,
|
||||||
|
total: int = 0,
|
||||||
|
) -> None:
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.total = total
|
self.total = total
|
||||||
@@ -102,90 +179,312 @@ class ExternalBar(Iterable):
|
|||||||
self.callback(n)
|
self.callback(n)
|
||||||
|
|
||||||
|
|
||||||
def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
@ray.remote
|
||||||
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
def worker(task):
|
||||||
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
try:
|
||||||
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
with (
|
||||||
**bar_kwargs: Any) -> Generator[Any, None, None]:
|
warnings.catch_warnings(),
|
||||||
""" map a function fun to each iteration in iterable
|
redirect_stdout(open(os.devnull, "w")),
|
||||||
use as a function: pmap
|
redirect_stderr(open(os.devnull, "w")),
|
||||||
use as a decorator: parfor
|
):
|
||||||
best use: iterable is a generator and length is given to this function as 'total'
|
warnings.simplefilter("ignore", category=FutureWarning)
|
||||||
|
try:
|
||||||
|
task()
|
||||||
|
task.status = "done",
|
||||||
|
except Exception: # noqa
|
||||||
|
task.status = "task_error", format_exc()
|
||||||
|
except KeyboardInterrupt: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
required:
|
return task
|
||||||
fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs
|
|
||||||
iterable: iterable or iterator from which an item is given to fun as a first argument
|
|
||||||
optional:
|
|
||||||
args: tuple with other unnamed arguments to fun
|
|
||||||
kwargs: dict with other named arguments to fun
|
|
||||||
total: give the length of the iterator in cases where len(iterator) results in an error
|
|
||||||
desc: string with description of the progress bar
|
|
||||||
bar: bool enable progress bar,
|
|
||||||
or a callback function taking the number of passed iterations as an argument
|
|
||||||
serial: execute in series instead of parallel if True, None (default): let pmap decide
|
|
||||||
length: deprecated alias for total
|
|
||||||
n_processes: number of processes to use,
|
|
||||||
the parallel pool will be restarted if the current pool does not have the right number of processes
|
|
||||||
yield_ordered: return the result in the same order as the iterable
|
|
||||||
yield_index: return the index of the result too
|
|
||||||
**bar_kwargs: keywords arguments for tqdm.tqdm
|
|
||||||
|
|
||||||
output:
|
|
||||||
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
|
|
||||||
of the iterable / iterator
|
|
||||||
|
|
||||||
examples:
|
class Task:
|
||||||
<< from time import sleep
|
def __init__(
|
||||||
<<
|
self,
|
||||||
@parfor(range(10), (3,))
|
handle: Hashable,
|
||||||
def fun(i, a):
|
fun: Callable[[Any, ...], Any],
|
||||||
sleep(1)
|
args: tuple[Any, ...] = (),
|
||||||
return a * i ** 2
|
kwargs: dict[str, Any] = None,
|
||||||
fun
|
) -> None:
|
||||||
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
self.handle = handle
|
||||||
|
self.fun = fun
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.name = fun.__name__ if hasattr(fun, "__name__") else None
|
||||||
|
self.done = False
|
||||||
|
self.result = None
|
||||||
|
self.future = None
|
||||||
|
self.status = "starting"
|
||||||
|
|
||||||
<<
|
@property
|
||||||
def fun(i, a):
|
def fun(self) -> Callable[[Any, ...], Any]:
|
||||||
sleep(1)
|
return ray.get(self._fun)
|
||||||
return a * i ** 2
|
|
||||||
pmap(fun, range(10), (3,))
|
|
||||||
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
||||||
|
|
||||||
equivalent to using the deco module:
|
@fun.setter
|
||||||
<<
|
def fun(self, fun: Callable[[Any, ...], Any]):
|
||||||
@concurrent
|
self._fun = ray.put(fun)
|
||||||
def fun(i, a):
|
|
||||||
time.sleep(1)
|
|
||||||
return a * i ** 2
|
|
||||||
|
|
||||||
@synchronized
|
@property
|
||||||
def run(iterator, a):
|
def args(self) -> tuple[Any, ...]:
|
||||||
res = []
|
return tuple([ray.get(arg) for arg in self._args])
|
||||||
for i in iterator:
|
|
||||||
res.append(fun(i, a))
|
|
||||||
return res
|
|
||||||
run(range(10), 3)
|
|
||||||
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
||||||
|
|
||||||
all equivalent to the serial for-loop:
|
@args.setter
|
||||||
<<
|
def args(self, args: tuple[Any, ...]) -> None:
|
||||||
a = 3
|
self._args = [ray.put(arg) for arg in args]
|
||||||
fun = []
|
|
||||||
for i in range(10):
|
@property
|
||||||
sleep(1)
|
def kwargs(self) -> dict[str, Any]:
|
||||||
fun.append(a * i ** 2)
|
return {key: ray.get(value) for key, value in self._kwargs.items()}
|
||||||
fun
|
|
||||||
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
@kwargs.setter
|
||||||
|
def kwargs(self, kwargs: dict[str, Any]) -> None:
|
||||||
|
self._kwargs = {key: ray.put(value) for key, value in kwargs.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> Any:
|
||||||
|
return ray.get(self._result)
|
||||||
|
|
||||||
|
@result.setter
|
||||||
|
def result(self, result: Any) -> None:
|
||||||
|
self._result = ray.put(result)
|
||||||
|
|
||||||
|
def __call__(self) -> Task:
|
||||||
|
if not self.done:
|
||||||
|
self.result = self.fun(*self.args, **self.kwargs) # noqa
|
||||||
|
self.done = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.done:
|
||||||
|
return f"Task {self.handle}, result: {self.result}"
|
||||||
|
else:
|
||||||
|
return f"Task {self.handle}"
|
||||||
|
|
||||||
|
|
||||||
|
class ParPool:
|
||||||
|
"""Parallel processing with addition of iterations at any time and request of that result any time after that.
|
||||||
|
The target function and its argument can be changed at any time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fun: Callable[[Any, ...], Any] = None,
|
||||||
|
args: tuple[Any] = None,
|
||||||
|
kwargs: dict[str, Any] = None,
|
||||||
|
n_processes: int = None,
|
||||||
|
bar: Bar = None,
|
||||||
|
):
|
||||||
|
self.handle = 0
|
||||||
|
self.tasks = {}
|
||||||
|
self.bar = bar
|
||||||
|
self.bar_lengths = {}
|
||||||
|
self.fun = fun
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
PoolSingleton(n_processes)
|
||||||
|
|
||||||
|
def __getstate__(self) -> NoReturn:
|
||||||
|
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
|
||||||
|
|
||||||
|
def __enter__(self) -> ParPool:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
|
||||||
|
self.add_task(
|
||||||
|
args=(n, *(() if self.args is None else self.args)),
|
||||||
|
handle=handle,
|
||||||
|
barlength=barlength,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_task(
|
||||||
|
self,
|
||||||
|
fun: Callable[[Any, ...], Any] = None,
|
||||||
|
args: tuple[Any, ...] = None,
|
||||||
|
kwargs: dict[str, Any] = None,
|
||||||
|
handle: Hashable = None,
|
||||||
|
barlength: int = 1,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if handle is None:
|
||||||
|
new_handle = self.handle
|
||||||
|
self.handle += 1
|
||||||
|
else:
|
||||||
|
new_handle = handle
|
||||||
|
if new_handle in self:
|
||||||
|
raise ValueError(f"handle {new_handle} already present")
|
||||||
|
task = Task(
|
||||||
|
new_handle,
|
||||||
|
fun or self.fun,
|
||||||
|
args or self.args,
|
||||||
|
kwargs or self.kwargs,
|
||||||
|
)
|
||||||
|
task.future = worker.remote(task)
|
||||||
|
self.tasks[new_handle] = task
|
||||||
|
self.bar_lengths[new_handle] = barlength
|
||||||
|
if handle is None:
|
||||||
|
return new_handle
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __setitem__(self, handle: Hashable, n: Any) -> None:
|
||||||
|
"""Add new iteration."""
|
||||||
|
self(n, handle=handle)
|
||||||
|
|
||||||
|
def __getitem__(self, handle: Hashable) -> Any:
|
||||||
|
"""Request result and delete its record. Wait if result not yet available."""
|
||||||
|
if handle not in self:
|
||||||
|
raise ValueError(f"No handle: {handle} in pool")
|
||||||
|
task = ray.get(self.tasks[handle].future)
|
||||||
|
return self.finalize_task(task)
|
||||||
|
|
||||||
|
def __contains__(self, handle: Hashable) -> bool:
|
||||||
|
return handle in self.tasks
|
||||||
|
|
||||||
|
def __delitem__(self, handle: Hashable) -> None:
|
||||||
|
self.tasks.pop(handle)
|
||||||
|
|
||||||
|
def finalize_task(self, task: Task) -> Any:
|
||||||
|
code, *args = task.status
|
||||||
|
getattr(self, code)(task, *args)
|
||||||
|
self.tasks.pop(task.handle)
|
||||||
|
return task.result
|
||||||
|
|
||||||
|
def get_newest(self) -> Optional[Any]:
|
||||||
|
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
|
||||||
|
while True:
|
||||||
|
for handle, task in self.tasks.items():
|
||||||
|
if handle in self.bar_lengths:
|
||||||
|
try:
|
||||||
|
task = ray.get(task.future, timeout=0.01)
|
||||||
|
return task.handle, self.finalize_task(task)
|
||||||
|
except ray.exceptions.GetTimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def task_error(self, task: Task, error: Exception) -> None:
|
||||||
|
if task.handle in self:
|
||||||
|
task = self.tasks[task.handle]
|
||||||
|
print(f"Error from process working on iteration {task.handle}:\n")
|
||||||
|
print(error)
|
||||||
|
print("Retrying in main process...")
|
||||||
|
task()
|
||||||
|
raise Exception(f"Function '{task.name}' cannot be executed by parfor, amend or execute in serial.")
|
||||||
|
|
||||||
|
def done(self, task: Task) -> None:
|
||||||
|
if task.handle in self: # if not, the task was restarted erroneously
|
||||||
|
self.tasks[task.handle] = task
|
||||||
|
if hasattr(self.bar, "update"):
|
||||||
|
self.bar.update(self.bar_lengths.pop(task.handle))
|
||||||
|
|
||||||
|
|
||||||
|
class PoolSingleton:
|
||||||
|
instance: PoolSingleton = None
|
||||||
|
cpu_count: int = int(os.cpu_count())
|
||||||
|
|
||||||
|
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
||||||
|
# restart if any workers have shut down or if we want to have a different number of processes
|
||||||
|
n_processes = n_processes or cls.cpu_count
|
||||||
|
if cls.instance is None or cls.instance.n_processes != n_processes:
|
||||||
|
cls.instance = super().__new__(cls)
|
||||||
|
cls.instance.n_processes = n_processes
|
||||||
|
if ray.is_initialized():
|
||||||
|
warnings.warn("not setting n_processes because parallel pool was already initialized")
|
||||||
|
else:
|
||||||
|
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||||
|
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
||||||
|
return cls.instance
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
nested: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def gmap(
|
||||||
|
fun: Callable[[Any, ...], Any],
|
||||||
|
iterable: Iterable[Any] = None,
|
||||||
|
args: tuple[Any, ...] = None,
|
||||||
|
kwargs: dict[str, Any] = None,
|
||||||
|
total: int = None,
|
||||||
|
desc: str = None,
|
||||||
|
bar: Bar | bool = True,
|
||||||
|
serial: bool = None,
|
||||||
|
n_processes: int = None,
|
||||||
|
yield_ordered: bool = True,
|
||||||
|
yield_index: bool = False,
|
||||||
|
**bar_kwargs: Any,
|
||||||
|
) -> Generator[Any, None, None]:
|
||||||
|
"""map a function fun to each iteration in iterable
|
||||||
|
use as a function: pmap
|
||||||
|
use as a decorator: parfor
|
||||||
|
best use: iterable is a generator and length is given to this function as 'total'
|
||||||
|
|
||||||
|
required:
|
||||||
|
fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs
|
||||||
|
iterable: iterable or iterator from which an item is given to fun as a first argument
|
||||||
|
optional:
|
||||||
|
args: tuple with other unnamed arguments to fun
|
||||||
|
kwargs: dict with other named arguments to fun
|
||||||
|
total: give the length of the iterator in cases where len(iterator) results in an error
|
||||||
|
desc: string with description of the progress bar
|
||||||
|
bar: bool enable progress bar,
|
||||||
|
or a callback function taking the number of passed iterations as an argument
|
||||||
|
serial: execute in series instead of parallel if True, None (default): let pmap decide
|
||||||
|
length: deprecated alias for total
|
||||||
|
n_processes: number of processes to use,
|
||||||
|
the parallel pool will be restarted if the current pool does not have the right number of processes
|
||||||
|
yield_ordered: return the result in the same order as the iterable
|
||||||
|
yield_index: return the index of the result too
|
||||||
|
**bar_kwargs: keywords arguments for tqdm.tqdm
|
||||||
|
|
||||||
|
output:
|
||||||
|
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
|
||||||
|
of the iterable / iterator
|
||||||
|
|
||||||
|
examples:
|
||||||
|
<< from time import sleep
|
||||||
|
<<
|
||||||
|
@parfor(range(10), (3,))
|
||||||
|
def fun(i, a):
|
||||||
|
sleep(1)
|
||||||
|
return a * i ** 2
|
||||||
|
fun
|
||||||
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
||||||
|
|
||||||
|
<<
|
||||||
|
def fun(i, a):
|
||||||
|
sleep(1)
|
||||||
|
return a * i ** 2
|
||||||
|
pmap(fun, range(10), (3,))
|
||||||
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
||||||
|
|
||||||
|
equivalent to using the deco module:
|
||||||
|
<<
|
||||||
|
@concurrent
|
||||||
|
def fun(i, a):
|
||||||
|
time.sleep(1)
|
||||||
|
return a * i ** 2
|
||||||
|
|
||||||
|
@synchronized
|
||||||
|
def run(iterator, a):
|
||||||
|
res = []
|
||||||
|
for i in iterator:
|
||||||
|
res.append(fun(i, a))
|
||||||
|
return res
|
||||||
|
run(range(10), 3)
|
||||||
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
||||||
|
|
||||||
|
all equivalent to the serial for-loop:
|
||||||
|
<<
|
||||||
|
a = 3
|
||||||
|
fun = []
|
||||||
|
for i in range(10):
|
||||||
|
sleep(1)
|
||||||
|
fun.append(a * i ** 2)
|
||||||
|
fun
|
||||||
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
||||||
"""
|
"""
|
||||||
if total is None and length is not None:
|
|
||||||
total = length
|
|
||||||
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=2)
|
|
||||||
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=3)
|
|
||||||
if terminator is not None:
|
|
||||||
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
|
|
||||||
DeprecationWarning, stacklevel=2)
|
|
||||||
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
|
|
||||||
DeprecationWarning, stacklevel=3)
|
|
||||||
is_chunked = isinstance(iterable, Chunks)
|
is_chunked = isinstance(iterable, Chunks)
|
||||||
if is_chunked:
|
if is_chunked:
|
||||||
chunk_fun = fun
|
chunk_fun = fun
|
||||||
@@ -201,13 +500,13 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
|||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if 'total' not in bar_kwargs:
|
if "total" not in bar_kwargs:
|
||||||
bar_kwargs['total'] = sum(iterable.lengths)
|
bar_kwargs["total"] = sum(iterable.lengths)
|
||||||
if 'desc' not in bar_kwargs:
|
if "desc" not in bar_kwargs:
|
||||||
bar_kwargs['desc'] = desc
|
bar_kwargs["desc"] = desc
|
||||||
if 'disable' not in bar_kwargs:
|
if "disable" not in bar_kwargs:
|
||||||
bar_kwargs['disable'] = not bar
|
bar_kwargs["disable"] = not bar
|
||||||
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
|
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)): # serial case
|
||||||
|
|
||||||
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
|
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
|
||||||
with tqdm(*args, **kwargs) as b:
|
with tqdm(*args, **kwargs) as b:
|
||||||
@@ -215,8 +514,9 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
|||||||
yield chunk
|
yield chunk
|
||||||
b.update(length)
|
b.update(length)
|
||||||
|
|
||||||
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
|
iterable = (
|
||||||
else tqdm_chunks(iterable, **bar_kwargs))
|
ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar) else tqdm_chunks(iterable, **bar_kwargs) # type: ignore
|
||||||
|
)
|
||||||
if is_chunked:
|
if is_chunked:
|
||||||
if yield_index:
|
if yield_index:
|
||||||
for i, c in enumerate(iterable):
|
for i, c in enumerate(iterable):
|
||||||
@@ -233,10 +533,10 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
|||||||
for c in iterable:
|
for c in iterable:
|
||||||
yield from chunk_fun(c, *args, **kwargs)
|
yield from chunk_fun(c, *args, **kwargs)
|
||||||
|
|
||||||
else: # parallel case
|
else: # parallel case
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack: # noqa
|
||||||
if callable(bar):
|
if callable(bar):
|
||||||
bar = stack.enter_context(ExternalBar(callback=bar))
|
bar = stack.enter_context(ExternalBar(callback=bar)) # noqa
|
||||||
else:
|
else:
|
||||||
bar = stack.enter_context(tqdm(**bar_kwargs))
|
bar = stack.enter_context(tqdm(**bar_kwargs))
|
||||||
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p: # type: ignore
|
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p: # type: ignore
|
||||||
@@ -287,12 +587,13 @@ def pmap(*args, **kwargs) -> list[Any]:
|
|||||||
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
|
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
|
||||||
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
|
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
|
||||||
return pmap(fun, *args, **kwargs)
|
return pmap(fun, *args, **kwargs)
|
||||||
|
|
||||||
return decfun
|
return decfun
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
|
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
|
||||||
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
|
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
|
||||||
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
|
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != "fun"}
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
|
||||||
from typing import Any, Callable, Protocol, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from numpy.typing import ArrayLike, DTypeLike
|
|
||||||
|
|
||||||
cpu_count = int(os.cpu_count())
|
|
||||||
|
|
||||||
|
|
||||||
class Bar(Protocol):
|
|
||||||
def update(self, n: int = 1) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class SharedArray(np.ndarray):
|
|
||||||
""" Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one
|
|
||||||
process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent
|
|
||||||
race conditions!
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __new__(cls, shape: int | Sequence[int], dtype: DTypeLike = float, shm: str | SharedMemory = None,
|
|
||||||
offset: int = 0, strides: tuple[int, int] = None, order: str = None) -> SharedArray:
|
|
||||||
if isinstance(shm, str):
|
|
||||||
shm = SharedMemory(shm)
|
|
||||||
elif shm is None:
|
|
||||||
shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize)
|
|
||||||
new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order)
|
|
||||||
new.shm = shm
|
|
||||||
return new
|
|
||||||
|
|
||||||
def __reduce__(self) -> tuple[Callable[[int | Sequence[int], DTypeLike, str], SharedArray],
|
|
||||||
tuple[int | tuple[int, ...], np.dtype, str]]:
|
|
||||||
return self.__class__, (self.shape, self.dtype, self.shm.name)
|
|
||||||
|
|
||||||
def __enter__(self) -> SharedArray:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
if hasattr(self, 'shm'):
|
|
||||||
self.shm.close()
|
|
||||||
self.shm.unlink()
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
if hasattr(self, 'shm'):
|
|
||||||
self.shm.close()
|
|
||||||
|
|
||||||
def __array_finalize__(self, obj: np.ndarray | None) -> None:
|
|
||||||
if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray):
|
|
||||||
raise TypeError('view casting to SharedArray is not implemented because right now we need to make a copy')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_array(cls, array: ArrayLike) -> SharedArray:
|
|
||||||
""" copy existing array into a SharedArray """
|
|
||||||
new = cls(array.shape, array.dtype)
|
|
||||||
new[:] = array[:]
|
|
||||||
return new
|
|
||||||
463
parfor/gil.py
463
parfor/gil.py
@@ -1,463 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import multiprocessing
|
|
||||||
from collections import UserDict
|
|
||||||
from contextlib import redirect_stderr, redirect_stdout
|
|
||||||
from os import cpu_count, devnull, getpid
|
|
||||||
from time import time
|
|
||||||
from traceback import format_exc
|
|
||||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
from .common import Bar
|
|
||||||
from .pickler import dumps, loads
|
|
||||||
|
|
||||||
|
|
||||||
class SharedMemory(UserDict):
|
|
||||||
def __init__(self, manager: multiprocessing.Manager) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.data = manager.dict() # item_id: dilled representation of object
|
|
||||||
self.references = manager.dict() # item_id: counter
|
|
||||||
self.references_lock = manager.Lock()
|
|
||||||
self.cache = {} # item_id: object
|
|
||||||
self.trash_can = {}
|
|
||||||
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
|
|
||||||
|
|
||||||
def __getstate__(self) -> tuple[dict[int, bytes], dict[int, int], multiprocessing.Lock]:
|
|
||||||
return self.data, self.references, self.references_lock
|
|
||||||
|
|
||||||
def __setitem__(self, item_id: int, value: Any) -> None:
|
|
||||||
if item_id not in self: # values will not be changed
|
|
||||||
try:
|
|
||||||
self.data[item_id] = False, value
|
|
||||||
except Exception: # only use our pickler when necessary # noqa
|
|
||||||
self.data[item_id] = True, dumps(value, recurse=True)
|
|
||||||
with self.references_lock:
|
|
||||||
try:
|
|
||||||
self.references[item_id] += 1
|
|
||||||
except KeyError:
|
|
||||||
self.references[item_id] = 1
|
|
||||||
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
|
|
||||||
|
|
||||||
def add_item(self, item: Any, pool_id: int, task_handle: Hashable) -> int:
|
|
||||||
item_id = id(item)
|
|
||||||
self[item_id] = item
|
|
||||||
if item_id in self.pool_ids:
|
|
||||||
self.pool_ids[item_id].add((pool_id, task_handle))
|
|
||||||
else:
|
|
||||||
self.pool_ids[item_id] = {(pool_id, task_handle)}
|
|
||||||
return item_id
|
|
||||||
|
|
||||||
def remove_pool(self, pool_id: int) -> None:
|
|
||||||
""" remove objects used by a pool that won't be needed anymore """
|
|
||||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
|
|
||||||
for item_id in set(self.data.keys()) - set(self.pool_ids):
|
|
||||||
del self[item_id]
|
|
||||||
self.garbage_collect()
|
|
||||||
|
|
||||||
def remove_task(self, pool_id: int, task: Task) -> None:
|
|
||||||
""" remove objects used by a task that won't be needed anymore """
|
|
||||||
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
|
|
||||||
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
|
|
||||||
del self[item_id]
|
|
||||||
self.garbage_collect()
|
|
||||||
|
|
||||||
# worker functions
|
|
||||||
def __setstate__(self, state: dict) -> None:
|
|
||||||
self.data, self.references, self.references_lock = state
|
|
||||||
self.cache = {}
|
|
||||||
self.trash_can = None
|
|
||||||
|
|
||||||
def __getitem__(self, item_id: int) -> Any:
|
|
||||||
if item_id not in self.cache:
|
|
||||||
dilled, value = self.data[item_id]
|
|
||||||
if dilled:
|
|
||||||
value = loads(value)
|
|
||||||
with self.references_lock:
|
|
||||||
if item_id in self.references:
|
|
||||||
self.references[item_id] += 1
|
|
||||||
else:
|
|
||||||
self.references[item_id] = 1
|
|
||||||
self.cache[item_id] = value
|
|
||||||
return self.cache[item_id]
|
|
||||||
|
|
||||||
def garbage_collect(self) -> None:
|
|
||||||
""" clean up the cache """
|
|
||||||
for item_id in set(self.cache) - set(self.data.keys()):
|
|
||||||
with self.references_lock:
|
|
||||||
try:
|
|
||||||
self.references[item_id] -= 1
|
|
||||||
except KeyError:
|
|
||||||
self.references[item_id] = 0
|
|
||||||
if self.trash_can is not None and item_id not in self.trash_can:
|
|
||||||
self.trash_can[item_id] = self.cache[item_id]
|
|
||||||
del self.cache[item_id]
|
|
||||||
|
|
||||||
if self.trash_can:
|
|
||||||
for item_id in set(self.trash_can):
|
|
||||||
if self.references[item_id] == 0:
|
|
||||||
# make sure every process removed the object before removing it in the parent
|
|
||||||
del self.references[item_id]
|
|
||||||
del self.trash_can[item_id]
|
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
|
||||||
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: Hashable, fun: Callable[[Any, ...], Any],
|
|
||||||
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
|
|
||||||
self.pool_id = pool_id
|
|
||||||
self.handle = handle
|
|
||||||
self.fun = shared_memory.add_item(fun, pool_id, handle)
|
|
||||||
self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
|
|
||||||
self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
|
|
||||||
for item in kwargs.items()]
|
|
||||||
self.name = fun.__name__ if hasattr(fun, '__name__') else None
|
|
||||||
self.done = False
|
|
||||||
self.result = None
|
|
||||||
self.pid = None
|
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
|
||||||
state = self.__dict__
|
|
||||||
if self.result is not None:
|
|
||||||
state['result'] = dumps(self.result, recurse=True)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
||||||
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
|
|
||||||
if state['result'] is None:
|
|
||||||
self.result = None
|
|
||||||
else:
|
|
||||||
self.result = loads(state['result'])
|
|
||||||
|
|
||||||
def __call__(self, shared_memory: SharedMemory) -> Task:
|
|
||||||
if not self.done:
|
|
||||||
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa
|
|
||||||
args = [shared_memory[arg] for arg in self.args]
|
|
||||||
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
|
|
||||||
self.result = fun(*args, **kwargs) # noqa
|
|
||||||
self.done = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
if self.done:
|
|
||||||
return f'Task {self.handle}, result: {self.result}'
|
|
||||||
else:
|
|
||||||
return f'Task {self.handle}'
|
|
||||||
|
|
||||||
|
|
||||||
class Context(multiprocessing.context.SpawnContext):
|
|
||||||
""" Provide a context where child processes never are daemonic. """
|
|
||||||
class Process(multiprocessing.context.SpawnProcess):
|
|
||||||
@property
|
|
||||||
def daemon(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@daemon.setter
|
|
||||||
def daemon(self, value: bool) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ParPool:
|
|
||||||
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
|
|
||||||
The target function and its argument can be changed at any time.
|
|
||||||
"""
|
|
||||||
def __init__(self, fun: Callable[[Any, ...], Any] = None,
|
|
||||||
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
|
|
||||||
self.id = id(self)
|
|
||||||
self.handle = 0
|
|
||||||
self.tasks = {}
|
|
||||||
self.bar = bar
|
|
||||||
self.bar_lengths = {}
|
|
||||||
self.spool = PoolSingleton(n_processes, self)
|
|
||||||
self.manager = self.spool.manager
|
|
||||||
self.fun = fun
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
self.is_started = False
|
|
||||||
|
|
||||||
def __getstate__(self) -> NoReturn:
|
|
||||||
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
|
|
||||||
|
|
||||||
def __enter__(self) -> ParPool:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self.spool.remove_pool(self.id)
|
|
||||||
|
|
||||||
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
|
|
||||||
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
|
|
||||||
|
|
||||||
def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
|
|
||||||
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
|
|
||||||
if self.id not in self.spool.pools:
|
|
||||||
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
|
|
||||||
if handle is None:
|
|
||||||
new_handle = self.handle
|
|
||||||
self.handle += 1
|
|
||||||
else:
|
|
||||||
new_handle = handle
|
|
||||||
if new_handle in self:
|
|
||||||
raise ValueError(f'handle {new_handle} already present')
|
|
||||||
task = Task(self.spool.shared_memory, self.id, new_handle,
|
|
||||||
fun or self.fun, args or self.args, kwargs or self.kwargs)
|
|
||||||
self.tasks[new_handle] = task
|
|
||||||
self.spool.add_task(task)
|
|
||||||
self.bar_lengths[new_handle] = barlength
|
|
||||||
if handle is None:
|
|
||||||
return new_handle
|
|
||||||
|
|
||||||
def __setitem__(self, handle: Hashable, n: Any) -> None:
|
|
||||||
""" Add new iteration. """
|
|
||||||
self(n, handle=handle)
|
|
||||||
|
|
||||||
def __getitem__(self, handle: Hashable) -> Any:
|
|
||||||
""" Request result and delete its record. Wait if result not yet available. """
|
|
||||||
if handle not in self:
|
|
||||||
raise ValueError(f'No handle: {handle} in pool')
|
|
||||||
while not self.tasks[handle].done:
|
|
||||||
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
|
|
||||||
and not self.working:
|
|
||||||
for _ in range(10): # wait some time while processing possible new messages
|
|
||||||
self.spool.get_from_queue()
|
|
||||||
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
|
|
||||||
and not self.working:
|
|
||||||
# retry a task if the process was killed while working on a task
|
|
||||||
self.spool.add_task(self.tasks[handle])
|
|
||||||
warn(f'Task {handle} was restarted because the process working on it was probably killed.')
|
|
||||||
result = self.tasks[handle].result
|
|
||||||
self.tasks.pop(handle)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __contains__(self, handle: Hashable) -> bool:
|
|
||||||
return handle in self.tasks
|
|
||||||
|
|
||||||
def __delitem__(self, handle: Hashable) -> None:
|
|
||||||
self.tasks.pop(handle)
|
|
||||||
|
|
||||||
def get_newest(self) -> Any:
|
|
||||||
return self.spool.get_newest_for_pool(self)
|
|
||||||
|
|
||||||
def process_queue(self) -> None:
|
|
||||||
self.spool.process_queue()
|
|
||||||
|
|
||||||
def task_error(self, handle: Hashable, error: Exception) -> None:
|
|
||||||
if handle in self:
|
|
||||||
task = self.tasks[handle]
|
|
||||||
print(f'Error from process working on iteration {handle}:\n')
|
|
||||||
print(error)
|
|
||||||
print('Retrying in main process...')
|
|
||||||
task(self.spool.shared_memory)
|
|
||||||
self.spool.shared_memory.remove_task(self.id, task)
|
|
||||||
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
|
|
||||||
|
|
||||||
def done(self, task: Task) -> None:
|
|
||||||
if task.handle in self: # if not, the task was restarted erroneously
|
|
||||||
self.tasks[task.handle] = task
|
|
||||||
if hasattr(self.bar, 'update'):
|
|
||||||
self.bar.update(self.bar_lengths.pop(task.handle))
|
|
||||||
self.spool.shared_memory.remove_task(self.id, task)
|
|
||||||
|
|
||||||
def started(self, handle: Hashable, pid: int) -> None:
|
|
||||||
self.is_started = True
|
|
||||||
if handle in self: # if not, the task was restarted erroneously
|
|
||||||
self.tasks[handle].pid = pid
|
|
||||||
|
|
||||||
@property
|
|
||||||
def working(self) -> bool:
|
|
||||||
return not all([task.pid is None for task in self.tasks.values()])
|
|
||||||
|
|
||||||
|
|
||||||
class PoolSingleton:
|
|
||||||
""" There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a
|
|
||||||
new pool. The pool will close itself after 10 minutes of idle time. """
|
|
||||||
|
|
||||||
instance = None
|
|
||||||
cpu_count = cpu_count()
|
|
||||||
|
|
||||||
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
|
||||||
# restart if any workers have shut down or if we want to have a different number of processes
|
|
||||||
if cls.instance is not None:
|
|
||||||
if (cls.instance.n_workers.value < cls.instance.n_processes or
|
|
||||||
cls.instance.n_processes != (n_processes or cls.cpu_count)):
|
|
||||||
cls.instance.close()
|
|
||||||
if cls.instance is None or not cls.instance.is_alive:
|
|
||||||
new = super().__new__(cls)
|
|
||||||
new.n_processes = n_processes or cls.cpu_count
|
|
||||||
new.instance = new
|
|
||||||
new.is_started = False
|
|
||||||
ctx = Context()
|
|
||||||
new.n_workers = ctx.Value('i', new.n_processes)
|
|
||||||
new.event = ctx.Event()
|
|
||||||
new.queue_in = ctx.Queue(3 * new.n_processes)
|
|
||||||
new.queue_out = ctx.Queue(new.n_processes)
|
|
||||||
new.manager = ctx.Manager()
|
|
||||||
new.shared_memory = SharedMemory(new.manager)
|
|
||||||
new.pool = ctx.Pool(new.n_processes,
|
|
||||||
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
|
|
||||||
new.is_alive = True
|
|
||||||
new.handle = 0
|
|
||||||
new.pools = {}
|
|
||||||
new.time_out = None
|
|
||||||
cls.instance = new
|
|
||||||
return cls.instance
|
|
||||||
|
|
||||||
def __init__(self, n_processes: int = None, parpool: Parpool = None) -> None: # noqa
|
|
||||||
if parpool is not None:
|
|
||||||
self.pools[parpool.id] = parpool
|
|
||||||
if self.time_out is not None:
|
|
||||||
self.time_out.cancel()
|
|
||||||
self.time_out = None
|
|
||||||
|
|
||||||
def __getstate__(self) -> NoReturn:
|
|
||||||
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
|
|
||||||
|
|
||||||
def remove_pool(self, pool_id: int) -> None:
|
|
||||||
self.shared_memory.remove_pool(pool_id)
|
|
||||||
if pool_id in self.pools:
|
|
||||||
self.pools.pop(pool_id)
|
|
||||||
if len(self.pools) == 0:
|
|
||||||
try:
|
|
||||||
self.time_out = asyncio.get_running_loop().call_later(600, self.close) # noqa
|
|
||||||
except RuntimeError:
|
|
||||||
self.time_out = asyncio.new_event_loop().call_later(600, self.close) # noqa
|
|
||||||
|
|
||||||
def error(self, error: Exception) -> NoReturn:
|
|
||||||
self.close()
|
|
||||||
raise Exception(f'Error occurred in worker: {error}')
|
|
||||||
|
|
||||||
def process_queue(self) -> None:
|
|
||||||
while self.get_from_queue():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_from_queue(self) -> bool:
|
|
||||||
""" Get an item from the queue and store it, return True if more messages are waiting. """
|
|
||||||
try:
|
|
||||||
code, pool_id, *args = self.queue_out.get(True, 0.02)
|
|
||||||
if pool_id is None:
|
|
||||||
getattr(self, code)(*args)
|
|
||||||
elif pool_id in self.pools:
|
|
||||||
getattr(self.pools[pool_id], code)(*args)
|
|
||||||
return True
|
|
||||||
except multiprocessing.queues.Empty: # noqa
|
|
||||||
for pool in self.pools.values():
|
|
||||||
for handle, task in pool.tasks.items(): # retry a task if the process doing it was killed
|
|
||||||
if task.pid is not None \
|
|
||||||
and task.pid not in [child.pid for child in multiprocessing.active_children()]:
|
|
||||||
self.queue_in.put(task)
|
|
||||||
warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.')
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_task(self, task: Task) -> None:
|
|
||||||
""" Add new iteration, using optional manually defined handle."""
|
|
||||||
if self.is_alive and not self.event.is_set():
|
|
||||||
while self.queue_in.full():
|
|
||||||
self.get_from_queue()
|
|
||||||
self.queue_in.put(task)
|
|
||||||
self.shared_memory.garbage_collect()
|
|
||||||
|
|
||||||
def get_newest_for_pool(self, pool: ParPool) -> tuple[Hashable, Any]:
|
|
||||||
""" Request the newest key and result and delete its record. Wait if result not yet available. """
|
|
||||||
while len(pool.tasks):
|
|
||||||
self.get_from_queue()
|
|
||||||
for task in pool.tasks.values():
|
|
||||||
if task.done:
|
|
||||||
handle, result = task.handle, task.result
|
|
||||||
pool.tasks.pop(handle)
|
|
||||||
return handle, result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def close(cls) -> None:
|
|
||||||
if cls.instance is not None:
|
|
||||||
instance = cls.instance
|
|
||||||
cls.instance = None
|
|
||||||
if instance.time_out is not None:
|
|
||||||
instance.time_out.cancel()
|
|
||||||
|
|
||||||
def empty_queue(queue):
|
|
||||||
try:
|
|
||||||
if not queue._closed: # noqa
|
|
||||||
while not queue.empty():
|
|
||||||
try:
|
|
||||||
queue.get(True, 0.02)
|
|
||||||
except multiprocessing.queues.Empty: # noqa
|
|
||||||
pass
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close_queue(queue: multiprocessing.queues.Queue) -> None:
|
|
||||||
empty_queue(queue) # noqa
|
|
||||||
if not queue._closed: # noqa
|
|
||||||
queue.close()
|
|
||||||
queue.join_thread()
|
|
||||||
|
|
||||||
if instance.is_alive:
|
|
||||||
instance.is_alive = False
|
|
||||||
instance.event.set()
|
|
||||||
instance.pool.close()
|
|
||||||
t = time()
|
|
||||||
while instance.n_workers.value:
|
|
||||||
empty_queue(instance.queue_in)
|
|
||||||
empty_queue(instance.queue_out)
|
|
||||||
if time() - t > 10:
|
|
||||||
warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.')
|
|
||||||
instance.pool.terminate()
|
|
||||||
break
|
|
||||||
empty_queue(instance.queue_in)
|
|
||||||
empty_queue(instance.queue_out)
|
|
||||||
instance.pool.join()
|
|
||||||
close_queue(instance.queue_in)
|
|
||||||
close_queue(instance.queue_out)
|
|
||||||
instance.manager.shutdown()
|
|
||||||
instance.handle = 0
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
|
||||||
""" Manages executing the target function which will be executed in different processes. """
|
|
||||||
nested = False
|
|
||||||
|
|
||||||
def __init__(self, shared_memory: SharedMemory, queue_in: multiprocessing.queues.Queue,
|
|
||||||
queue_out: multiprocessing.queues.Queue, n_workers: multiprocessing.Value,
|
|
||||||
event: multiprocessing.Event) -> None:
|
|
||||||
self.shared_memory = shared_memory
|
|
||||||
self.queue_in = queue_in
|
|
||||||
self.queue_out = queue_out
|
|
||||||
self.n_workers = n_workers
|
|
||||||
self.event = event
|
|
||||||
|
|
||||||
def add_to_queue(self, *args: Any) -> None:
|
|
||||||
while not self.event.is_set():
|
|
||||||
try:
|
|
||||||
self.queue_out.put(args, timeout=0.1)
|
|
||||||
break
|
|
||||||
except multiprocessing.queues.Full: # noqa
|
|
||||||
continue
|
|
||||||
|
|
||||||
def __call__(self) -> None:
|
|
||||||
Worker.nested = True
|
|
||||||
pid = getpid()
|
|
||||||
last_active_time = time()
|
|
||||||
while not self.event.is_set() and time() - last_active_time < 600:
|
|
||||||
try:
|
|
||||||
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
|
|
||||||
task = self.queue_in.get(True, 0.02)
|
|
||||||
try:
|
|
||||||
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
|
||||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
|
||||||
except Exception: # noqa
|
|
||||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
|
||||||
self.event.set()
|
|
||||||
self.shared_memory.garbage_collect()
|
|
||||||
last_active_time = time()
|
|
||||||
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa
|
|
||||||
pass
|
|
||||||
except Exception: # noqa
|
|
||||||
self.add_to_queue('error', None, format_exc())
|
|
||||||
self.event.set()
|
|
||||||
self.shared_memory.garbage_collect()
|
|
||||||
for child in multiprocessing.active_children():
|
|
||||||
child.kill()
|
|
||||||
with self.n_workers:
|
|
||||||
self.n_workers.value -= 1
|
|
||||||
158
parfor/nogil.py
158
parfor/nogil.py
@@ -1,158 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
from os import cpu_count
|
|
||||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
|
||||||
|
|
||||||
from .common import Bar
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
|
||||||
nested = False
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PoolSingleton:
|
|
||||||
cpu_count = cpu_count()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
|
||||||
def __init__(self, queue: queue.Queue, handle: Hashable, fun: Callable[[Any, ...], Any], # noqa
|
|
||||||
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
|
|
||||||
self.queue = queue
|
|
||||||
self.handle = handle
|
|
||||||
self.fun = fun
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = {} if kwargs is None else kwargs
|
|
||||||
self.name = fun.__name__ if hasattr(fun, '__name__') else None
|
|
||||||
self.started = False
|
|
||||||
self.done = False
|
|
||||||
self.result = None
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
if not self.done:
|
|
||||||
self.result = self.fun(*self.args, **self.kwargs)
|
|
||||||
try:
|
|
||||||
self.queue.put(self.handle)
|
|
||||||
except queue.ShutDown:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
if self.done:
|
|
||||||
return f'Task {self.handle}, result: {self.result}'
|
|
||||||
else:
|
|
||||||
return f'Task {self.handle}'
|
|
||||||
|
|
||||||
|
|
||||||
class ParPool:
|
|
||||||
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
|
|
||||||
The target function and its argument can be changed at any time.
|
|
||||||
"""
|
|
||||||
def __init__(self, fun: Callable[[Any, ...], Any] = None,
|
|
||||||
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
|
|
||||||
self.queue = queue.Queue()
|
|
||||||
self.handle = 0
|
|
||||||
self.tasks = {}
|
|
||||||
self.bar = bar
|
|
||||||
self.bar_lengths = {}
|
|
||||||
self.fun = fun
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
self.n_processes = n_processes or PoolSingleton.cpu_count
|
|
||||||
self.threads = {}
|
|
||||||
|
|
||||||
def __getstate__(self) -> NoReturn:
|
|
||||||
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
|
|
||||||
|
|
||||||
def __enter__(self) -> ParPool:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self.queue.shutdown() # noqa python3.13
|
|
||||||
for thread in self.threads.values():
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
|
|
||||||
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
|
|
||||||
|
|
||||||
def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
|
|
||||||
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
|
|
||||||
if handle is None:
|
|
||||||
new_handle = self.handle
|
|
||||||
self.handle += 1
|
|
||||||
else:
|
|
||||||
new_handle = handle
|
|
||||||
if new_handle in self:
|
|
||||||
raise ValueError(f'handle {new_handle} already present')
|
|
||||||
task = Task(self.queue, new_handle, fun or self.fun, args or self.args, kwargs or self.kwargs)
|
|
||||||
while len(self.threads) > self.n_processes:
|
|
||||||
self.get_from_queue()
|
|
||||||
thread = threading.Thread(target=task)
|
|
||||||
thread.start()
|
|
||||||
self.threads[new_handle] = thread
|
|
||||||
self.tasks[new_handle] = task
|
|
||||||
self.bar_lengths[new_handle] = barlength
|
|
||||||
if handle is None:
|
|
||||||
return new_handle
|
|
||||||
|
|
||||||
def __setitem__(self, handle: Hashable, n: Any) -> None:
|
|
||||||
""" Add new iteration. """
|
|
||||||
self(n, handle=handle)
|
|
||||||
|
|
||||||
def __getitem__(self, handle: Hashable) -> Any:
|
|
||||||
""" Request result and delete its record. Wait if result not yet available. """
|
|
||||||
if handle not in self:
|
|
||||||
raise ValueError(f'No handle: {handle} in pool')
|
|
||||||
while not self.tasks[handle].done:
|
|
||||||
self.get_from_queue()
|
|
||||||
task = self.tasks.pop(handle)
|
|
||||||
return task.result
|
|
||||||
|
|
||||||
def __contains__(self, handle: Hashable) -> bool:
|
|
||||||
return handle in self.tasks
|
|
||||||
|
|
||||||
def __delitem__(self, handle: Hashable) -> None:
|
|
||||||
self.tasks.pop(handle)
|
|
||||||
|
|
||||||
def get_from_queue(self) -> bool:
|
|
||||||
""" Get an item from the queue and store it, return True if more messages are waiting. """
|
|
||||||
try:
|
|
||||||
handle = self.queue.get(True, 0.02)
|
|
||||||
self.done(handle)
|
|
||||||
return True
|
|
||||||
except (queue.Empty, queue.ShutDown):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_newest(self) -> Any:
|
|
||||||
""" Request the newest key and result and delete its record. Wait if result not yet available. """
|
|
||||||
while len(self.tasks):
|
|
||||||
self.get_from_queue()
|
|
||||||
for task in self.tasks.values():
|
|
||||||
if task.done:
|
|
||||||
handle, result = task.handle, task.result
|
|
||||||
self.tasks.pop(handle)
|
|
||||||
return handle, result
|
|
||||||
|
|
||||||
def process_queue(self) -> None:
|
|
||||||
while self.get_from_queue():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def done(self, handle: Hashable) -> None:
|
|
||||||
thread = self.threads.pop(handle)
|
|
||||||
thread.join()
|
|
||||||
task = self.tasks[handle]
|
|
||||||
task.done = True
|
|
||||||
if hasattr(self.bar, 'update'):
|
|
||||||
self.bar.update(self.bar_lengths.pop(handle))
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import copyreg
|
|
||||||
from io import BytesIO
|
|
||||||
from pickle import PicklingError
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
import dill
|
|
||||||
|
|
||||||
loads = dill.loads
|
|
||||||
|
|
||||||
|
|
||||||
class CouldNotBePickled:
|
|
||||||
def __init__(self, class_name: str) -> None:
|
|
||||||
self.class_name = class_name
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
|
|
||||||
return cls, (type(item).__name__,)
|
|
||||||
|
|
||||||
|
|
||||||
class Pickler(dill.Pickler):
|
|
||||||
""" Overload dill to ignore unpicklable parts of objects.
|
|
||||||
You probably didn't want to use these parts anyhow.
|
|
||||||
However, if you did, you'll have to find some way to make them picklable.
|
|
||||||
"""
|
|
||||||
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
|
||||||
""" Copied from pickle and amended. """
|
|
||||||
self.framer.commit_frame()
|
|
||||||
|
|
||||||
# Check for persistent id (defined by a subclass)
|
|
||||||
pid = self.persistent_id(obj)
|
|
||||||
if pid is not None and save_persistent_id:
|
|
||||||
self.save_pers(pid)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check the memo
|
|
||||||
x = self.memo.get(id(obj))
|
|
||||||
if x is not None:
|
|
||||||
self.write(self.get(x[0]))
|
|
||||||
return
|
|
||||||
|
|
||||||
rv = NotImplemented
|
|
||||||
reduce = getattr(self, "reducer_override", None)
|
|
||||||
if reduce is not None:
|
|
||||||
rv = reduce(obj)
|
|
||||||
|
|
||||||
if rv is NotImplemented:
|
|
||||||
# Check the type dispatch table
|
|
||||||
t = type(obj)
|
|
||||||
f = self.dispatch.get(t)
|
|
||||||
if f is not None:
|
|
||||||
f(self, obj) # Call unbound method with explicit self
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check private dispatch table if any, or else
|
|
||||||
# copyreg.dispatch_table
|
|
||||||
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
|
||||||
if reduce is not None:
|
|
||||||
rv = reduce(obj)
|
|
||||||
else:
|
|
||||||
# Check for a class with a custom metaclass; treat as regular
|
|
||||||
# class
|
|
||||||
if issubclass(t, type):
|
|
||||||
self.save_global(obj)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for a __reduce_ex__ method, fall back to __reduce__
|
|
||||||
reduce = getattr(obj, "__reduce_ex__", None)
|
|
||||||
try:
|
|
||||||
if reduce is not None:
|
|
||||||
rv = reduce(self.proto)
|
|
||||||
else:
|
|
||||||
reduce = getattr(obj, "__reduce__", None)
|
|
||||||
if reduce is not None:
|
|
||||||
rv = reduce()
|
|
||||||
else:
|
|
||||||
raise PicklingError("Can't pickle %r object: %r" %
|
|
||||||
(t.__name__, obj))
|
|
||||||
except Exception: # noqa
|
|
||||||
rv = CouldNotBePickled.reduce(obj)
|
|
||||||
|
|
||||||
# Check for string returned by reduce(), meaning "save as global"
|
|
||||||
if isinstance(rv, str):
|
|
||||||
try:
|
|
||||||
self.save_global(obj, rv)
|
|
||||||
except Exception: # noqa
|
|
||||||
self.save_global(obj, CouldNotBePickled.reduce(obj))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Assert that reduce() returned a tuple
|
|
||||||
if not isinstance(rv, tuple):
|
|
||||||
raise PicklingError("%s must return string or tuple" % reduce)
|
|
||||||
|
|
||||||
# Assert that it returned an appropriately sized tuple
|
|
||||||
length = len(rv)
|
|
||||||
if not (2 <= length <= 6):
|
|
||||||
raise PicklingError("Tuple returned by %s must have "
|
|
||||||
"two to six elements" % reduce)
|
|
||||||
|
|
||||||
# Save the reduce() output and finally memoize the object
|
|
||||||
try:
|
|
||||||
self.save_reduce(obj=obj, *rv)
|
|
||||||
except Exception: # noqa
|
|
||||||
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
|
||||||
|
|
||||||
|
|
||||||
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
|
|
||||||
**kwds: Any) -> bytes:
|
|
||||||
"""pickle an object to a string"""
|
|
||||||
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
|
||||||
_kwds = kwds.copy()
|
|
||||||
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
|
||||||
with BytesIO() as file:
|
|
||||||
Pickler(file, protocol, **_kwds).dump(obj)
|
|
||||||
return file.getvalue()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "parfor"
|
name = "parfor"
|
||||||
version = "2025.1.0"
|
version = "2026.1.0"
|
||||||
description = "A package to mimic the use of parfor as done in Matlab."
|
description = "A package to mimic the use of parfor as done in Matlab."
|
||||||
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
||||||
license = "GPLv3"
|
license = "GPLv3"
|
||||||
@@ -17,6 +17,10 @@ pytest = { version = "*", optional = true }
|
|||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
test = ["pytest", "numpy"]
|
test = ["pytest", "numpy"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 119
|
||||||
|
indent-width = 4
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
line_length = 119
|
line_length = 119
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from os import getpid
|
|
||||||
from time import sleep
|
|
||||||
from typing import Any, Iterator, Optional, Sequence
|
from typing import Any, Iterator, Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,14 +8,6 @@ import pytest
|
|||||||
|
|
||||||
from parfor import Chunks, ParPool, SharedArray, parfor, pmap
|
from parfor import Chunks, ParPool, SharedArray, parfor, pmap
|
||||||
|
|
||||||
try:
|
|
||||||
if sys._is_gil_enabled(): # noqa
|
|
||||||
gil = True
|
|
||||||
else:
|
|
||||||
gil = False
|
|
||||||
except Exception: # noqa
|
|
||||||
gil = True
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceIterator:
|
class SequenceIterator:
|
||||||
def __init__(self, sequence: Sequence) -> None:
|
def __init__(self, sequence: Sequence) -> None:
|
||||||
@@ -56,7 +45,7 @@ def iterators() -> tuple[Iterator, Optional[int]]:
|
|||||||
yield Iterable(range(10)), 10
|
yield Iterable(range(10)), 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('iterator', iterators())
|
@pytest.mark.parametrize("iterator", iterators())
|
||||||
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
|
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
|
||||||
chunks = Chunks(iterator[0], size=2, length=iterator[1])
|
chunks = Chunks(iterator[0], size=2, length=iterator[1])
|
||||||
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
|
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
|
||||||
@@ -66,7 +55,7 @@ def test_parpool() -> None:
|
|||||||
def fun(i, j, k) -> int: # noqa
|
def fun(i, j, k) -> int: # noqa
|
||||||
return i * j * k
|
return i * j * k
|
||||||
|
|
||||||
with ParPool(fun, (3,), {'k': 2}) as pool: # noqa
|
with ParPool(fun, (3,), {"k": 2}) as pool: # noqa
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
pool[i] = i
|
pool[i] = i
|
||||||
|
|
||||||
@@ -74,40 +63,66 @@ def test_parpool() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_parfor() -> None:
|
def test_parfor() -> None:
|
||||||
@parfor(range(10), (3,), {'k': 2})
|
@parfor(range(10), (3,), {"k": 2})
|
||||||
def fun(i, j, k):
|
def fun(i, j, k):
|
||||||
return i * j * k
|
return i * j * k
|
||||||
|
|
||||||
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('serial', (True, False))
|
@pytest.mark.parametrize("serial", (True, False))
|
||||||
def test_pmap(serial) -> None:
|
def test_pmap(serial) -> None:
|
||||||
def fun(i, j, k):
|
def fun(i, j, k):
|
||||||
return i * j * k
|
return i * j * k
|
||||||
|
|
||||||
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
|
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial) == [
|
||||||
|
0,
|
||||||
|
6,
|
||||||
|
12,
|
||||||
|
18,
|
||||||
|
24,
|
||||||
|
30,
|
||||||
|
36,
|
||||||
|
42,
|
||||||
|
48,
|
||||||
|
54,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('serial', (True, False))
|
@pytest.mark.parametrize("serial", (True, False))
|
||||||
def test_pmap_with_idx(serial) -> None:
|
def test_pmap_with_idx(serial) -> None:
|
||||||
def fun(i, j, k):
|
def fun(i, j, k):
|
||||||
return i * j * k
|
return i * j * k
|
||||||
|
|
||||||
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
|
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial, yield_index=True) == [
|
||||||
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
|
(0, 0),
|
||||||
|
(1, 6),
|
||||||
|
(2, 12),
|
||||||
|
(3, 18),
|
||||||
|
(4, 24),
|
||||||
|
(5, 30),
|
||||||
|
(6, 36),
|
||||||
|
(7, 42),
|
||||||
|
(8, 48),
|
||||||
|
(9, 54),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('serial', (True, False))
|
@pytest.mark.parametrize("serial", (True, False))
|
||||||
def test_pmap_chunks(serial) -> None:
|
def test_pmap_chunks(serial) -> None:
|
||||||
def fun(i, j, k):
|
def fun(i, j, k):
|
||||||
return [i_ * j * k for i_ in i]
|
return [i_ * j * k for i_ in i]
|
||||||
|
|
||||||
chunks = Chunks(range(10), size=2)
|
chunks = Chunks(range(10), size=2)
|
||||||
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
|
assert pmap(fun, chunks, (3,), {"k": 2}, serial=serial) == [
|
||||||
|
[0, 6],
|
||||||
|
[12, 18],
|
||||||
|
[24, 30],
|
||||||
|
[36, 42],
|
||||||
|
[48, 54],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
|
|
||||||
def test_id_reuse() -> None:
|
def test_id_reuse() -> None:
|
||||||
def fun(i):
|
def fun(i):
|
||||||
return i[0].a
|
return i[0].a
|
||||||
@@ -126,18 +141,6 @@ def test_id_reuse() -> None:
|
|||||||
assert all([i == j for i, j in enumerate(a)])
|
assert all([i == j for i, j in enumerate(a)])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
|
|
||||||
@pytest.mark.parametrize('n_processes', (2, 4, 6))
|
|
||||||
def test_n_processes(n_processes) -> None:
|
|
||||||
|
|
||||||
@parfor(range(12), n_processes=n_processes)
|
|
||||||
def fun(i): # noqa
|
|
||||||
sleep(0.25)
|
|
||||||
return getpid()
|
|
||||||
|
|
||||||
assert len(set(fun)) == n_processes
|
|
||||||
|
|
||||||
|
|
||||||
def test_shared_array() -> None:
|
def test_shared_array() -> None:
|
||||||
def fun(i, a):
|
def fun(i, a):
|
||||||
a[i] = i
|
a[i] = i
|
||||||
|
|||||||
Reference in New Issue
Block a user