- add gmap: function like pmap, but returning a generator instead of a list
- add arguments for returning results out/in order and returning result indices
This commit is contained in:
@@ -8,7 +8,7 @@ from importlib.metadata import version
|
||||
from os import devnull, getpid
|
||||
from time import time
|
||||
from traceback import format_exc
|
||||
from typing import Any, Callable, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar
|
||||
from typing import Any, Callable, Generator, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -155,7 +155,7 @@ class Chunks(Iterable):
|
||||
if len(self.iterators) == 1:
|
||||
yield [next(self.iterators[0]) for _ in range(q - p)]
|
||||
else:
|
||||
yield [[next(iterator) for _ in range(q-p)] for iterator in self.iterators]
|
||||
yield [[next(iterator) for _ in range(q - p)] for iterator in self.iterators]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.length
|
||||
@@ -551,10 +551,11 @@ class Worker:
|
||||
self.n_workers.value -= 1
|
||||
|
||||
|
||||
def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
|
||||
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
|
||||
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
||||
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
||||
n_processes: int = None, **bar_kwargs: Any) -> list[Result]:
|
||||
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
||||
**bar_kwargs: Any) -> Generator[Result, None, None]:
|
||||
""" map a function fun to each iteration in iterable
|
||||
use as a function: pmap
|
||||
use as a decorator: parfor
|
||||
@@ -574,10 +575,13 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
||||
length: deprecated alias for total
|
||||
n_processes: number of processes to use,
|
||||
the parallel pool will be restarted if the current pool does not have the right number of processes
|
||||
yield_ordered: return the result in the same order as the iterable
|
||||
yield_index: return the index of the result too
|
||||
**bar_kwargs: keywords arguments for tqdm.tqdm
|
||||
|
||||
output:
|
||||
list with results from applying the function \'fun\' to each iteration of the iterable / iterator
|
||||
list (pmap) or generator (gmap) with results from applying the function \'fun\' to each iteration
|
||||
of the iterable / iterator
|
||||
|
||||
examples:
|
||||
<< from time import sleep
|
||||
@@ -663,16 +667,49 @@ def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
||||
bar = stack.enter_context(tqdm(**bar_kwargs))
|
||||
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
|
||||
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
|
||||
p(j, handle=i, barlength=iterable.lengths[i])
|
||||
if bar.total is None or bar.total < i+1:
|
||||
bar.total = i+1
|
||||
p(j, handle=i, barlength=l)
|
||||
if bar.total is None or bar.total < i + 1:
|
||||
bar.total = i + 1
|
||||
|
||||
if is_chunked:
|
||||
return [p[i] for i in range(len(iterable))]
|
||||
if yield_ordered:
|
||||
if yield_index:
|
||||
for i in range(len(iterable)):
|
||||
yield i, p[i]
|
||||
else:
|
||||
for i in range(len(iterable)):
|
||||
yield p[i]
|
||||
else:
|
||||
if yield_index:
|
||||
for _ in range(len(iterable)):
|
||||
yield p.get_newest()
|
||||
else:
|
||||
for _ in range(len(iterable)):
|
||||
yield p.get_newest()[1]
|
||||
else:
|
||||
return sum([p[i] for i in range(len(iterable))], []) # collect the results
|
||||
if yield_ordered:
|
||||
if yield_index:
|
||||
for i in range(len(iterable)):
|
||||
yield i, p[i][0]
|
||||
else:
|
||||
for i in range(len(iterable)):
|
||||
yield p[i][0]
|
||||
else:
|
||||
if yield_index:
|
||||
for _ in range(len(iterable)):
|
||||
i, n = p.get_newest()
|
||||
yield i, n[0]
|
||||
else:
|
||||
for _ in range(len(iterable)):
|
||||
yield p.get_newest()[1][0]
|
||||
|
||||
|
||||
@wraps(pmap)
|
||||
@wraps(gmap)
|
||||
def pmap(*args, **kwargs) -> list[Result]:
|
||||
return list(gmap(*args, **kwargs)) # type: ignore
|
||||
|
||||
|
||||
@wraps(gmap)
|
||||
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
|
||||
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
|
||||
return pmap(fun, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user