- fix serial computation

- some more tests
This commit is contained in:
Wim Pomp
2024-09-11 13:08:03 +02:00
parent 4d80316244
commit fb7757828f
4 changed files with 52 additions and 13 deletions

View File

@@ -71,11 +71,30 @@ def test_parfor() -> None:
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_pmap() -> None:
@pytest.mark.parametrize('serial', (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
def test_id_reuse() -> None: