- fix serial computation

- some more tests
This commit is contained in:
Wim Pomp
2024-09-11 13:08:03 +02:00
parent 4d80316244
commit fb7757828f
4 changed files with 52 additions and 13 deletions

View File

@@ -170,5 +170,5 @@ Split a long iterator in bite-sized chunks to parallelize
## `ParPool`
More low-level accessibility to parallel execution. Submit tasks and request the result at any time,
(although necessarily submit first, then request a specific task), use different functions and function
(although to avoid breaking causality, submit first, then request), use different functions and function
arguments for different tasks.

View File

@@ -655,10 +655,30 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
if callable(bar):
return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], [])
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk
b.update(length)
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
else tqdm_chunks(iterable, **bar_kwargs))
if is_chunked:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)
else:
return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], [])
for c in iterable:
yield chunk_fun(c, *args, **kwargs)
else:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)[0]
else:
for c in iterable:
yield chunk_fun(c, *args, **kwargs)[0]
else: # parallel case
with ExitStack() as stack:
if callable(bar):

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "parfor"
version = "2024.9.0"
version = "2024.9.1"
description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3"

View File

@@ -71,11 +71,30 @@ def test_parfor() -> None:
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_pmap() -> None:
@pytest.mark.parametrize('serial', (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
def test_id_reuse() -> None: