- statically determine gil/nogil, otherwise we cannot subclass some things

This commit is contained in:
Wim Pomp
2024-12-05 12:57:46 +01:00
parent 7291468fb7
commit 31e07b49eb
2 changed files with 11 additions and 28 deletions

View File

@@ -12,6 +12,12 @@ from tqdm.auto import tqdm
from . import gil, nogil
from .common import Bar, cpu_count
if hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled(): # noqa
from .nogil import ParPool, PoolSingleton, Task, Worker
else:
from .gil import ParPool, PoolSingleton, Task, Worker
__version__ = version('parfor')
@@ -19,29 +25,6 @@ Result = TypeVar('Result')
Iteration = TypeVar('Iteration')
def select():
return nogil if hasattr(sys, '_is_gil_enabled') and not sys._is_gil_enabled() else gil # noqa
class ParPool:
def __new__(cls, *args, **kwargs):
return select().ParPool(*args, **kwargs)
class PoolSingleton:
def __new__(cls, *args, **kwargs):
return select().PoolSingleton(*args, **kwargs)
@staticmethod
def close():
return select().PoolSingleton.close()
class Task:
def __new__(cls, *args, **kwargs):
return select().Task(*args, **kwargs)
class Chunks(Iterable):
""" Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
@@ -127,7 +110,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Result, None, None]:
**bar_kwargs: Any) -> Generator[Result, None, None] | list[Result]:
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
@@ -228,7 +211,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar_kwargs['desc'] = desc
if 'disable' not in bar_kwargs:
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or select().Worker.nested: # serial case
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
with tqdm(*args, **kwargs) as b:
@@ -260,7 +243,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar = stack.enter_context(ExternalBar(callback=bar))
else:
bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p:
with ParPool(chunk_fun, args, kwargs, n_processes, bar) as p: # type: ignore
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=l)
if bar.total is None or bar.total < i + 1:
@@ -303,7 +286,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
@wraps(gmap)
def pmap(*args, **kwargs) -> list[Result]:
return list(gmap(*args, **kwargs)) # type: ignore
return list(gmap(*args, **kwargs))
@wraps(gmap)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "parfor"
version = "2024.11.1"
version = "2024.12.0"
description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3"