from __future__ import annotations from dataclasses import dataclass from typing import Any, Iterator, Optional, Sequence import numpy as np import pytest from parfor import Chunks, ParPool, SharedArray, parfor, pmap class SequenceIterator: def __init__(self, sequence: Sequence) -> None: self._sequence = sequence self._index = 0 def __iter__(self) -> SequenceIterator: return self def __next__(self) -> Any: if self._index < len(self._sequence): item = self._sequence[self._index] self._index += 1 return item else: raise StopIteration def __len__(self) -> int: return len(self._sequence) class Iterable: def __init__(self, sequence: Sequence) -> None: self.sequence = sequence def __iter__(self) -> SequenceIterator: return SequenceIterator(self.sequence) def iterators() -> tuple[Iterator, Optional[int]]: yield range(10), None yield list(range(10)), None yield (i for i in range(10)), 10 yield SequenceIterator(range(10)), None yield Iterable(range(10)), 10 @pytest.mark.parametrize("iterator", iterators()) def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None: chunks = Chunks(iterator[0], size=2, length=iterator[1]) assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] def test_parpool() -> None: def fun(i, j, k) -> int: # noqa return i * j * k with ParPool(fun, (3,), {"k": 2}) as pool: # noqa for i in range(10): pool[i] = i assert [pool[i] for i in range(10)] == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] def test_parfor() -> None: @parfor(range(10), (3,), {"k": 2}) def fun(i, j, k): return i * j * k assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] @pytest.mark.parametrize("serial", (True, False)) def test_pmap(serial) -> None: def fun(i, j, k): return i * j * k assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial) == [ 0, 6, 12, 18, 24, 30, 36, 42, 48, 54, ] @pytest.mark.parametrize("serial", (True, False)) def test_pmap_with_idx(serial) -> None: def fun(i, j, k): return i * j * k assert pmap(fun, range(10), (3,), {"k": 2}, serial=serial, yield_index=True) == [ (0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54), ] @pytest.mark.parametrize("serial", (True, False)) def test_pmap_chunks(serial) -> None: def fun(i, j, k): return [i_ * j * k for i_ in i] chunks = Chunks(range(10), size=2) assert pmap(fun, chunks, (3,), {"k": 2}, serial=serial) == [ [0, 6], [12, 18], [24, 30], [36, 42], [48, 54], ] def test_id_reuse() -> None: def fun(i): return i[0].a @dataclass class T: a: int = 3 def gen(total): for i in range(total): t = T(i) yield t del t a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000) # noqa assert all([i == j for i, j in enumerate(a)]) def test_shared_array() -> None: def fun(i, a): a[i] = i with SharedArray(100, int) as arr: pmap(fun, range(len(arr)), (arr,)) b = np.array(arr) assert np.all(b == np.arange(len(arr)))