Files
parfor/tests/test_parfor.py
Wim Pomp 326a06f9da - add SharedArray to share numpy arrays
- fix async call_later when no event loop is available
2025-01-06 18:01:55 +01:00

150 lines
3.8 KiB
Python

from __future__ import annotations
import sys
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import numpy as np
import pytest
from parfor import Chunks, ParPool, SharedArray, parfor, pmap
try:
if sys._is_gil_enabled(): # noqa
gil = True
else:
gil = False
except Exception: # noqa
gil = True
class SequenceIterator:
def __init__(self, sequence: Sequence) -> None:
self._sequence = sequence
self._index = 0
def __iter__(self) -> SequenceIterator:
return self
def __next__(self) -> Any:
if self._index < len(self._sequence):
item = self._sequence[self._index]
self._index += 1
return item
else:
raise StopIteration
def __len__(self) -> int:
return len(self._sequence)
class Iterable:
def __init__(self, sequence: Sequence) -> None:
self.sequence = sequence
def __iter__(self) -> SequenceIterator:
return SequenceIterator(self.sequence)
def iterators() -> tuple[Iterator, Optional[int]]:
yield range(10), None
yield list(range(10)), None
yield (i for i in range(10)), 10
yield SequenceIterator(range(10)), None
yield Iterable(range(10)), 10
@pytest.mark.parametrize('iterator', iterators())
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
chunks = Chunks(iterator[0], size=2, length=iterator[1])
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
def test_parpool() -> None:
def fun(i, j, k) -> int: # noqa
return i * j * k
with ParPool(fun, (3,), {'k': 2}) as pool: # noqa
for i in range(10):
pool[i] = i
assert [pool[i] for i in range(10)] == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_parfor() -> None:
@parfor(range(10), (3,), {'k': 2})
def fun(i, j, k):
return i * j * k
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap(serial) -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}, serial=serial) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_with_idx(serial) -> None:
def fun(i, j, k):
return i * j * k
assert (pmap(fun, range(10), (3,), {'k': 2}, serial=serial, yield_index=True) ==
[(0, 0), (1, 6), (2, 12), (3, 18), (4, 24), (5, 30), (6, 36), (7, 42), (8, 48), (9, 54)])
@pytest.mark.parametrize('serial', (True, False))
def test_pmap_chunks(serial) -> None:
def fun(i, j, k):
return [i_ * j * k for i_ in i]
chunks = Chunks(range(10), size=2)
assert pmap(fun, chunks, (3,), {'k': 2}, serial=serial) == [[0, 6], [12, 18], [24, 30], [36, 42], [48, 54]]
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
def test_id_reuse() -> None:
def fun(i):
return i[0].a
@dataclass
class T:
a: int = 3
def gen(total):
for i in range(total):
t = T(i)
yield t
del t
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000) # noqa
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.skipif(not gil, reason='test if gil enabled only')
@pytest.mark.parametrize('n_processes', (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) == n_processes
def test_shared_array() -> None:
def fun(i, a):
a[i] = i
with SharedArray(100, int) as arr:
pmap(fun, range(len(arr)), (arr,))
b = np.array(arr)
assert np.all(b == np.arange(len(arr)))