- fix serial computation

- some more tests
This commit is contained in:
Wim Pomp
2024-09-11 13:08:03 +02:00
parent 4d80316244
commit fb7757828f
4 changed files with 52 additions and 13 deletions

View File

@@ -72,7 +72,7 @@ iterations need to be dillable. You might be able to make objects dillable anyho
fun = []
for i in range(10):
sleep(1)
fun.append(a*i**2)
fun.append(a * i ** 2)
print(fun)
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
@@ -84,7 +84,7 @@ iterations need to be dillable. You might be able to make objects dillable anyho
@parfor(range(10), (3,))
def fun(i, a):
sleep(1)
return a*i**2
return a * i ** 2
print(fun)
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
@@ -93,7 +93,7 @@ iterations need to be dillable. You might be able to make objects dillable anyho
@parfor(range(10), (3,), bar=False)
def fun(i, a):
sleep(1)
return a*i**2
return a * i ** 2
print(fun)
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
@@ -110,7 +110,7 @@ use the `if __name__ == '__main__':` structure:
@parfor(range(10), (3,))
def fun(i, a):
sleep(1)
return a*i**2
return a * i ** 2
print(fun)
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
@@ -124,7 +124,7 @@ or:
@parfor(range(10), (3,))
def fun(i, a):
sleep(1)
return a*i**2
return a * i ** 2
return fun
if __name__ == '__main__':
@@ -140,7 +140,7 @@ pmap maps an iterator to a function like map does, but in parallel
from time import sleep
def fun(i, a):
sleep(1)
return a*i**2
return a * i ** 2
print(pmap(fun, range(10), (3,)))
>> [0, 3, 12, 27, 48, 75, 108, 147, 192, 243]
@@ -170,5 +170,5 @@ Split a long iterator in bite-sized chunks to parallelize
## `ParPool`
More low-level accessibility to parallel execution. Submit tasks and request the result at any time,
(although necessarily submit first, then request a specific task), use different functions and function
(although to avoid breaking causality, submit first, then request), use different functions and function
arguments for different tasks.

View File

@@ -655,10 +655,30 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
if callable(bar):
return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], [])
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk
b.update(length)
iterable = (ExternalBar(iterable, bar, sum(iterable.lengths)) if callable(bar)
else tqdm_chunks(iterable, **bar_kwargs))
if is_chunked:
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)
else:
for c in iterable:
yield chunk_fun(c, *args, **kwargs)
else:
return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], [])
if yield_index:
for i, c in enumerate(iterable):
yield i, chunk_fun(c, *args, **kwargs)[0]
else:
for c in iterable:
yield chunk_fun(c, *args, **kwargs)[0]
else: # parallel case
with ExitStack() as stack:
if callable(bar):

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "parfor"
version = "2024.9.0"
version = "2024.9.1"
description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3"

View File

@@ -71,11 +71,30 @@ def test_parfor() -> None:
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_pmap() -> None:
@pytest.mark.parametrize('serial', (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
def test_id_reuse() -> None: