Files
parfor/tests/test_parfor.py
Wim Pomp 9fdbf49b5c - n_processes can now be changed each time
- block adding tasks if pool is busy
2026-01-09 11:37:55 +01:00

175 lines
3.9 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import numpy as np
import pytest
from parfor import Chunks, ParPool, SharedArray, parfor, pmap
class SequenceIterator:
def __init__(self, sequence: Sequence) -> None:
self._sequence = sequence
self._index = 0
def __iter__(self) -> SequenceIterator:
return self
def __next__(self) -> Any:
if self._index < len(self._sequence):
item = self._sequence[self._index]
self._index += 1
return item
else:
raise StopIteration
def __len__(self) -> int:
return len(self._sequence)
class Iterable:
def __init__(self, sequence: Sequence) -> None:
self.sequence = sequence
def __iter__(self) -> SequenceIterator:
return SequenceIterator(self.sequence)
def iterators() -> tuple[Iterator, Optional[int]]:
yield range(10), None
yield list(range(10)), None
yield (i for i in range(10)), 10
yield SequenceIterator(range(10)), None
yield Iterable(range(10)), 10
@pytest.mark.parametrize("iterator", iterators())
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
chunks = Chunks(iterator[0], size=2, length=iterator[1])
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
def test_parpool() -> None:
def fun(i, j, k) -> int: # noqa
return i * j * k
with ParPool(fun, (3,), {"k": 2}) as pool: # noqa
for i in range(10):
pool[i] = i
assert [pool[i] for i in range(10)] == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_parfor() -> None:
@parfor(range(10), (3,), {"k": 2})
def fun(i, j, k):
return i * j * k
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize("serial", (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial) == [
0,
6,
12,
18,
24,
30,
36,
42,
48,
54,
]
@pytest.mark.parametrize("serial", (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial, yield_index=True) == [
(0, 0),
(1, 6),
(2, 12),
(3, 18),
(4, 24),
(5, 30),
(6, 36),
(7, 42),
(8, 48),
(9, 54),
]
@pytest.mark.parametrize("serial", (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {"k": 2}, serial=serial) == [
[0, 6],
[12, 18],
[24, 30],
[36, 42],
[48, 54],
]
def test_id_reuse() -> None:
def fun(i):
return i[0].a
@dataclass
class T:
a: int = 3
def gen(total):
for i in range(total):
t = T(i)
yield t
del t
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000) # noqa
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.parametrize("n_processes", (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) <= n_processes
def test_shared_array() -> None:
def fun(i, a):
a[i] = i
with SharedArray(100, int) as arr:
pmap(fun, range(len(arr)), (arr,))
b = np.array(arr)
assert np.all(b == np.arange(len(arr)))
def test_nesting() -> None:
def a(i):
return i**2
def b(i):
return pmap(a, range(i, i + 50))
assert pmap(b, range(10)) == [[i**2 for i in range(j, j + 50)] for j in range(10)]