- introduce n_processes to change the number of processes in the pool

This commit is contained in:
Wim Pomp
2024-05-24 16:57:35 +02:00
parent ac4d599646
commit 9783c1d1f2
5 changed files with 62 additions and 34 deletions

View File

@@ -1,4 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from os import getpid
from time import sleep
from typing import Any, Iterator, Optional, Sequence
import pytest
@@ -6,14 +11,14 @@ from parfor import Chunks, ParPool, parfor, pmap
class SequenceIterator:
def __init__(self, sequence):
def __init__(self, sequence: Sequence) -> None:
self._sequence = sequence
self._index = 0
def __iter__(self):
def __iter__(self) -> SequenceIterator:
return self
def __next__(self):
def __next__(self) -> Any:
if self._index < len(self._sequence):
item = self._sequence[self._index]
self._index += 1
@@ -21,19 +26,19 @@ class SequenceIterator:
else:
raise StopIteration
def __len__(self):
def __len__(self) -> int:
return len(self._sequence)
class Iterable:
def __init__(self, sequence):
def __init__(self, sequence: Sequence) -> None:
self.sequence = sequence
def __iter__(self):
def __iter__(self) -> SequenceIterator:
return SequenceIterator(self.sequence)
def iterators():
def iterators() -> tuple[Iterator, Optional[int]]:
yield range(10), None
yield list(range(10)), None
yield (i for i in range(10)), 10
@@ -42,23 +47,23 @@ def iterators():
@pytest.mark.parametrize('iterator', iterators())
def test_chunks(iterator):
def test_chunks(iterator: tuple[Iterator, Optional[int]]) -> None:
chunks = Chunks(iterator[0], size=2, length=iterator[1])
assert list(chunks) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
def test_parpool():
def fun(i, j, k): # noqa
def test_parpool() -> None:
def fun(i, j, k) -> int: # noqa
return i * j * k
with ParPool(fun, (3,), {'k': 2}) as pool:
with ParPool(fun, (3,), {'k': 2}) as pool: # noqa
for i in range(10):
pool[i] = i
assert [pool[i] for i in range(10)] == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_parfor():
def test_parfor() -> None:
@parfor(range(10), (3,), {'k': 2})
def fun(i, j, k):
return i * j * k
@@ -66,14 +71,14 @@ def test_parfor():
assert fun == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_pmap():
def test_pmap() -> None:
def fun(i, j, k):
return i * j * k
assert pmap(fun, range(10), (3,), {'k': 2}) == [0, 6, 12, 18, 24, 30, 36, 42, 48, 54]
def test_id_reuse():
def test_id_reuse() -> None:
def fun(i):
return i[0].a
@@ -87,5 +92,16 @@ def test_id_reuse():
yield t
del t
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000)
a = pmap(fun, Chunks(gen(1000), size=1, length=1000), total=1000) # noqa
assert all([i == j for i, j in enumerate(a)])
@pytest.mark.parametrize('n_processes', (2, 4, 6))
def test_n_processes(n_processes) -> None:
@parfor(range(12), n_processes=n_processes)
def fun(i): # noqa
sleep(0.25)
return getpid()
assert len(set(fun)) == n_processes