- fix serial computation

- some more tests
This commit is contained in:
Wim Pomp
2024-09-11 13:08:03 +02:00
parent 4d80316244
commit fb7757828f
4 changed files with 52 additions and 13 deletions

View File

@@ -170,5 +170,5 @@ Split a long iterator in bite-sized chunks to parallelize
## `ParPool` ## `ParPool`
More low-level accessibility to parallel execution. Submit tasks and request the result at any time, More low-level accessibility to parallel execution. Submit tasks and request the result at any time,
(although necessarily submit first, then request a specific task), use different functions and function (although to avoid breaking causality, submit first, then request), use different functions and function
arguments for different tasks. arguments for different tasks.

View File

@@ -655,10 +655,30 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
if 'disable' not in bar_kwargs: if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
if callable(bar):
return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], []) def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk
b.update(length)
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
else tqdm_chunks(iterable, **bar_kwargs))
if is_chunked:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)
else: else:
return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], []) for c in iterable:
yield chunk_fun(c, *args, **kwargs)
else:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)[0]
else:
for c in iterable:
yield chunk_fun(c, *args, **kwargs)[0]
else: # parallel case else: # parallel case
with ExitStack() as stack: with ExitStack() as stack:
if callable(bar): if callable(bar):

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2024.9.0" version = "2024.9.1"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"

View File

@@ -71,11 +71,30 @@ def test_parfor() -> None:
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_pmap() -> None: @pytest.mark.parametrize('serial', (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k): def fun(i, j, k):
return i * j * k return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
def test_id_reuse() -> None: def test_id_reuse() -> None: