- make cpu_count a field to PoolSingleton for easy global configuration of number of processes

- remove TypeVars
- manually wrap parfor and pmap
- also redirect output when retrieving task
This commit is contained in:
Wim Pomp
2024-12-20 16:43:02 +01:00
parent 31e07b49eb
commit eb92ce006d
4 changed files with 34 additions and 28 deletions

View File

@@ -4,7 +4,7 @@ import sys
from contextlib import ExitStack
from functools import wraps
from importlib.metadata import version
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar
from typing import Any, Callable, Generator, Iterable, Iterator, Sized
from warnings import warn
from tqdm.auto import tqdm
@@ -21,10 +21,6 @@ else:
__version__ = version('parfor')
Result = TypeVar('Result')
Iteration = TypeVar('Iteration')
class Chunks(Iterable):
""" Yield successive chunks from lists.
Usage: chunks(list0, list1, ...)
@@ -106,11 +102,11 @@ class ExternalBar(Iterable):
self.callback(n)
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Result, None, None] | list[Result]:
**bar_kwargs: Any) -> Generator[Any, None, None]:
""" map a function fun to each iteration in iterable
use as a function: pmap
use as a decorator: parfor
@@ -197,7 +193,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
iterable = Chunks(iterable, ratio=5, length=total)
@wraps(fun)
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Any]: # noqa
return [fun(iteration, *args, **kwargs) for iteration in iterable]
if args is None:
@@ -213,7 +209,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk
@@ -284,13 +280,19 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
yield from p.get_newest()[1]
@wraps(gmap)
def pmap(*args, **kwargs) -> list[Result]:
def pmap(*args, **kwargs) -> list[Any]:
return list(gmap(*args, **kwargs))
@wraps(gmap)
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
return pmap(fun, *args, **kwargs)
return decfun
try:
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
except AttributeError:
pass

View File

@@ -4,13 +4,13 @@ import asyncio
import multiprocessing
from collections import UserDict
from contextlib import redirect_stderr, redirect_stdout
from os import devnull, getpid
from os import cpu_count, devnull, getpid
from time import time
from traceback import format_exc
from typing import Any, Callable, Hashable, NoReturn, Optional
from warnings import warn
from .common import Bar, cpu_count
from .common import Bar
from .pickler import dumps, loads
@@ -275,16 +275,17 @@ class PoolSingleton:
new pool. The pool will close itself after 10 minutes of idle time. """
instance = None
cpu_count = cpu_count()
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes
if cls.instance is not None:
if (cls.instance.n_workers.value < cls.instance.n_processes or
cls.instance.n_processes != (n_processes or cpu_count)):
cls.instance.n_processes != (n_processes or cls.cpu_count)):
cls.instance.close()
if cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls)
new.n_processes = n_processes or cpu_count
new.n_processes = n_processes or cls.cpu_count
new.instance = new
new.is_started = False
ctx = Context()
@@ -437,14 +438,14 @@ class Worker:
last_active_time = time()
while not self.event.is_set() and time() - last_active_time < 600:
try:
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.pool_id, task.handle, pid)
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
task = self.queue_in.get(True, 0.02)
try:
self.add_to_queue('started', task.pool_id, task.handle, pid)
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception: # noqa
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.event.set()
except Exception: # noqa
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.event.set()
self.shared_memory.garbage_collect()
last_active_time = time()
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import queue
import threading
from os import cpu_count
from typing import Any, Callable, Hashable, NoReturn, Optional
from .common import Bar, cpu_count
from .common import Bar
class Worker:
@@ -15,6 +16,8 @@ class Worker:
class PoolSingleton:
cpu_count = cpu_count()
def __init__(self, *args, **kwargs):
pass
@@ -64,7 +67,7 @@ class ParPool:
self.fun = fun
self.args = args
self.kwargs = kwargs
self.n_processes = n_processes or cpu_count
self.n_processes = n_processes or PoolSingleton.cpu_count
self.threads = {}
def __getstate__(self) -> NoReturn: