From 5528b54edeb721b3619314c92252d895f38a92dc Mon Sep 17 00:00:00 2001 From: Wim Pomp Date: Thu, 8 Jan 2026 00:48:14 +0100 Subject: [PATCH] Now based on ray, which enables nested parallel computations. --- .github/workflows/pytest.yml | 4 +- parfor/__init__.py | 537 +++++++++++++++++++++++++++-------- parfor/common.py | 58 ---- parfor/gil.py | 463 ------------------------------ parfor/nogil.py | 158 ----------- parfor/pickler.py | 119 -------- pyproject.toml | 6 +- tests/test_parfor.py | 71 ++--- 8 files changed, 463 insertions(+), 953 deletions(-) delete mode 100644 parfor/common.py delete mode 100644 parfor/gil.py delete mode 100644 parfor/nogil.py delete mode 100644 parfor/pickler.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c61decd..1451798 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -11,9 +11,9 @@ jobs: os: [ubuntu-latest, windows-latest, macOS-latest] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install diff --git a/parfor/__init__.py b/parfor/__init__.py index fc7792b..1756462 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -1,47 +1,114 @@ from __future__ import annotations -import sys -from contextlib import ExitStack +import logging +import os +import warnings +from contextlib import ExitStack, redirect_stdout, redirect_stderr from functools import wraps from importlib.metadata import version -from typing import Any, Callable, Generator, Iterable, Iterator, Sized -from warnings import warn +from multiprocessing.shared_memory import SharedMemory +from traceback import format_exc +from typing import Any, Callable, Generator, Iterable, Iterator, Sized, Hashable, NoReturn, Optional, Protocol, Sequence +import numpy as np +import ray +from numpy.typing import ArrayLike, DTypeLike from tqdm.auto import tqdm -from . import gil, nogil -from .common import Bar, SharedArray, cpu_count -if hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled(): # noqa - from .nogil import ParPool, PoolSingleton, Task, Worker -else: - from .gil import ParPool, PoolSingleton, Task, Worker +__version__ = version("parfor") +cpu_count = int(os.cpu_count()) -__version__ = version('parfor') +class Bar(Protocol): + def update(self, n: int = 1) -> None: ... + + +class SharedArray(np.ndarray): + """Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one + process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent + race conditions! + """ + + def __new__( + cls, + shape: int | Sequence[int], + dtype: DTypeLike = float, + shm: str | SharedMemory = None, + offset: int = 0, + strides: tuple[int, int] = None, + order: str = None, + ) -> SharedArray: + if isinstance(shm, str): + shm = SharedMemory(shm) + elif shm is None: + shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize) # type: ignore + new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order) + new.shm = shm + return new + + def __reduce__( + self, + ) -> tuple[ + Callable[[int | Sequence[int], DTypeLike, str], SharedArray], + tuple[int | tuple[int, ...], np.dtype, str], + ]: + return self.__class__, (self.shape, self.dtype, self.shm.name) + + def __enter__(self) -> SharedArray: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + if hasattr(self, "shm"): + self.shm.close() + self.shm.unlink() + + def __del__(self) -> None: + if hasattr(self, "shm"): + self.shm.close() + + def __array_finalize__(self, obj: np.ndarray | None) -> None: + if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray): + raise TypeError("view casting to SharedArray is not implemented because right now we need to make a copy") + + @classmethod + def from_array(cls, array: ArrayLike) -> SharedArray: + """copy existing array into a SharedArray""" + array = np.asarray(array) + new = cls(array.shape, array.dtype) + new[:] = array[:] + return new + class Chunks(Iterable): - """ Yield successive chunks from lists. - Usage: chunks(list0, list1, ...) - chunks(list0, list1, ..., size=s) - chunks(list0, list1, ..., number=n) - chunks(list0, list1, ..., ratio=r) - size: size of chunks, might change to optimize division between chunks - number: number of chunks, coerced to 1 <= n <= len(list0) - ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0) - both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size - both ratio and number are given: use ratio + """Yield successive chunks from lists. + Usage: chunks(list0, list1, ...) + chunks(list0, list1, ..., size=s) + chunks(list0, list1, ..., number=n) + chunks(list0, list1, ..., ratio=r) + size: size of chunks, might change to optimize division between chunks + number: number of chunks, coerced to 1 <= n <= len(list0) + ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0) + both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size + both ratio and number are given: use ratio """ - def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None, - ratio: float = None, length: int = None) -> None: + def __init__( + self, + *iterables: Iterable[Any] | Sized, + size: int = None, + number: int = None, + ratio: float = None, + length: int = None, + ) -> None: if length is None: try: length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0]) except TypeError: - raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an' - ' argument.') + raise TypeError( + "Cannot determine the length of the iterables(s), so the length must be provided as an argument." + ) if size is not None and (number is not None or ratio is not None): if number is None: number = int(cpu_count * ratio) @@ -54,12 +121,17 @@ class Chunks(Iterable): self.iterators = [iter(arg) for arg in iterables] self.number_of_items = length self.length = min(length, number) - self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length) - for i in range(self.length)] + self.lengths = [ + ((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length) + for i in range(self.length) + ] def __iter__(self) -> Iterator[Any]: for i in range(self.length): - p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length) + p, q = ( + (i * self.number_of_items // self.length), + ((i + 1) * self.number_of_items // self.length), + ) if len(self.iterators) == 1: yield [next(self.iterators[0]) for _ in range(q - p)] else: @@ -70,7 +142,12 @@ class Chunks(Iterable): class ExternalBar(Iterable): - def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None: + def __init__( + self, + iterable: Iterable = None, + callback: Callable[[int], None] = None, + total: int = 0, + ) -> None: self.iterable = iterable self.callback = callback self.total = total @@ -102,90 +179,312 @@ class ExternalBar(Iterable): self.callback(n) -def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None, - args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None, - bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None, - n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False, - **bar_kwargs: Any) -> Generator[Any, None, None]: - """ map a function fun to each iteration in iterable - use as a function: pmap - use as a decorator: parfor - best use: iterable is a generator and length is given to this function as 'total' +@ray.remote +def worker(task): + try: + with ( + warnings.catch_warnings(), + redirect_stdout(open(os.devnull, "w")), + redirect_stderr(open(os.devnull, "w")), + ): + warnings.simplefilter("ignore", category=FutureWarning) + try: + task() + task.status = "done", + except Exception: # noqa + task.status = "task_error", format_exc() + except KeyboardInterrupt: # noqa + pass - required: - fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs - iterable: iterable or iterator from which an item is given to fun as a first argument - optional: - args: tuple with other unnamed arguments to fun - kwargs: dict with other named arguments to fun - total: give the length of the iterator in cases where len(iterator) results in an error - desc: string with description of the progress bar - bar: bool enable progress bar, - or a callback function taking the number of passed iterations as an argument - serial: execute in series instead of parallel if True, None (default): let pmap decide - length: deprecated alias for total - n_processes: number of processes to use, - the parallel pool will be restarted if the current pool does not have the right number of processes - yield_ordered: return the result in the same order as the iterable - yield_index: return the index of the result too - **bar_kwargs: keywords arguments for tqdm.tqdm + return task - output: - list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration - of the iterable / iterator - examples: - << from time import sleep - << - @parfor(range(10), (3,)) - def fun(i, a): - sleep(1) - return a * i ** 2 - fun - >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] +class Task: + def __init__( + self, + handle: Hashable, + fun: Callable[[Any, ...], Any], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] = None, + ) -> None: + self.handle = handle + self.fun = fun + self.args = args + self.kwargs = kwargs + self.name = fun.__name__ if hasattr(fun, "__name__") else None + self.done = False + self.result = None + self.future = None + self.status = "starting" - << - def fun(i, a): - sleep(1) - return a * i ** 2 - pmap(fun, range(10), (3,)) - >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + @property + def fun(self) -> Callable[[Any, ...], Any]: + return ray.get(self._fun) - equivalent to using the deco module: - << - @concurrent - def fun(i, a): - time.sleep(1) - return a * i ** 2 + @fun.setter + def fun(self, fun: Callable[[Any, ...], Any]): + self._fun = ray.put(fun) - @synchronized - def run(iterator, a): - res = [] - for i in iterator: - res.append(fun(i, a)) - return res - run(range(10), 3) - >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + @property + def args(self) -> tuple[Any, ...]: + return tuple([ray.get(arg) for arg in self._args]) - all equivalent to the serial for-loop: - << - a = 3 - fun = [] - for i in range(10): - sleep(1) - fun.append(a * i ** 2) - fun - >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + @args.setter + def args(self, args: tuple[Any, ...]) -> None: + self._args = [ray.put(arg) for arg in args] + + @property + def kwargs(self) -> dict[str, Any]: + return {key: ray.get(value) for key, value in self._kwargs.items()} + + @kwargs.setter + def kwargs(self, kwargs: dict[str, Any]) -> None: + self._kwargs = {key: ray.put(value) for key, value in kwargs.items()} + + @property + def result(self) -> Any: + return ray.get(self._result) + + @result.setter + def result(self, result: Any) -> None: + self._result = ray.put(result) + + def __call__(self) -> Task: + if not self.done: + self.result = self.fun(*self.args, **self.kwargs) # noqa + self.done = True + return self + + def __repr__(self) -> str: + if self.done: + return f"Task {self.handle}, result: {self.result}" + else: + return f"Task {self.handle}" + + +class ParPool: + """Parallel processing with addition of iterations at any time and request of that result any time after that. + The target function and its argument can be changed at any time. + """ + + def __init__( + self, + fun: Callable[[Any, ...], Any] = None, + args: tuple[Any] = None, + kwargs: dict[str, Any] = None, + n_processes: int = None, + bar: Bar = None, + ): + self.handle = 0 + self.tasks = {} + self.bar = bar + self.bar_lengths = {} + self.fun = fun + self.args = args + self.kwargs = kwargs + PoolSingleton(n_processes) + + def __getstate__(self) -> NoReturn: + raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.") + + def __enter__(self) -> ParPool: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass + + def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None: + self.add_task( + args=(n, *(() if self.args is None else self.args)), + handle=handle, + barlength=barlength, + ) + + def add_task( + self, + fun: Callable[[Any, ...], Any] = None, + args: tuple[Any, ...] = None, + kwargs: dict[str, Any] = None, + handle: Hashable = None, + barlength: int = 1, + ) -> Optional[int]: + if handle is None: + new_handle = self.handle + self.handle += 1 + else: + new_handle = handle + if new_handle in self: + raise ValueError(f"handle {new_handle} already present") + task = Task( + new_handle, + fun or self.fun, + args or self.args, + kwargs or self.kwargs, + ) + task.future = worker.remote(task) + self.tasks[new_handle] = task + self.bar_lengths[new_handle] = barlength + if handle is None: + return new_handle + else: + return None + + def __setitem__(self, handle: Hashable, n: Any) -> None: + """Add new iteration.""" + self(n, handle=handle) + + def __getitem__(self, handle: Hashable) -> Any: + """Request result and delete its record. Wait if result not yet available.""" + if handle not in self: + raise ValueError(f"No handle: {handle} in pool") + task = ray.get(self.tasks[handle].future) + return self.finalize_task(task) + + def __contains__(self, handle: Hashable) -> bool: + return handle in self.tasks + + def __delitem__(self, handle: Hashable) -> None: + self.tasks.pop(handle) + + def finalize_task(self, task: Task) -> Any: + code, *args = task.status + getattr(self, code)(task, *args) + self.tasks.pop(task.handle) + return task.result + + def get_newest(self) -> Optional[Any]: + """Request the newest handle and result and delete its record. Wait if result not yet available.""" + while True: + for handle, task in self.tasks.items(): + if handle in self.bar_lengths: + try: + task = ray.get(task.future, timeout=0.01) + return task.handle, self.finalize_task(task) + except ray.exceptions.GetTimeoutError: + pass + + def task_error(self, task: Task, error: Exception) -> None: + if task.handle in self: + task = self.tasks[task.handle] + print(f"Error from process working on iteration {task.handle}:\n") + print(error) + print("Retrying in main process...") + task() + raise Exception(f"Function '{task.name}' cannot be executed by parfor, amend or execute in serial.") + + def done(self, task: Task) -> None: + if task.handle in self: # if not, the task was restarted erroneously + self.tasks[task.handle] = task + if hasattr(self.bar, "update"): + self.bar.update(self.bar_lengths.pop(task.handle)) + + +class PoolSingleton: + instance: PoolSingleton = None + cpu_count: int = int(os.cpu_count()) + + def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton: + # restart if any workers have shut down or if we want to have a different number of processes + n_processes = n_processes or cls.cpu_count + if cls.instance is None or cls.instance.n_processes != n_processes: + cls.instance = super().__new__(cls) + cls.instance.n_processes = n_processes + if ray.is_initialized(): + warnings.warn("not setting n_processes because parallel pool was already initialized") + else: + os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" + ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False) + return cls.instance + + +class Worker: + nested: bool = False + + +def gmap( + fun: Callable[[Any, ...], Any], + iterable: Iterable[Any] = None, + args: tuple[Any, ...] = None, + kwargs: dict[str, Any] = None, + total: int = None, + desc: str = None, + bar: Bar | bool = True, + serial: bool = None, + n_processes: int = None, + yield_ordered: bool = True, + yield_index: bool = False, + **bar_kwargs: Any, +) -> Generator[Any, None, None]: + """map a function fun to each iteration in iterable + use as a function: pmap + use as a decorator: parfor + best use: iterable is a generator and length is given to this function as 'total' + + required: + fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs + iterable: iterable or iterator from which an item is given to fun as a first argument + optional: + args: tuple with other unnamed arguments to fun + kwargs: dict with other named arguments to fun + total: give the length of the iterator in cases where len(iterator) results in an error + desc: string with description of the progress bar + bar: bool enable progress bar, + or a callback function taking the number of passed iterations as an argument + serial: execute in series instead of parallel if True, None (default): let pmap decide + length: deprecated alias for total + n_processes: number of processes to use, + the parallel pool will be restarted if the current pool does not have the right number of processes + yield_ordered: return the result in the same order as the iterable + yield_index: return the index of the result too + **bar_kwargs: keywords arguments for tqdm.tqdm + + output: + list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration + of the iterable / iterator + + examples: + << from time import sleep + << + @parfor(range(10), (3,)) + def fun(i, a): + sleep(1) + return a * i ** 2 + fun + >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + + << + def fun(i, a): + sleep(1) + return a * i ** 2 + pmap(fun, range(10), (3,)) + >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + + equivalent to using the deco module: + << + @concurrent + def fun(i, a): + time.sleep(1) + return a * i ** 2 + + @synchronized + def run(iterator, a): + res = [] + for i in iterator: + res.append(fun(i, a)) + return res + run(range(10), 3) + >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] + + all equivalent to the serial for-loop: + << + a = 3 + fun = [] + for i in range(10): + sleep(1) + fun.append(a * i ** 2) + fun + >> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243] """ - if total is None and length is not None: - total = length - warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=2) - warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=3) - if terminator is not None: - warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically', - DeprecationWarning, stacklevel=2) - warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically', - DeprecationWarning, stacklevel=3) is_chunked = isinstance(iterable, Chunks) if is_chunked: chunk_fun = fun @@ -201,13 +500,13 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None, if kwargs is None: kwargs = {} - if 'total' not in bar_kwargs: - bar_kwargs['total'] = sum(iterable.lengths) - if 'desc' not in bar_kwargs: - bar_kwargs['desc'] = desc - if 'disable' not in bar_kwargs: - bar_kwargs['disable'] = not bar - if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case + if "total" not in bar_kwargs: + bar_kwargs["total"] = sum(iterable.lengths) + if "desc" not in bar_kwargs: + bar_kwargs["desc"] = desc + if "disable" not in bar_kwargs: + bar_kwargs["disable"] = not bar + if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)): # serial case def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa with tqdm(*args, **kwargs) as b: @@ -215,8 +514,9 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None, yield chunk b.update(length) - iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar) - else tqdm_chunks(iterable, **bar_kwargs)) + iterable = ( + ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar) else tqdm_chunks(iterable, **bar_kwargs) # type: ignore + ) if is_chunked: if yield_index: for i, c in enumerate(iterable): @@ -233,10 +533,10 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None, for c in iterable: yield from chunk_fun(c, *args, **kwargs) - else: # parallel case - with ExitStack() as stack: + else: # parallel case + with ExitStack() as stack: # noqa if callable(bar): - bar = stack.enter_context(ExternalBar(callback=bar)) + bar = stack.enter_context(ExternalBar(callback=bar)) # noqa else: bar = stack.enter_context(tqdm(**bar_kwargs)) with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p: # type: ignore @@ -287,12 +587,13 @@ def pmap(*args, **kwargs) -> list[Any]: def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]: def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]: return pmap(fun, *args, **kwargs) + return decfun try: parfor.__doc__ = pmap.__doc__ = gmap.__doc__ pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__ - parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'} + parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != "fun"} except AttributeError: pass diff --git a/parfor/common.py b/parfor/common.py deleted file mode 100644 index b067eb2..0000000 --- a/parfor/common.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import os -from multiprocessing.shared_memory import SharedMemory -from typing import Any, Callable, Protocol, Sequence - -import numpy as np -from numpy.typing import ArrayLike, DTypeLike - -cpu_count = int(os.cpu_count()) - - -class Bar(Protocol): - def update(self, n: int = 1) -> None: ... - - -class SharedArray(np.ndarray): - """ Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one - process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent - race conditions! - """ - - def __new__(cls, shape: int | Sequence[int], dtype: DTypeLike = float, shm: str | SharedMemory = None, - offset: int = 0, strides: tuple[int, int] = None, order: str = None) -> SharedArray: - if isinstance(shm, str): - shm = SharedMemory(shm) - elif shm is None: - shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize) - new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order) - new.shm = shm - return new - - def __reduce__(self) -> tuple[Callable[[int | Sequence[int], DTypeLike, str], SharedArray], - tuple[int | tuple[int, ...], np.dtype, str]]: - return self.__class__, (self.shape, self.dtype, self.shm.name) - - def __enter__(self) -> SharedArray: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - if hasattr(self, 'shm'): - self.shm.close() - self.shm.unlink() - - def __del__(self) -> None: - if hasattr(self, 'shm'): - self.shm.close() - - def __array_finalize__(self, obj: np.ndarray | None) -> None: - if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray): - raise TypeError('view casting to SharedArray is not implemented because right now we need to make a copy') - - @classmethod - def from_array(cls, array: ArrayLike) -> SharedArray: - """ copy existing array into a SharedArray """ - new = cls(array.shape, array.dtype) - new[:] = array[:] - return new diff --git a/parfor/gil.py b/parfor/gil.py deleted file mode 100644 index d672391..0000000 --- a/parfor/gil.py +++ /dev/null @@ -1,463 +0,0 @@ -from __future__ import annotations - -import asyncio -import multiprocessing -from collections import UserDict -from contextlib import redirect_stderr, redirect_stdout -from os import cpu_count, devnull, getpid -from time import time -from traceback import format_exc -from typing import Any, Callable, Hashable, NoReturn, Optional -from warnings import warn - -from .common import Bar -from .pickler import dumps, loads - - -class SharedMemory(UserDict): - def __init__(self, manager: multiprocessing.Manager) -> None: - super().__init__() - self.data = manager.dict() # item_id: dilled representation of object - self.references = manager.dict() # item_id: counter - self.references_lock = manager.Lock() - self.cache = {} # item_id: object - self.trash_can = {} - self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} - - def __getstate__(self) -> tuple[dict[int, bytes], dict[int, int], multiprocessing.Lock]: - return self.data, self.references, self.references_lock - - def __setitem__(self, item_id: int, value: Any) -> None: - if item_id not in self: # values will not be changed - try: - self.data[item_id] = False, value - except Exception: # only use our pickler when necessary # noqa - self.data[item_id] = True, dumps(value, recurse=True) - with self.references_lock: - try: - self.references[item_id] += 1 - except KeyError: - self.references[item_id] = 1 - self.cache[item_id] = value # the id of the object will not be reused as long as the object exists - - def add_item(self, item: Any, pool_id: int, task_handle: Hashable) -> int: - item_id = id(item) - self[item_id] = item - if item_id in self.pool_ids: - self.pool_ids[item_id].add((pool_id, task_handle)) - else: - self.pool_ids[item_id] = {(pool_id, task_handle)} - return item_id - - def remove_pool(self, pool_id: int) -> None: - """ remove objects used by a pool that won't be needed anymore """ - self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} - for item_id in set(self.data.keys()) - set(self.pool_ids): - del self[item_id] - self.garbage_collect() - - def remove_task(self, pool_id: int, task: Task) -> None: - """ remove objects used by a task that won't be needed anymore """ - self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} - for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): - del self[item_id] - self.garbage_collect() - - # worker functions - def __setstate__(self, state: dict) -> None: - self.data, self.references, self.references_lock = state - self.cache = {} - self.trash_can = None - - def __getitem__(self, item_id: int) -> Any: - if item_id not in self.cache: - dilled, value = self.data[item_id] - if dilled: - value = loads(value) - with self.references_lock: - if item_id in self.references: - self.references[item_id] += 1 - else: - self.references[item_id] = 1 - self.cache[item_id] = value - return self.cache[item_id] - - def garbage_collect(self) -> None: - """ clean up the cache """ - for item_id in set(self.cache) - set(self.data.keys()): - with self.references_lock: - try: - self.references[item_id] -= 1 - except KeyError: - self.references[item_id] = 0 - if self.trash_can is not None and item_id not in self.trash_can: - self.trash_can[item_id] = self.cache[item_id] - del self.cache[item_id] - - if self.trash_can: - for item_id in set(self.trash_can): - if self.references[item_id] == 0: - # make sure every process removed the object before removing it in the parent - del self.references[item_id] - del self.trash_can[item_id] - - -class Task: - def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: Hashable, fun: Callable[[Any, ...], Any], - args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None: - self.pool_id = pool_id - self.handle = handle - self.fun = shared_memory.add_item(fun, pool_id, handle) - self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args] - self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle) - for item in kwargs.items()] - self.name = fun.__name__ if hasattr(fun, '__name__') else None - self.done = False - self.result = None - self.pid = None - - def __getstate__(self) -> dict[str, Any]: - state = self.__dict__ - if self.result is not None: - state['result'] = dumps(self.result, recurse=True) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - self.__dict__.update({key: value for key, value in state.items() if key != 'result'}) - if state['result'] is None: - self.result = None - else: - self.result = loads(state['result']) - - def __call__(self, shared_memory: SharedMemory) -> Task: - if not self.done: - fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa - args = [shared_memory[arg] for arg in self.args] - kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs]) - self.result = fun(*args, **kwargs) # noqa - self.done = True - return self - - def __repr__(self) -> str: - if self.done: - return f'Task {self.handle}, result: {self.result}' - else: - return f'Task {self.handle}' - - -class Context(multiprocessing.context.SpawnContext): - """ Provide a context where child processes never are daemonic. """ - class Process(multiprocessing.context.SpawnProcess): - @property - def daemon(self) -> bool: - return False - - @daemon.setter - def daemon(self, value: bool) -> None: - pass - - -class ParPool: - """ Parallel processing with addition of iterations at any time and request of that result any time after that. - The target function and its argument can be changed at any time. - """ - def __init__(self, fun: Callable[[Any, ...], Any] = None, - args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None): - self.id = id(self) - self.handle = 0 - self.tasks = {} - self.bar = bar - self.bar_lengths = {} - self.spool = PoolSingleton(n_processes, self) - self.manager = self.spool.manager - self.fun = fun - self.args = args - self.kwargs = kwargs - self.is_started = False - - def __getstate__(self) -> NoReturn: - raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') - - def __enter__(self) -> ParPool: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.close() - - def close(self) -> None: - self.spool.remove_pool(self.id) - - def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None: - self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength) - - def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None, - kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]: - if self.id not in self.spool.pools: - raise ValueError(f'this pool is not registered (anymore) with the pool singleton') - if handle is None: - new_handle = self.handle - self.handle += 1 - else: - new_handle = handle - if new_handle in self: - raise ValueError(f'handle {new_handle} already present') - task = Task(self.spool.shared_memory, self.id, new_handle, - fun or self.fun, args or self.args, kwargs or self.kwargs) - self.tasks[new_handle] = task - self.spool.add_task(task) - self.bar_lengths[new_handle] = barlength - if handle is None: - return new_handle - - def __setitem__(self, handle: Hashable, n: Any) -> None: - """ Add new iteration. """ - self(n, handle=handle) - - def __getitem__(self, handle: Hashable) -> Any: - """ Request result and delete its record. Wait if result not yet available. """ - if handle not in self: - raise ValueError(f'No handle: {handle} in pool') - while not self.tasks[handle].done: - if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \ - and not self.working: - for _ in range(10): # wait some time while processing possible new messages - self.spool.get_from_queue() - if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \ - and not self.working: - # retry a task if the process was killed while working on a task - self.spool.add_task(self.tasks[handle]) - warn(f'Task {handle} was restarted because the process working on it was probably killed.') - result = self.tasks[handle].result - self.tasks.pop(handle) - return result - - def __contains__(self, handle: Hashable) -> bool: - return handle in self.tasks - - def __delitem__(self, handle: Hashable) -> None: - self.tasks.pop(handle) - - def get_newest(self) -> Any: - return self.spool.get_newest_for_pool(self) - - def process_queue(self) -> None: - self.spool.process_queue() - - def task_error(self, handle: Hashable, error: Exception) -> None: - if handle in self: - task = self.tasks[handle] - print(f'Error from process working on iteration {handle}:\n') - print(error) - print('Retrying in main process...') - task(self.spool.shared_memory) - self.spool.shared_memory.remove_task(self.id, task) - raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.') - - def done(self, task: Task) -> None: - if task.handle in self: # if not, the task was restarted erroneously - self.tasks[task.handle] = task - if hasattr(self.bar, 'update'): - self.bar.update(self.bar_lengths.pop(task.handle)) - self.spool.shared_memory.remove_task(self.id, task) - - def started(self, handle: Hashable, pid: int) -> None: - self.is_started = True - if handle in self: # if not, the task was restarted erroneously - self.tasks[handle].pid = pid - - @property - def working(self) -> bool: - return not all([task.pid is None for task in self.tasks.values()]) - - -class PoolSingleton: - """ There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a - new pool. The pool will close itself after 10 minutes of idle time. """ - - instance = None - cpu_count = cpu_count() - - def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton: - # restart if any workers have shut down or if we want to have a different number of processes - if cls.instance is not None: - if (cls.instance.n_workers.value < cls.instance.n_processes or - cls.instance.n_processes != (n_processes or cls.cpu_count)): - cls.instance.close() - if cls.instance is None or not cls.instance.is_alive: - new = super().__new__(cls) - new.n_processes = n_processes or cls.cpu_count - new.instance = new - new.is_started = False - ctx = Context() - new.n_workers = ctx.Value('i', new.n_processes) - new.event = ctx.Event() - new.queue_in = ctx.Queue(3 * new.n_processes) - new.queue_out = ctx.Queue(new.n_processes) - new.manager = ctx.Manager() - new.shared_memory = SharedMemory(new.manager) - new.pool = ctx.Pool(new.n_processes, - Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event)) - new.is_alive = True - new.handle = 0 - new.pools = {} - new.time_out = None - cls.instance = new - return cls.instance - - def __init__(self, n_processes: int = None, parpool: Parpool = None) -> None: # noqa - if parpool is not None: - self.pools[parpool.id] = parpool - if self.time_out is not None: - self.time_out.cancel() - self.time_out = None - - def __getstate__(self) -> NoReturn: - raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') - - def remove_pool(self, pool_id: int) -> None: - self.shared_memory.remove_pool(pool_id) - if pool_id in self.pools: - self.pools.pop(pool_id) - if len(self.pools) == 0: - try: - self.time_out = asyncio.get_running_loop().call_later(600, self.close) # noqa - except RuntimeError: - self.time_out = asyncio.new_event_loop().call_later(600, self.close) # noqa - - def error(self, error: Exception) -> NoReturn: - self.close() - raise Exception(f'Error occurred in worker: {error}') - - def process_queue(self) -> None: - while self.get_from_queue(): - pass - - def get_from_queue(self) -> bool: - """ Get an item from the queue and store it, return True if more messages are waiting. """ - try: - code, pool_id, *args = self.queue_out.get(True, 0.02) - if pool_id is None: - getattr(self, code)(*args) - elif pool_id in self.pools: - getattr(self.pools[pool_id], code)(*args) - return True - except multiprocessing.queues.Empty: # noqa - for pool in self.pools.values(): - for handle, task in pool.tasks.items(): # retry a task if the process doing it was killed - if task.pid is not None \ - and task.pid not in [child.pid for child in multiprocessing.active_children()]: - self.queue_in.put(task) - warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.') - return False - - def add_task(self, task: Task) -> None: - """ Add new iteration, using optional manually defined handle.""" - if self.is_alive and not self.event.is_set(): - while self.queue_in.full(): - self.get_from_queue() - self.queue_in.put(task) - self.shared_memory.garbage_collect() - - def get_newest_for_pool(self, pool: ParPool) -> tuple[Hashable, Any]: - """ Request the newest key and result and delete its record. Wait if result not yet available. """ - while len(pool.tasks): - self.get_from_queue() - for task in pool.tasks.values(): - if task.done: - handle, result = task.handle, task.result - pool.tasks.pop(handle) - return handle, result - - @classmethod - def close(cls) -> None: - if cls.instance is not None: - instance = cls.instance - cls.instance = None - if instance.time_out is not None: - instance.time_out.cancel() - - def empty_queue(queue): - try: - if not queue._closed: # noqa - while not queue.empty(): - try: - queue.get(True, 0.02) - except multiprocessing.queues.Empty: # noqa - pass - except OSError: - pass - - def close_queue(queue: multiprocessing.queues.Queue) -> None: - empty_queue(queue) # noqa - if not queue._closed: # noqa - queue.close() - queue.join_thread() - - if instance.is_alive: - instance.is_alive = False - instance.event.set() - instance.pool.close() - t = time() - while instance.n_workers.value: - empty_queue(instance.queue_in) - empty_queue(instance.queue_out) - if time() - t > 10: - warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.') - instance.pool.terminate() - break - empty_queue(instance.queue_in) - empty_queue(instance.queue_out) - instance.pool.join() - close_queue(instance.queue_in) - close_queue(instance.queue_out) - instance.manager.shutdown() - instance.handle = 0 - - -class Worker: - """ Manages executing the target function which will be executed in different processes. """ - nested = False - - def __init__(self, shared_memory: SharedMemory, queue_in: multiprocessing.queues.Queue, - queue_out: multiprocessing.queues.Queue, n_workers: multiprocessing.Value, - event: multiprocessing.Event) -> None: - self.shared_memory = shared_memory - self.queue_in = queue_in - self.queue_out = queue_out - self.n_workers = n_workers - self.event = event - - def add_to_queue(self, *args: Any) -> None: - while not self.event.is_set(): - try: - self.queue_out.put(args, timeout=0.1) - break - except multiprocessing.queues.Full: # noqa - continue - - def __call__(self) -> None: - Worker.nested = True - pid = getpid() - last_active_time = time() - while not self.event.is_set() and time() - last_active_time < 600: - try: - with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')): - task = self.queue_in.get(True, 0.02) - try: - self.add_to_queue('started', task.pool_id, task.handle, pid) - self.add_to_queue('done', task.pool_id, task(self.shared_memory)) - except Exception: # noqa - self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) - self.event.set() - self.shared_memory.garbage_collect() - last_active_time = time() - except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa - pass - except Exception: # noqa - self.add_to_queue('error', None, format_exc()) - self.event.set() - self.shared_memory.garbage_collect() - for child in multiprocessing.active_children(): - child.kill() - with self.n_workers: - self.n_workers.value -= 1 diff --git a/parfor/nogil.py b/parfor/nogil.py deleted file mode 100644 index 0eedf20..0000000 --- a/parfor/nogil.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import queue -import threading -from os import cpu_count -from typing import Any, Callable, Hashable, NoReturn, Optional - -from .common import Bar - - -class Worker: - nested = False - - def __init__(self, *args, **kwargs): - pass - - -class PoolSingleton: - cpu_count = cpu_count() - - def __init__(self, *args, **kwargs): - pass - - def close(self): - pass - - -class Task: - def __init__(self, queue: queue.Queue, handle: Hashable, fun: Callable[[Any, ...], Any], # noqa - args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None: - self.queue = queue - self.handle = handle - self.fun = fun - self.args = args - self.kwargs = {} if kwargs is None else kwargs - self.name = fun.__name__ if hasattr(fun, '__name__') else None - self.started = False - self.done = False - self.result = None - - def __call__(self): - if not self.done: - self.result = self.fun(*self.args, **self.kwargs) - try: - self.queue.put(self.handle) - except queue.ShutDown: - pass - - def __repr__(self) -> str: - if self.done: - return f'Task {self.handle}, result: {self.result}' - else: - return f'Task {self.handle}' - - -class ParPool: - """ Parallel processing with addition of iterations at any time and request of that result any time after that. - The target function and its argument can be changed at any time. - """ - def __init__(self, fun: Callable[[Any, ...], Any] = None, - args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None): - self.queue = queue.Queue() - self.handle = 0 - self.tasks = {} - self.bar = bar - self.bar_lengths = {} - self.fun = fun - self.args = args - self.kwargs = kwargs - self.n_processes = n_processes or PoolSingleton.cpu_count - self.threads = {} - - def __getstate__(self) -> NoReturn: - raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') - - def __enter__(self) -> ParPool: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.close() - - def close(self) -> None: - self.queue.shutdown() # noqa python3.13 - for thread in self.threads.values(): - thread.join() - - def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None: - self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength) - - def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None, - kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]: - if handle is None: - new_handle = self.handle - self.handle += 1 - else: - new_handle = handle - if new_handle in self: - raise ValueError(f'handle {new_handle} already present') - task = Task(self.queue, new_handle, fun or self.fun, args or self.args, kwargs or self.kwargs) - while len(self.threads) > self.n_processes: - self.get_from_queue() - thread = threading.Thread(target=task) - thread.start() - self.threads[new_handle] = thread - self.tasks[new_handle] = task - self.bar_lengths[new_handle] = barlength - if handle is None: - return new_handle - - def __setitem__(self, handle: Hashable, n: Any) -> None: - """ Add new iteration. """ - self(n, handle=handle) - - def __getitem__(self, handle: Hashable) -> Any: - """ Request result and delete its record. Wait if result not yet available. """ - if handle not in self: - raise ValueError(f'No handle: {handle} in pool') - while not self.tasks[handle].done: - self.get_from_queue() - task = self.tasks.pop(handle) - return task.result - - def __contains__(self, handle: Hashable) -> bool: - return handle in self.tasks - - def __delitem__(self, handle: Hashable) -> None: - self.tasks.pop(handle) - - def get_from_queue(self) -> bool: - """ Get an item from the queue and store it, return True if more messages are waiting. """ - try: - handle = self.queue.get(True, 0.02) - self.done(handle) - return True - except (queue.Empty, queue.ShutDown): - return False - - def get_newest(self) -> Any: - """ Request the newest key and result and delete its record. Wait if result not yet available. """ - while len(self.tasks): - self.get_from_queue() - for task in self.tasks.values(): - if task.done: - handle, result = task.handle, task.result - self.tasks.pop(handle) - return handle, result - - def process_queue(self) -> None: - while self.get_from_queue(): - pass - - def done(self, handle: Hashable) -> None: - thread = self.threads.pop(handle) - thread.join() - task = self.tasks[handle] - task.done = True - if hasattr(self.bar, 'update'): - self.bar.update(self.bar_lengths.pop(handle)) diff --git a/parfor/pickler.py b/parfor/pickler.py deleted file mode 100644 index 27bb447..0000000 --- a/parfor/pickler.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -import copyreg -from io import BytesIO -from pickle import PicklingError -from typing import Any, Callable - -import dill - -loads = dill.loads - - -class CouldNotBePickled: - def __init__(self, class_name: str) -> None: - self.class_name = class_name - - def __repr__(self) -> str: - return f"Item of type '{self.class_name}' could not be pickled and was omitted." - - @classmethod - def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]: - return cls, (type(item).__name__,) - - -class Pickler(dill.Pickler): - """ Overload dill to ignore unpicklable parts of objects. - You probably didn't want to use these parts anyhow. - However, if you did, you'll have to find some way to make them picklable. - """ - def save(self, obj: Any, save_persistent_id: bool = True) -> None: - """ Copied from pickle and amended. """ - self.framer.commit_frame() - - # Check for persistent id (defined by a subclass) - pid = self.persistent_id(obj) - if pid is not None and save_persistent_id: - self.save_pers(pid) - return - - # Check the memo - x = self.memo.get(id(obj)) - if x is not None: - self.write(self.get(x[0])) - return - - rv = NotImplemented - reduce = getattr(self, "reducer_override", None) - if reduce is not None: - rv = reduce(obj) - - if rv is NotImplemented: - # Check the type dispatch table - t = type(obj) - f = self.dispatch.get(t) - if f is not None: - f(self, obj) # Call unbound method with explicit self - return - - # Check private dispatch table if any, or else - # copyreg.dispatch_table - reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t) - if reduce is not None: - rv = reduce(obj) - else: - # Check for a class with a custom metaclass; treat as regular - # class - if issubclass(t, type): - self.save_global(obj) - return - - # Check for a __reduce_ex__ method, fall back to __reduce__ - reduce = getattr(obj, "__reduce_ex__", None) - try: - if reduce is not None: - rv = reduce(self.proto) - else: - reduce = getattr(obj, "__reduce__", None) - if reduce is not None: - rv = reduce() - else: - raise PicklingError("Can't pickle %r object: %r" % - (t.__name__, obj)) - except Exception: # noqa - rv = CouldNotBePickled.reduce(obj) - - # Check for string returned by reduce(), meaning "save as global" - if isinstance(rv, str): - try: - self.save_global(obj, rv) - except Exception: # noqa - self.save_global(obj, CouldNotBePickled.reduce(obj)) - return - - # Assert that reduce() returned a tuple - if not isinstance(rv, tuple): - raise PicklingError("%s must return string or tuple" % reduce) - - # Assert that it returned an appropriately sized tuple - length = len(rv) - if not (2 <= length <= 6): - raise PicklingError("Tuple returned by %s must have " - "two to six elements" % reduce) - - # Save the reduce() output and finally memoize the object - try: - self.save_reduce(obj=obj, *rv) - except Exception: # noqa - self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) - - -def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, - **kwds: Any) -> bytes: - """pickle an object to a string""" - protocol = dill.settings['protocol'] if protocol is None else int(protocol) - _kwds = kwds.copy() - _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) - with BytesIO() as file: - Pickler(file, protocol, **_kwds).dump(obj) - return file.getvalue() diff --git a/pyproject.toml b/pyproject.toml index 5d1be80..5b5008a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "parfor" -version = "2025.1.0" +version = "2026.1.0" description = "A package to mimic the use of parfor as done in Matlab." authors = ["Wim Pomp "] license = "GPLv3" @@ -17,6 +17,10 @@ pytest = { version = "*", optional = true } [tool.poetry.extras] test = ["pytest", "numpy"] +[tool.ruff] +line-length = 119 +indent-width = 4 + [tool.isort] line_length = 119 diff --git a/tests/test_parfor.py b/tests/test_parfor.py index a0e9708..2d178ab 100644 --- a/tests/test_parfor.py +++ b/tests/test_parfor.py @@ -1,9 +1,6 @@ from __future__ import annotations -import sys from dataclasses import dataclass -from os import getpid -from time import sleep from typing import Any, Iterator, Optional, Sequence import numpy as np @@ -11,14 +8,6 @@ import pytest from parfor import Chunks, ParPool, SharedArray, parfor, pmap -try: - if sys._is_gil_enabled(): # noqa - gil = True - else: - gil = False -except Exception: # noqa - gil = True - class SequenceIterator: def __init__(self, sequence: Sequence) -> None: @@ -56,7 +45,7 @@ def iterators() -> tuple[Iterator, Optional[int]]: yield Iterable(range(10)), 10 -@pytest.mark.parametrize('iterator', iterators()) +@pytest.mark.parametrize("iterator", iterators()) def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None: chunks = Chunks(iterator[0], size=2, length=iterator[1]) assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] @@ -66,7 +55,7 @@ def test_parpool() -> None: def fun(i, j, k) -> int: # noqa return i * j * k - with ParPool(fun, (3,), {'k': 2}) as pool: # noqa + with ParPool(fun, (3,), {"k": 2}) as pool: # noqa for i in range(10): pool[i] = i @@ -74,40 +63,66 @@ def test_parpool() -> None: def test_parfor() -> None: - @parfor(range(10), (3,), {'k': 2}) + @parfor(range(10), (3,), {"k": 2}) def fun(i, j, k): return i * j * k assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] -@pytest.mark.parametrize('serial', (True, False)) +@pytest.mark.parametrize("serial", (True, False)) def test_pmap(serial) -> None: def fun(i, j, k): return i * j * k - assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] + assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial) == [ + 0, + 6, + 12, + 18, + 24, + 30, + 36, + 42, + 48, + 54, + ] -@pytest.mark.parametrize('serial', (True, False)) +@pytest.mark.parametrize("serial", (True, False)) def test_pmap_with_idx(serial) -> None: def fun(i, j, k): return i * j * k - assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) == - [(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)]) + assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial, yield_index=True) == [ + (0, 0), + (1, 6), + (2, 12), + (3, 18), + (4, 24), + (5, 30), + (6, 36), + (7, 42), + (8, 48), + (9, 54), + ] -@pytest.mark.parametrize('serial', (True, False)) +@pytest.mark.parametrize("serial", (True, False)) def test_pmap_chunks(serial) -> None: def fun(i, j, k): return [i_ * j * k for i_ in i] chunks = Chunks(range(10), size=2) - assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]] + assert pmap(fun, chunks, (3,), {"k": 2}, serial=serial) == [ + [0, 6], + [12, 18], + [24, 30], + [36, 42], + [48, 54], + ] -@pytest.mark.skipif(not gil, reason='test if gil enabled only') def test_id_reuse() -> None: def fun(i): return i[0].a @@ -126,18 +141,6 @@ def test_id_reuse() -> None: assert all([i == j for i, j in enumerate(a)]) -@pytest.mark.skipif(not gil, reason='test if gil enabled only') -@pytest.mark.parametrize('n_processes', (2, 4, 6)) -def test_n_processes(n_processes) -> None: - - @parfor(range(12), n_processes=n_processes) - def fun(i): # noqa - sleep(0.25) - return getpid() - - assert len(set(fun)) == n_processes - - def test_shared_array() -> None: def fun(i, a): a[i] = i