- add gmap: function like pmap, but returning a generator instead of a list

- add arguments for returning results out/in order and returning result indices
This commit is contained in:
Wim Pomp
2024-09-05 18:37:47 +02:00
parent 29203dd128
commit 4d80316244
4 changed files with 58 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ from importlib.metadata import version
from os import devnull, getpid
from time import time
from traceback import format_exc
from typing import Any, Callable, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar
from typing import Any, Callable, Generator, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar
from warnings import warn
from tqdm.auto import tqdm
@@ -155,7 +155,7 @@ class Chunks(Iterable):
if len(self.iterators) == 1:
yield [next(self.iterators[0]) for _ in range(q - p)]
else:
yield [[next(iterator) for _ in range(q-p)] for iterator in self.iterators]
yield [[next(iterator) for _ in range(q - p)] for iterator in self.iterators]
def __len__(self) -> int:
return self.length
@@ -551,10 +551,11 @@ class Worker:
self.n_workers.value -= 1
def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, **bar_kwargs: Any) -> list[Result]:
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Result, None, None]:
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
@@ -574,10 +575,13 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
length: deprecated alias for total
n_processes: number of processes to use,
the parallel pool will be restarted if the current pool does not have the right number of processes
yield_ordered: return the result in the same order as the iterable
yield_index: return the index of the result too
**bar_kwargs: keywords arguments for tqdm.tqdm
output:
list with results from applying the function \'fun\' to each iteration of the iterable / iterator
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
of the iterable / iterator
examples:
<< from time import sleep
@@ -663,16 +667,49 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=iterable.lengths[i])
if bar.total is None or bar.total < i+1:
bar.total = i+1
p(j, handle=i, barlength=l)
if bar.total is None or bar.total < i + 1:
bar.total = i + 1
if is_chunked:
return [p[i] for i in range(len(iterable))]
if yield_ordered:
if yield_index:
for i in range(len(iterable)):
yield i, p[i]
else:
for i in range(len(iterable)):
yield p[i]
else:
if yield_index:
for _ in range(len(iterable)):
yield p.get_newest()
else:
for _ in range(len(iterable)):
yield p.get_newest()[1]
else:
return sum([p[i] for i in range(len(iterable))], []) # collect the results
if yield_ordered:
if yield_index:
for i in range(len(iterable)):
yield i, p[i][0]
else:
for i in range(len(iterable)):
yield p[i][0]
else:
if yield_index:
for _ in range(len(iterable)):
i, n = p.get_newest()
yield i, n[0]
else:
for _ in range(len(iterable)):
yield p.get_newest()[1][0]
@wraps(pmap)
@wraps(gmap)
def pmap(*args, **kwargs) -> list[Result]:
return list(gmap(*args, **kwargs)) # type: ignore
@wraps(gmap)
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
return pmap(fun, *args, **kwargs)