308 lines
12 KiB
Python
308 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
from contextlib import ExitStack
|
|
from functools import wraps
|
|
from importlib.metadata import version
|
|
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar
|
|
from warnings import warn
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
from . import gil, nogil
|
|
from .common import Bar, cpu_count
|
|
|
|
__version__ = version('parfor')
|
|
|
|
|
|
Result = TypeVar('Result')
|
|
Iteration = TypeVar('Iteration')
|
|
|
|
|
|
class ParPool:
|
|
def __new__(cls, *args, **kwargs):
|
|
try:
|
|
if not sys._is_gil_enabled(): # noqa
|
|
return nogil.ParPool(*args, **kwargs)
|
|
except AttributeError:
|
|
pass
|
|
return gil.ParPool(*args, **kwargs)
|
|
|
|
|
|
def nested():
|
|
try:
|
|
if not sys._is_gil_enabled(): # noqa
|
|
return nogil.Worker.nested
|
|
except AttributeError:
|
|
pass
|
|
return gil.Worker.nested
|
|
|
|
|
|
class Chunks(Iterable):
|
|
""" Yield successive chunks from lists.
|
|
Usage: chunks(list0, list1, ...)
|
|
chunks(list0, list1, ..., size=s)
|
|
chunks(list0, list1, ..., number=n)
|
|
chunks(list0, list1, ..., ratio=r)
|
|
size: size of chunks, might change to optimize division between chunks
|
|
number: number of chunks, coerced to 1 <= n <= len(list0)
|
|
ratio: number of chunks / number of cpus, coerced to 1 <= n <= len(list0)
|
|
both size and number or ratio are given: use number or ratio, unless the chunk size would be bigger than size
|
|
both ratio and number are given: use ratio
|
|
"""
|
|
|
|
def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None,
|
|
ratio: float = None, length: int = None) -> None:
|
|
if length is None:
|
|
try:
|
|
length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
|
|
except TypeError:
|
|
raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an'
|
|
' argument.')
|
|
if size is not None and (number is not None or ratio is not None):
|
|
if number is None:
|
|
number = int(cpu_count * ratio)
|
|
if length >= size * number:
|
|
number = round(length / size)
|
|
elif size is not None: # size of chunks
|
|
number = round(length / size)
|
|
elif ratio is not None: # number of chunks
|
|
number = int(cpu_count * ratio)
|
|
self.iterators = [iter(arg) for arg in iterables]
|
|
self.number_of_items = length
|
|
self.length = min(length, number)
|
|
self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
|
|
for i in range(self.length)]
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
for i in range(self.length):
|
|
p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length)
|
|
if len(self.iterators) == 1:
|
|
yield [next(self.iterators[0]) for _ in range(q - p)]
|
|
else:
|
|
yield [[next(iterator) for _ in range(q - p)] for iterator in self.iterators]
|
|
|
|
def __len__(self) -> int:
|
|
return self.length
|
|
|
|
|
|
class ExternalBar(Iterable):
|
|
def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None:
|
|
self.iterable = iterable
|
|
self.callback = callback
|
|
self.total = total
|
|
self._n = 0
|
|
|
|
def __enter__(self) -> ExternalBar:
|
|
return self
|
|
|
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
|
return
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
for n, item in enumerate(self.iterable):
|
|
yield item
|
|
self.n = n + 1
|
|
|
|
def update(self, n: int = 1) -> None:
|
|
self.n += n
|
|
|
|
@property
|
|
def n(self) -> int:
|
|
return self._n
|
|
|
|
@n.setter
|
|
def n(self, n: int) -> None:
|
|
if n != self._n:
|
|
self._n = n
|
|
if self.callback is not None:
|
|
self.callback(n)
|
|
|
|
|
|
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
|
|
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
|
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
|
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
|
**bar_kwargs: Any) -> Generator[Result, None, None]:
|
|
""" map a function fun to each iteration in iterable
|
|
use as a function: pmap
|
|
use as a decorator: parfor
|
|
best use: iterable is a generator and length is given to this function as 'total'
|
|
|
|
required:
|
|
fun: function taking arguments: iteration from iterable, other arguments defined in args & kwargs
|
|
iterable: iterable or iterator from which an item is given to fun as a first argument
|
|
optional:
|
|
args: tuple with other unnamed arguments to fun
|
|
kwargs: dict with other named arguments to fun
|
|
total: give the length of the iterator in cases where len(iterator) results in an error
|
|
desc: string with description of the progress bar
|
|
bar: bool enable progress bar,
|
|
or a callback function taking the number of passed iterations as an argument
|
|
serial: execute in series instead of parallel if True, None (default): let pmap decide
|
|
length: deprecated alias for total
|
|
n_processes: number of processes to use,
|
|
the parallel pool will be restarted if the current pool does not have the right number of processes
|
|
yield_ordered: return the result in the same order as the iterable
|
|
yield_index: return the index of the result too
|
|
**bar_kwargs: keywords arguments for tqdm.tqdm
|
|
|
|
output:
|
|
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
|
|
of the iterable / iterator
|
|
|
|
examples:
|
|
<< from time import sleep
|
|
<<
|
|
@parfor(range(10), (3,))
|
|
def fun(i, a):
|
|
sleep(1)
|
|
return a * i ** 2
|
|
fun
|
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
|
|
<<
|
|
def fun(i, a):
|
|
sleep(1)
|
|
return a * i ** 2
|
|
pmap(fun, range(10), (3,))
|
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
|
|
equivalent to using the deco module:
|
|
<<
|
|
@concurrent
|
|
def fun(i, a):
|
|
time.sleep(1)
|
|
return a * i ** 2
|
|
|
|
@synchronized
|
|
def run(iterator, a):
|
|
res = []
|
|
for i in iterator:
|
|
res.append(fun(i, a))
|
|
return res
|
|
run(range(10), 3)
|
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
|
|
all equivalent to the serial for-loop:
|
|
<<
|
|
a = 3
|
|
fun = []
|
|
for i in range(10):
|
|
sleep(1)
|
|
fun.append(a * i ** 2)
|
|
fun
|
|
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
|
|
"""
|
|
if total is None and length is not None:
|
|
total = length
|
|
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=2)
|
|
warn('parfor: use of \'length\' is deprecated, use \'total\' instead', DeprecationWarning, stacklevel=3)
|
|
if terminator is not None:
|
|
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
|
|
DeprecationWarning, stacklevel=2)
|
|
warn('parfor: use of \'terminator\' is deprecated, workers are terminated automatically',
|
|
DeprecationWarning, stacklevel=3)
|
|
is_chunked = isinstance(iterable, Chunks)
|
|
if is_chunked:
|
|
chunk_fun = fun
|
|
else:
|
|
iterable = Chunks(iterable, ratio=5, length=total)
|
|
|
|
@wraps(fun)
|
|
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
|
|
return [fun(iteration, *args, **kwargs) for iteration in iterable]
|
|
|
|
args = args or ()
|
|
kwargs = kwargs or {}
|
|
|
|
if 'total' not in bar_kwargs:
|
|
bar_kwargs['total'] = sum(iterable.lengths)
|
|
if 'desc' not in bar_kwargs:
|
|
bar_kwargs['desc'] = desc
|
|
if 'disable' not in bar_kwargs:
|
|
bar_kwargs['disable'] = not bar
|
|
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or nested(): # serial case
|
|
|
|
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
|
|
with tqdm(*args, **kwargs) as b:
|
|
for chunk, length in zip(chunks, chunks.lengths): # noqa
|
|
yield chunk
|
|
b.update(length)
|
|
|
|
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
|
|
else tqdm_chunks(iterable, **bar_kwargs))
|
|
if is_chunked:
|
|
if yield_index:
|
|
for i, c in enumerate(iterable):
|
|
yield i, chunk_fun(c, *args, **kwargs)
|
|
else:
|
|
for c in iterable:
|
|
yield chunk_fun(c, *args, **kwargs)
|
|
else:
|
|
if yield_index:
|
|
for i, c in enumerate(iterable):
|
|
for q in chunk_fun(c, *args, **kwargs):
|
|
yield i, q
|
|
else:
|
|
for c in iterable:
|
|
yield from chunk_fun(c, *args, **kwargs)
|
|
|
|
else: # parallel case
|
|
with ExitStack() as stack:
|
|
if callable(bar):
|
|
bar = stack.enter_context(ExternalBar(callback=bar))
|
|
else:
|
|
bar = stack.enter_context(tqdm(**bar_kwargs))
|
|
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
|
|
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
|
|
p(j, handle=i, barlength=l)
|
|
if bar.total is None or bar.total < i + 1:
|
|
bar.total = i + 1
|
|
|
|
if is_chunked:
|
|
if yield_ordered:
|
|
if yield_index:
|
|
for i in range(len(iterable)):
|
|
yield i, p[i]
|
|
else:
|
|
for i in range(len(iterable)):
|
|
yield p[i]
|
|
else:
|
|
if yield_index:
|
|
for _ in range(len(iterable)):
|
|
yield p.get_newest()
|
|
else:
|
|
for _ in range(len(iterable)):
|
|
yield p.get_newest()[1]
|
|
else:
|
|
if yield_ordered:
|
|
if yield_index:
|
|
for i in range(len(iterable)):
|
|
for q in p[i]:
|
|
yield i, q
|
|
else:
|
|
for i in range(len(iterable)):
|
|
yield from p[i]
|
|
else:
|
|
if yield_index:
|
|
for _ in range(len(iterable)):
|
|
i, n = p.get_newest()
|
|
for q in n:
|
|
yield i, q
|
|
else:
|
|
for _ in range(len(iterable)):
|
|
yield from p.get_newest()[1]
|
|
|
|
|
|
@wraps(gmap)
|
|
def pmap(*args, **kwargs) -> list[Result]:
|
|
return list(gmap(*args, **kwargs)) # type: ignore
|
|
|
|
|
|
@wraps(gmap)
|
|
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
|
|
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
|
|
return pmap(fun, *args, **kwargs)
|
|
return decfun
|