Files
parfor/parfor/__init__.py

308 lines
12 KiB
Python

from __future__ import annotations
import sys
from contextlib import ExitStack
from functools import wraps
from importlib.metadata import version
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar
from warnings import warn
from tqdm.auto import tqdm
from . import gil, nogil
from .common import Bar, cpu_count
__version__ = version('parfor')
Result = TypeVar('Result')
Iteration = TypeVar('Iteration')
class ParPool:
def __new__(cls, *args, **kwargs):
try:
if not sys._is_gil_enabled(): # noqa
return nogil.ParPool(*args, **kwargs)
except AttributeError:
pass
return gil.ParPool(*args, **kwargs)
def nested():
try:
if not sys._is_gil_enabled(): # noqa
return nogil.Worker.nested
except AttributeError:
pass
return gil.Worker.nested
class Chunks(Iterable):
""" Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
chunks(list0, list1, ..., size=s)
chunks(list0, list1, ..., number=n)
chunks(list0, list1, ..., ratio=r)
size: size of chunks, might change to optimize division between chunks
number: number of chunks, coerced to 1 <= n <= len(list0)
ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0)
both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size
both ratio and number are given: use ratio
"""
def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None,
ratio: float = None, length: int = None) -> None:
if length is None:
try:
length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
except TypeError:
raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an'
' argument.')
if size is not None and (number is not None or ratio is not None):
if number is None:
number = int(cpu_count * ratio)
if length >= size * number:
number = round(length / size)
elif size is not None: # size of chunks
number = round(length / size)
elif ratio is not None: # number of chunks
number = int(cpu_count * ratio)
self.iterators = [iter(arg) for arg in iterables]
self.number_of_items = length
self.length = min(length, number)
self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
for i in range(self.length)]
def __iter__(self) -> Iterator[Any]:
for i in range(self.length):
p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length)
if len(self.iterators) == 1:
yield [next(self.iterators[0]) for _ in range(q - p)]
else:
yield [[next(iterator) for _ in range(q - p)] for iterator in self.iterators]
def __len__(self) -> int:
return self.length
class ExternalBar(Iterable):
def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None:
self.iterable = iterable
self.callback = callback
self.total = total
self._n = 0
def __enter__(self) -> ExternalBar:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
def __iter__(self) -> Iterator[Any]:
for n, item in enumerate(self.iterable):
yield item
self.n = n + 1
def update(self, n: int = 1) -> None:
self.n += n
@property
def n(self) -> int:
return self._n
@n.setter
def n(self, n: int) -> None:
if n != self._n:
self._n = n
if self.callback is not None:
self.callback(n)
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Result, None, None]:
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
best use: iterable is a generator and length is given to this function as 'total'
required:
fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs
iterable: iterable or iterator from which an item is given to fun as a first argument
optional:
args: tuple with other unnamed arguments to fun
kwargs: dict with other named arguments to fun
total: give the length of the iterator in cases where len(iterator) results in an error
desc: string with description of the progress bar
bar: bool enable progress bar,
or a callback function taking the number of passed iterations as an argument
serial: execute in series instead of parallel if True, None (default): let pmap decide
length: deprecated alias for total
n_processes: number of processes to use,
the parallel pool will be restarted if the current pool does not have the right number of processes
yield_ordered: return the result in the same order as the iterable
yield_index: return the index of the result too
**bar_kwargs: keywords arguments for tqdm.tqdm
output:
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
of the iterable / iterator
examples:
<< from time import sleep
<<
@parfor(range(10), (3,))
def fun(i, a):
sleep(1)
return a * i ** 2
fun
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
<<
def fun(i, a):
sleep(1)
return a * i ** 2
pmap(fun, range(10), (3,))
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
equivalent to using the deco module:
<<
@concurrent
def fun(i, a):
time.sleep(1)
return a * i ** 2
@synchronized
def run(iterator, a):
res = []
for i in iterator:
res.append(fun(i, a))
return res
run(range(10), 3)
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
all equivalent to the serial for-loop:
<<
a = 3
fun = []
for i in range(10):
sleep(1)
fun.append(a * i ** 2)
fun
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
"""
if total is None and length is not None:
total = length
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=2)
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=3)
if terminator is not None:
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
DeprecationWarning, stacklevel=2)
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
DeprecationWarning, stacklevel=3)
is_chunked = isinstance(iterable, Chunks)
if is_chunked:
chunk_fun = fun
else:
iterable = Chunks(iterable, ratio=5, length=total)
@wraps(fun)
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
return [fun(iteration, *args, **kwargs) for iteration in iterable]
args = args or ()
kwargs = kwargs or {}
if 'total' not in bar_kwargs:
bar_kwargs['total'] = sum(iterable.lengths)
if 'desc' not in bar_kwargs:
bar_kwargs['desc'] = desc
if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or nested(): # serial case
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk
b.update(length)
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
else tqdm_chunks(iterable, **bar_kwargs))
if is_chunked:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)
else:
for c in iterable:
yield chunk_fun(c, *args, **kwargs)
else:
if yield_index:
for i, c in enumerate(iterable):
for q in chunk_fun(c, *args, **kwargs):
yield i, q
else:
for c in iterable:
yield from chunk_fun(c, *args, **kwargs)
else: # parallel case
with ExitStack() as stack:
if callable(bar):
bar = stack.enter_context(ExternalBar(callback=bar))
else:
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=l)
if bar.total is None or bar.total < i + 1:
bar.total = i + 1
if is_chunked:
if yield_ordered:
if yield_index:
for i in range(len(iterable)):
yield i, p[i]
else:
for i in range(len(iterable)):
yield p[i]
else:
if yield_index:
for _ in range(len(iterable)):
yield p.get_newest()
else:
for _ in range(len(iterable)):
yield p.get_newest()[1]
else:
if yield_ordered:
if yield_index:
for i in range(len(iterable)):
for q in p[i]:
yield i, q
else:
for i in range(len(iterable)):
yield from p[i]
else:
if yield_index:
for _ in range(len(iterable)):
i, n = p.get_newest()
for q in n:
yield i, q
else:
for _ in range(len(iterable)):
yield from p.get_newest()[1]
@wraps(gmap)
def pmap(*args, **kwargs) -> list[Result]:
return list(gmap(*args, **kwargs)) # type: ignore
@wraps(gmap)
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
return pmap(fun, *args, **kwargs)
return decfun