Now based on ray, which enables nested parallel computations.

This commit is contained in:
Wim Pomp
2026-01-08 00:48:14 +01:00
parent a84bf2d29e
commit 5528b54ede
8 changed files with 463 additions and 953 deletions

View File

@@ -11,9 +11,9 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install

View File

@@ -1,28 +1,88 @@
from __future__ import annotations
import sys
from contextlib import ExitStack
import logging
import os
import warnings
from contextlib import ExitStack, redirect_stdout, redirect_stderr
from functools import wraps
from importlib.metadata import version
from typing import Any, Callable, Generator, Iterable, Iterator, Sized
from warnings import warn
from multiprocessing.shared_memory import SharedMemory
from traceback import format_exc
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, Hashable, NoReturn, Optional, Protocol, Sequence
import numpy as np
import ray
from numpy.typing import ArrayLike, DTypeLike
from tqdm.auto import tqdm
from . import gil, nogil
from .common import Bar, SharedArray, cpu_count
if hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled(): # noqa
from .nogil import ParPool, PoolSingleton, Task, Worker
else:
from .gil import ParPool, PoolSingleton, Task, Worker
__version__ = version("parfor")
cpu_count = int(os.cpu_count())
__version__ = version('parfor')
class Bar(Protocol):
def update(self, n: int = 1) -> None: ...
class SharedArray(np.ndarray):
"""Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one
process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent
race conditions!
"""
def __new__(
cls,
shape: int | Sequence[int],
dtype: DTypeLike = float,
shm: str | SharedMemory = None,
offset: int = 0,
strides: tuple[int, int] = None,
order: str = None,
) -> SharedArray:
if isinstance(shm, str):
shm = SharedMemory(shm)
elif shm is None:
shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize) # type: ignore
new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order)
new.shm = shm
return new
def __reduce__(
self,
) -> tuple[
Callable[[int | Sequence[int], DTypeLike, str], SharedArray],
tuple[int | tuple[int, ...], np.dtype, str],
]:
return self.__class__, (self.shape, self.dtype, self.shm.name)
def __enter__(self) -> SharedArray:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
if hasattr(self, "shm"):
self.shm.close()
self.shm.unlink()
def __del__(self) -> None:
if hasattr(self, "shm"):
self.shm.close()
def __array_finalize__(self, obj: np.ndarray | None) -> None:
if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray):
raise TypeError("view casting to SharedArray is not implemented because right now we need to make a copy")
@classmethod
def from_array(cls, array: ArrayLike) -> SharedArray:
"""copy existing array into a SharedArray"""
array = np.asarray(array)
new = cls(array.shape, array.dtype)
new[:] = array[:]
return new
class Chunks(Iterable):
""" Yield successive chunks from lists.
"""Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
chunks(list0, list1, ..., size=s)
chunks(list0, list1, ..., number=n)
@@ -34,14 +94,21 @@ class Chunks(Iterable):
both ratio and number are given: use ratio
"""
def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None,
ratio: float = None, length: int = None) -> None:
def __init__(
self,
*iterables: Iterable[Any] | Sized,
size: int = None,
number: int = None,
ratio: float = None,
length: int = None,
) -> None:
if length is None:
try:
length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
except TypeError:
raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an'
' argument.')
raise TypeError(
"Cannot determine the length of the iterables(s), so the length must be provided as an argument."
)
if size is not None and (number is not None or ratio is not None):
if number is None:
number = int(cpu_count * ratio)
@@ -54,12 +121,17 @@ class Chunks(Iterable):
self.iterators = [iter(arg) for arg in iterables]
self.number_of_items = length
self.length = min(length, number)
self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
for i in range(self.length)]
self.lengths = [
((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
for i in range(self.length)
]
def __iter__(self) -> Iterator[Any]:
for i in range(self.length):
p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length)
p, q = (
(i * self.number_of_items // self.length),
((i + 1) * self.number_of_items // self.length),
)
if len(self.iterators) == 1:
yield [next(self.iterators[0]) for _ in range(q - p)]
else:
@@ -70,7 +142,12 @@ class Chunks(Iterable):
class ExternalBar(Iterable):
def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None:
def __init__(
self,
iterable: Iterable = None,
callback: Callable[[int], None] = None,
total: int = 0,
) -> None:
self.iterable = iterable
self.callback = callback
self.total = total
@@ -102,12 +179,243 @@ class ExternalBar(Iterable):
self.callback(n)
def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Any, None, None]:
""" map a function fun to each iteration in iterable
@ray.remote
def worker(task):
try:
with (
warnings.catch_warnings(),
redirect_stdout(open(os.devnull, "w")),
redirect_stderr(open(os.devnull, "w")),
):
warnings.simplefilter("ignore", category=FutureWarning)
try:
task()
task.status = "done",
except Exception: # noqa
task.status = "task_error", format_exc()
except KeyboardInterrupt: # noqa
pass
return task
class Task:
def __init__(
self,
handle: Hashable,
fun: Callable[[Any, ...], Any],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] = None,
) -> None:
self.handle = handle
self.fun = fun
self.args = args
self.kwargs = kwargs
self.name = fun.__name__ if hasattr(fun, "__name__") else None
self.done = False
self.result = None
self.future = None
self.status = "starting"
@property
def fun(self) -> Callable[[Any, ...], Any]:
return ray.get(self._fun)
@fun.setter
def fun(self, fun: Callable[[Any, ...], Any]):
self._fun = ray.put(fun)
@property
def args(self) -> tuple[Any, ...]:
return tuple([ray.get(arg) for arg in self._args])
@args.setter
def args(self, args: tuple[Any, ...]) -> None:
self._args = [ray.put(arg) for arg in args]
@property
def kwargs(self) -> dict[str, Any]:
return {key: ray.get(value) for key, value in self._kwargs.items()}
@kwargs.setter
def kwargs(self, kwargs: dict[str, Any]) -> None:
self._kwargs = {key: ray.put(value) for key, value in kwargs.items()}
@property
def result(self) -> Any:
return ray.get(self._result)
@result.setter
def result(self, result: Any) -> None:
self._result = ray.put(result)
def __call__(self) -> Task:
if not self.done:
self.result = self.fun(*self.args, **self.kwargs) # noqa
self.done = True
return self
def __repr__(self) -> str:
if self.done:
return f"Task {self.handle}, result: {self.result}"
else:
return f"Task {self.handle}"
class ParPool:
"""Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(
self,
fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None,
kwargs: dict[str, Any] = None,
n_processes: int = None,
bar: Bar = None,
):
self.handle = 0
self.tasks = {}
self.bar = bar
self.bar_lengths = {}
self.fun = fun
self.args = args
self.kwargs = kwargs
PoolSingleton(n_processes)
def __getstate__(self) -> NoReturn:
raise RuntimeError(f"Cannot pickle {self.__class__.__name__} object.")
def __enter__(self) -> ParPool:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task(
args=(n, *(() if self.args is None else self.args)),
handle=handle,
barlength=barlength,
)
def add_task(
self,
fun: Callable[[Any, ...], Any] = None,
args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None,
handle: Hashable = None,
barlength: int = 1,
) -> Optional[int]:
if handle is None:
new_handle = self.handle
self.handle += 1
else:
new_handle = handle
if new_handle in self:
raise ValueError(f"handle {new_handle} already present")
task = Task(
new_handle,
fun or self.fun,
args or self.args,
kwargs or self.kwargs,
)
task.future = worker.remote(task)
self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
else:
return None
def __setitem__(self, handle: Hashable, n: Any) -> None:
"""Add new iteration."""
self(n, handle=handle)
def __getitem__(self, handle: Hashable) -> Any:
"""Request result and delete its record. Wait if result not yet available."""
if handle not in self:
raise ValueError(f"No handle: {handle} in pool")
task = ray.get(self.tasks[handle].future)
return self.finalize_task(task)
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def finalize_task(self, task: Task) -> Any:
code, *args = task.status
getattr(self, code)(task, *args)
self.tasks.pop(task.handle)
return task.result
def get_newest(self) -> Optional[Any]:
"""Request the newest handle and result and delete its record. Wait if result not yet available."""
while True:
for handle, task in self.tasks.items():
if handle in self.bar_lengths:
try:
task = ray.get(task.future, timeout=0.01)
return task.handle, self.finalize_task(task)
except ray.exceptions.GetTimeoutError:
pass
def task_error(self, task: Task, error: Exception) -> None:
if task.handle in self:
task = self.tasks[task.handle]
print(f"Error from process working on iteration {task.handle}:\n")
print(error)
print("Retrying in main process...")
task()
raise Exception(f"Function '{task.name}' cannot be executed by parfor, amend or execute in serial.")
def done(self, task: Task) -> None:
if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task
if hasattr(self.bar, "update"):
self.bar.update(self.bar_lengths.pop(task.handle))
class PoolSingleton:
instance: PoolSingleton = None
cpu_count: int = int(os.cpu_count())
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
n_processes = n_processes or cls.cpu_count
if cls.instance is None or cls.instance.n_processes != n_processes:
cls.instance = super().__new__(cls)
cls.instance.n_processes = n_processes
if ray.is_initialized():
warnings.warn("not setting n_processes because parallel pool was already initialized")
else:
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
return cls.instance
class Worker:
nested: bool = False
def gmap(
fun: Callable[[Any, ...], Any],
iterable: Iterable[Any] = None,
args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None,
total: int = None,
desc: str = None,
bar: Bar | bool = True,
serial: bool = None,
n_processes: int = None,
yield_ordered: bool = True,
yield_index: bool = False,
**bar_kwargs: Any,
) -> Generator[Any, None, None]:
"""map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
best use: iterable is a generator and length is given to this function as 'total'
@@ -177,15 +485,6 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
fun
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
"""
if total is None and length is not None:
total = length
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=2)
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=3)
if terminator is not None:
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
DeprecationWarning, stacklevel=2)
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
DeprecationWarning, stacklevel=3)
is_chunked = isinstance(iterable, Chunks)
if is_chunked:
chunk_fun = fun
@@ -201,13 +500,13 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
if kwargs is None:
kwargs = {}
if 'total' not in bar_kwargs:
bar_kwargs['total'] = sum(iterable.lengths)
if 'desc' not in bar_kwargs:
bar_kwargs['desc'] = desc
if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
if "total" not in bar_kwargs:
bar_kwargs["total"] = sum(iterable.lengths)
if "desc" not in bar_kwargs:
bar_kwargs["desc"] = desc
if "disable" not in bar_kwargs:
bar_kwargs["disable"] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)): # serial case
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
with tqdm(*args, **kwargs) as b:
@@ -215,8 +514,9 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
yield chunk
b.update(length)
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
else tqdm_chunks(iterable, **bar_kwargs))
iterable = (
ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar) else tqdm_chunks(iterable, **bar_kwargs) # type: ignore
)
if is_chunked:
if yield_index:
for i, c in enumerate(iterable):
@@ -234,9 +534,9 @@ def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
yield from chunk_fun(c, *args, **kwargs)
else: # parallel case
with ExitStack() as stack:
with ExitStack() as stack: # noqa
if callable(bar):
bar = stack.enter_context(ExternalBar(callback=bar))
bar = stack.enter_context(ExternalBar(callback=bar)) # noqa
else:
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p: # type: ignore
@@ -287,12 +587,13 @@ def pmap(*args, **kwargs) -> list[Any]:
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
return pmap(fun, *args, **kwargs)
return decfun
try:
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != "fun"}
except AttributeError:
pass

View File

@@ -1,58 +0,0 @@
from __future__ import annotations
import os
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Protocol, Sequence
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
cpu_count = int(os.cpu_count())
class Bar(Protocol):
def update(self, n: int = 1) -> None: ...
class SharedArray(np.ndarray):
""" Numpy array whose memory can be shared between processes, so that memory use is reduced and changes in one
process are reflected in all other processes. Changes are not atomic, so protect changes with a lock to prevent
race conditions!
"""
def __new__(cls, shape: int | Sequence[int], dtype: DTypeLike = float, shm: str | SharedMemory = None,
offset: int = 0, strides: tuple[int, int] = None, order: str = None) -> SharedArray:
if isinstance(shm, str):
shm = SharedMemory(shm)
elif shm is None:
shm = SharedMemory(create=True, size=np.prod(shape) * np.dtype(dtype).itemsize)
new = super().__new__(cls, shape, dtype, shm.buf, offset, strides, order)
new.shm = shm
return new
def __reduce__(self) -> tuple[Callable[[int | Sequence[int], DTypeLike, str], SharedArray],
tuple[int | tuple[int, ...], np.dtype, str]]:
return self.__class__, (self.shape, self.dtype, self.shm.name)
def __enter__(self) -> SharedArray:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
if hasattr(self, 'shm'):
self.shm.close()
self.shm.unlink()
def __del__(self) -> None:
if hasattr(self, 'shm'):
self.shm.close()
def __array_finalize__(self, obj: np.ndarray | None) -> None:
if isinstance(obj, np.ndarray) and not isinstance(obj, SharedArray):
raise TypeError('view casting to SharedArray is not implemented because right now we need to make a copy')
@classmethod
def from_array(cls, array: ArrayLike) -> SharedArray:
""" copy existing array into a SharedArray """
new = cls(array.shape, array.dtype)
new[:] = array[:]
return new

View File

@@ -1,463 +0,0 @@
from __future__ import annotations
import asyncio
import multiprocessing
from collections import UserDict
from contextlib import redirect_stderr, redirect_stdout
from os import cpu_count, devnull, getpid
from time import time
from traceback import format_exc
from typing import Any, Callable, Hashable, NoReturn, Optional
from warnings import warn
from .common import Bar
from .pickler import dumps, loads
class SharedMemory(UserDict):
def __init__(self, manager: multiprocessing.Manager) -> None:
super().__init__()
self.data = manager.dict() # item_id: dilled representation of object
self.references = manager.dict() # item_id: counter
self.references_lock = manager.Lock()
self.cache = {} # item_id: object
self.trash_can = {}
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self) -> tuple[dict[int, bytes], dict[int, int], multiprocessing.Lock]:
return self.data, self.references, self.references_lock
def __setitem__(self, item_id: int, value: Any) -> None:
if item_id not in self: # values will not be changed
try:
self.data[item_id] = False, value
except Exception: # only use our pickler when necessary # noqa
self.data[item_id] = True, dumps(value, recurse=True)
with self.references_lock:
try:
self.references[item_id] += 1
except KeyError:
self.references[item_id] = 1
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item: Any, pool_id: int, task_handle: Hashable) -> int:
item_id = id(item)
self[item_id] = item
if item_id in self.pool_ids:
self.pool_ids[item_id].add((pool_id, task_handle))
else:
self.pool_ids[item_id] = {(pool_id, task_handle)}
return item_id
def remove_pool(self, pool_id: int) -> None:
""" remove objects used by a pool that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for item_id in set(self.data.keys()) - set(self.pool_ids):
del self[item_id]
self.garbage_collect()
def remove_task(self, pool_id: int, task: Task) -> None:
""" remove objects used by a task that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
del self[item_id]
self.garbage_collect()
# worker functions
def __setstate__(self, state: dict) -> None:
self.data, self.references, self.references_lock = state
self.cache = {}
self.trash_can = None
def __getitem__(self, item_id: int) -> Any:
if item_id not in self.cache:
dilled, value = self.data[item_id]
if dilled:
value = loads(value)
with self.references_lock:
if item_id in self.references:
self.references[item_id] += 1
else:
self.references[item_id] = 1
self.cache[item_id] = value
return self.cache[item_id]
def garbage_collect(self) -> None:
""" clean up the cache """
for item_id in set(self.cache) - set(self.data.keys()):
with self.references_lock:
try:
self.references[item_id] -= 1
except KeyError:
self.references[item_id] = 0
if self.trash_can is not None and item_id not in self.trash_can:
self.trash_can[item_id] = self.cache[item_id]
del self.cache[item_id]
if self.trash_can:
for item_id in set(self.trash_can):
if self.references[item_id] == 0:
# make sure every process removed the object before removing it in the parent
del self.references[item_id]
del self.trash_can[item_id]
class Task:
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: Hashable, fun: Callable[[Any, ...], Any],
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
self.pool_id = pool_id
self.handle = handle
self.fun = shared_memory.add_item(fun, pool_id, handle)
self.args = [shared_memory.add_item(arg, pool_id, handle) for arg in args]
self.kwargs = [] if kwargs is None else [shared_memory.add_item(item, pool_id, handle)
for item in kwargs.items()]
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.done = False
self.result = None
self.pid = None
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__
if self.result is not None:
state['result'] = dumps(self.result, recurse=True)
return state
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
if state['result'] is None:
self.result = None
else:
self.result = loads(state['result'])
def __call__(self, shared_memory: SharedMemory) -> Task:
if not self.done:
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa
args = [shared_memory[arg] for arg in self.args]
kwargs = dict([shared_memory[kwarg] for kwarg in self.kwargs])
self.result = fun(*args, **kwargs) # noqa
self.done = True
return self
def __repr__(self) -> str:
if self.done:
return f'Task {self.handle}, result: {self.result}'
else:
return f'Task {self.handle}'
class Context(multiprocessing.context.SpawnContext):
""" Provide a context where child processes never are daemonic. """
class Process(multiprocessing.context.SpawnProcess):
@property
def daemon(self) -> bool:
return False
@daemon.setter
def daemon(self, value: bool) -> None:
pass
class ParPool:
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(self, fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
self.id = id(self)
self.handle = 0
self.tasks = {}
self.bar = bar
self.bar_lengths = {}
self.spool = PoolSingleton(n_processes, self)
self.manager = self.spool.manager
self.fun = fun
self.args = args
self.kwargs = kwargs
self.is_started = False
def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def __enter__(self) -> ParPool:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.close()
def close(self) -> None:
self.spool.remove_pool(self.id)
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None:
new_handle = self.handle
self.handle += 1
else:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
task = Task(self.spool.shared_memory, self.id, new_handle,
fun or self.fun, args or self.args, kwargs or self.kwargs)
self.tasks[new_handle] = task
self.spool.add_task(task)
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
def __setitem__(self, handle: Hashable, n: Any) -> None:
""" Add new iteration. """
self(n, handle=handle)
def __getitem__(self, handle: Hashable) -> Any:
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError(f'No handle: {handle} in pool')
while not self.tasks[handle].done:
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
for _ in range(10): # wait some time while processing possible new messages
self.spool.get_from_queue()
if not self.spool.get_from_queue() and not self.tasks[handle].done and self.is_started \
and not self.working:
# retry a task if the process was killed while working on a task
self.spool.add_task(self.tasks[handle])
warn(f'Task {handle} was restarted because the process working on it was probably killed.')
result = self.tasks[handle].result
self.tasks.pop(handle)
return result
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def get_newest(self) -> Any:
return self.spool.get_newest_for_pool(self)
def process_queue(self) -> None:
self.spool.process_queue()
def task_error(self, handle: Hashable, error: Exception) -> None:
if handle in self:
task = self.tasks[handle]
print(f'Error from process working on iteration {handle}:\n')
print(error)
print('Retrying in main process...')
task(self.spool.shared_memory)
self.spool.shared_memory.remove_task(self.id, task)
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
def done(self, task: Task) -> None:
if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle))
self.spool.shared_memory.remove_task(self.id, task)
def started(self, handle: Hashable, pid: int) -> None:
self.is_started = True
if handle in self: # if not, the task was restarted erroneously
self.tasks[handle].pid = pid
@property
def working(self) -> bool:
return not all([task.pid is None for task in self.tasks.values()])
class PoolSingleton:
""" There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a
new pool. The pool will close itself after 10 minutes of idle time. """
instance = None
cpu_count = cpu_count()
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
if cls.instance is not None:
if (cls.instance.n_workers.value < cls.instance.n_processes or
cls.instance.n_processes != (n_processes or cls.cpu_count)):
cls.instance.close()
if cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls)
new.n_processes = n_processes or cls.cpu_count
new.instance = new
new.is_started = False
ctx = Context()
new.n_workers = ctx.Value('i', new.n_processes)
new.event = ctx.Event()
new.queue_in = ctx.Queue(3 * new.n_processes)
new.queue_out = ctx.Queue(new.n_processes)
new.manager = ctx.Manager()
new.shared_memory = SharedMemory(new.manager)
new.pool = ctx.Pool(new.n_processes,
Worker(new.shared_memory, new.queue_in, new.queue_out, new.n_workers, new.event))
new.is_alive = True
new.handle = 0
new.pools = {}
new.time_out = None
cls.instance = new
return cls.instance
def __init__(self, n_processes: int = None, parpool: Parpool = None) -> None: # noqa
if parpool is not None:
self.pools[parpool.id] = parpool
if self.time_out is not None:
self.time_out.cancel()
self.time_out = None
def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def remove_pool(self, pool_id: int) -> None:
self.shared_memory.remove_pool(pool_id)
if pool_id in self.pools:
self.pools.pop(pool_id)
if len(self.pools) == 0:
try:
self.time_out = asyncio.get_running_loop().call_later(600, self.close) # noqa
except RuntimeError:
self.time_out = asyncio.new_event_loop().call_later(600, self.close) # noqa
def error(self, error: Exception) -> NoReturn:
self.close()
raise Exception(f'Error occurred in worker: {error}')
def process_queue(self) -> None:
while self.get_from_queue():
pass
def get_from_queue(self) -> bool:
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
code, pool_id, *args = self.queue_out.get(True, 0.02)
if pool_id is None:
getattr(self, code)(*args)
elif pool_id in self.pools:
getattr(self.pools[pool_id], code)(*args)
return True
except multiprocessing.queues.Empty: # noqa
for pool in self.pools.values():
for handle, task in pool.tasks.items(): # retry a task if the process doing it was killed
if task.pid is not None \
and task.pid not in [child.pid for child in multiprocessing.active_children()]:
self.queue_in.put(task)
warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.')
return False
def add_task(self, task: Task) -> None:
""" Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.event.is_set():
while self.queue_in.full():
self.get_from_queue()
self.queue_in.put(task)
self.shared_memory.garbage_collect()
def get_newest_for_pool(self, pool: ParPool) -> tuple[Hashable, Any]:
""" Request the newest key and result and delete its record. Wait if result not yet available. """
while len(pool.tasks):
self.get_from_queue()
for task in pool.tasks.values():
if task.done:
handle, result = task.handle, task.result
pool.tasks.pop(handle)
return handle, result
@classmethod
def close(cls) -> None:
if cls.instance is not None:
instance = cls.instance
cls.instance = None
if instance.time_out is not None:
instance.time_out.cancel()
def empty_queue(queue):
try:
if not queue._closed: # noqa
while not queue.empty():
try:
queue.get(True, 0.02)
except multiprocessing.queues.Empty: # noqa
pass
except OSError:
pass
def close_queue(queue: multiprocessing.queues.Queue) -> None:
empty_queue(queue) # noqa
if not queue._closed: # noqa
queue.close()
queue.join_thread()
if instance.is_alive:
instance.is_alive = False
instance.event.set()
instance.pool.close()
t = time()
while instance.n_workers.value:
empty_queue(instance.queue_in)
empty_queue(instance.queue_out)
if time() - t > 10:
warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.')
instance.pool.terminate()
break
empty_queue(instance.queue_in)
empty_queue(instance.queue_out)
instance.pool.join()
close_queue(instance.queue_in)
close_queue(instance.queue_out)
instance.manager.shutdown()
instance.handle = 0
class Worker:
""" Manages executing the target function which will be executed in different processes. """
nested = False
def __init__(self, shared_memory: SharedMemory, queue_in: multiprocessing.queues.Queue,
queue_out: multiprocessing.queues.Queue, n_workers: multiprocessing.Value,
event: multiprocessing.Event) -> None:
self.shared_memory = shared_memory
self.queue_in = queue_in
self.queue_out = queue_out
self.n_workers = n_workers
self.event = event
def add_to_queue(self, *args: Any) -> None:
while not self.event.is_set():
try:
self.queue_out.put(args, timeout=0.1)
break
except multiprocessing.queues.Full: # noqa
continue
def __call__(self) -> None:
Worker.nested = True
pid = getpid()
last_active_time = time()
while not self.event.is_set() and time() - last_active_time < 600:
try:
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.pool_id, task.handle, pid)
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception: # noqa
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.event.set()
self.shared_memory.garbage_collect()
last_active_time = time()
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa
pass
except Exception: # noqa
self.add_to_queue('error', None, format_exc())
self.event.set()
self.shared_memory.garbage_collect()
for child in multiprocessing.active_children():
child.kill()
with self.n_workers:
self.n_workers.value -= 1

View File

@@ -1,158 +0,0 @@
from __future__ import annotations
import queue
import threading
from os import cpu_count
from typing import Any, Callable, Hashable, NoReturn, Optional
from .common import Bar
class Worker:
nested = False
def __init__(self, *args, **kwargs):
pass
class PoolSingleton:
cpu_count = cpu_count()
def __init__(self, *args, **kwargs):
pass
def close(self):
pass
class Task:
def __init__(self, queue: queue.Queue, handle: Hashable, fun: Callable[[Any, ...], Any], # noqa
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
self.queue = queue
self.handle = handle
self.fun = fun
self.args = args
self.kwargs = {} if kwargs is None else kwargs
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.started = False
self.done = False
self.result = None
def __call__(self):
if not self.done:
self.result = self.fun(*self.args, **self.kwargs)
try:
self.queue.put(self.handle)
except queue.ShutDown:
pass
def __repr__(self) -> str:
if self.done:
return f'Task {self.handle}, result: {self.result}'
else:
return f'Task {self.handle}'
class ParPool:
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(self, fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
self.queue = queue.Queue()
self.handle = 0
self.tasks = {}
self.bar = bar
self.bar_lengths = {}
self.fun = fun
self.args = args
self.kwargs = kwargs
self.n_processes = n_processes or PoolSingleton.cpu_count
self.threads = {}
def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def __enter__(self) -> ParPool:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.close()
def close(self) -> None:
self.queue.shutdown() # noqa python3.13
for thread in self.threads.values():
thread.join()
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
if handle is None:
new_handle = self.handle
self.handle += 1
else:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
task = Task(self.queue, new_handle, fun or self.fun, args or self.args, kwargs or self.kwargs)
while len(self.threads) > self.n_processes:
self.get_from_queue()
thread = threading.Thread(target=task)
thread.start()
self.threads[new_handle] = thread
self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
def __setitem__(self, handle: Hashable, n: Any) -> None:
""" Add new iteration. """
self(n, handle=handle)
def __getitem__(self, handle: Hashable) -> Any:
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError(f'No handle: {handle} in pool')
while not self.tasks[handle].done:
self.get_from_queue()
task = self.tasks.pop(handle)
return task.result
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def get_from_queue(self) -> bool:
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
handle = self.queue.get(True, 0.02)
self.done(handle)
return True
except (queue.Empty, queue.ShutDown):
return False
def get_newest(self) -> Any:
""" Request the newest key and result and delete its record. Wait if result not yet available. """
while len(self.tasks):
self.get_from_queue()
for task in self.tasks.values():
if task.done:
handle, result = task.handle, task.result
self.tasks.pop(handle)
return handle, result
def process_queue(self) -> None:
while self.get_from_queue():
pass
def done(self, handle: Hashable) -> None:
thread = self.threads.pop(handle)
thread.join()
task = self.tasks[handle]
task.done = True
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(handle))

View File

@@ -1,119 +0,0 @@
from __future__ import annotations
import copyreg
from io import BytesIO
from pickle import PicklingError
from typing import Any, Callable
import dill
loads = dill.loads
class CouldNotBePickled:
def __init__(self, class_name: str) -> None:
self.class_name = class_name
def __repr__(self) -> str:
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
@classmethod
def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
return cls, (type(item).__name__,)
class Pickler(dill.Pickler):
""" Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable.
"""
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
""" Copied from pickle and amended. """
self.framer.commit_frame()
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
# Check the memo
x = self.memo.get(id(obj))
if x is not None:
self.write(self.get(x[0]))
return
rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)
if rv is NotImplemented:
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular
# class
if issubclass(t, type):
self.save_global(obj)
return
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
try:
if reduce is not None:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
except Exception: # noqa
rv = CouldNotBePickled.reduce(obj)
# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
try:
self.save_global(obj, rv)
except Exception: # noqa
self.save_global(obj, CouldNotBePickled.reduce(obj))
return
# Assert that reduce() returned a tuple
if not isinstance(rv, tuple):
raise PicklingError("%s must return string or tuple" % reduce)
# Assert that it returned an appropriately sized tuple
length = len(rv)
if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
# Save the reduce() output and finally memoize the object
try:
self.save_reduce(obj=obj, *rv)
except Exception: # noqa
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
**kwds: Any) -> bytes:
"""pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
with BytesIO() as file:
Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue()

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "parfor"
version = "2025.1.0"
version = "2026.1.0"
description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3"
@@ -17,6 +17,10 @@ pytest = { version = "*", optional = true }
[tool.poetry.extras]
test = ["pytest", "numpy"]
[tool.ruff]
line-length = 119
indent-width = 4
[tool.isort]
line_length = 119

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import sys
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import numpy as np
@@ -11,14 +8,6 @@ import pytest
from parfor import Chunks, ParPool, SharedArray, parfor, pmap
try:
if sys._is_gil_enabled(): # noqa
gil = True
else:
gil = False
except Exception: # noqa
gil = True
class SequenceIterator:
def __init__(self, sequence: Sequence) -> None:
@@ -56,7 +45,7 @@ def iterators() -> tuple[Iterator, Optional[int]]:
yield Iterable(range(10)), 10
@pytest.mark.parametrize('iterator', iterators())
@pytest.mark.parametrize("iterator", iterators())
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
chunks = Chunks(iterator[0], size=2, length=iterator[1])
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
@@ -66,7 +55,7 @@ def test_parpool() -> None:
def fun(i, j, k) -> int: # noqa
return i * j * k
with ParPool(fun, (3,), {'k': 2}) as pool: # noqa
with ParPool(fun, (3,), {"k": 2}) as pool: # noqa
for i in range(10):
pool[i] = i
@@ -74,40 +63,66 @@ def test_parpool() -> None:
def test_parfor() -> None:
@parfor(range(10), (3,), {'k': 2})
@parfor(range(10), (3,), {"k": 2})
def fun(i, j, k):
return i * j * k
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
@pytest.mark.parametrize("serial", (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial) == [
0,
6,
12,
18,
24,
30,
36,
42,
48,
54,
]
@pytest.mark.parametrize('serial', (True, False))
@pytest.mark.parametrize("serial", (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial, yield_index=True) == [
(0, 0),
(1, 6),
(2, 12),
(3, 18),
(4, 24),
(5, 30),
(6, 36),
(7, 42),
(8, 48),
(9, 54),
]
@pytest.mark.parametrize('serial', (True, False))
@pytest.mark.parametrize("serial", (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
assert pmap(fun, chunks, (3,), {"k": 2}, serial=serial) == [
[0, 6],
[12, 18],
[24, 30],
[36, 42],
[48, 54],
]
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
def test_id_reuse() -> None:
def fun(i):
return i[0].a
@@ -126,18 +141,6 @@ def test_id_reuse() -> None:
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
@pytest.mark.parametrize('n_processes', (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) == n_processes
def test_shared_array() -> None:
def fun(i, a):
a[i] = i