- make cpu_count a field to PoolSingleton for easy global configuration of number of processes

- remove TypeVars
- manually wrap parfor and pmap
- also redirect output when retrieving task
This commit is contained in:
Wim Pomp
2024-12-20 16:43:02 +01:00
parent 31e07b49eb
commit eb92ce006d
4 changed files with 34 additions and 28 deletions

View File

@@ -4,7 +4,7 @@ import sys
from contextlib import ExitStack from contextlib import ExitStack
from functools import wraps from functools import wraps
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar from typing import Any, Callable, Generator, Iterable, Iterator, Sized
from warnings import warn from warnings import warn
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -21,10 +21,6 @@ else:
__version__ = version('parfor') __version__ = version('parfor')
Result = TypeVar('Result')
Iteration = TypeVar('Iteration')
class Chunks(Iterable): class Chunks(Iterable):
""" Yield successive chunks from lists. """ Yield successive chunks from lists.
Usage: chunks(list0, list1, ...) Usage: chunks(list0, list1, ...)
@@ -106,11 +102,11 @@ class ExternalBar(Iterable):
self.callback(n) self.callback(n)
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None, def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None, args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None, bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False, n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
**bar_kwargs: Any) -> Generator[Result, None, None] | list[Result]: **bar_kwargs: Any) -> Generator[Any, None, None]:
""" map a function fun to each iteration in iterable """ map a function fun to each iteration in iterable
use as a function: pmap use as a function: pmap
use as a decorator: parfor use as a decorator: parfor
@@ -197,7 +193,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
iterable = Chunks(iterable, ratio=5, length=total) iterable = Chunks(iterable, ratio=5, length=total)
@wraps(fun) @wraps(fun)
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Any]: # noqa
return [fun(iteration, *args, **kwargs) for iteration in iterable] return [fun(iteration, *args, **kwargs) for iteration in iterable]
if args is None: if args is None:
@@ -213,7 +209,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
bar_kwargs['disable'] = not bar bar_kwargs['disable'] = not bar
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
with tqdm(*args, **kwargs) as b: with tqdm(*args, **kwargs) as b:
for chunk, length in zip(chunks, chunks.lengths): # noqa for chunk, length in zip(chunks, chunks.lengths): # noqa
yield chunk yield chunk
@@ -284,13 +280,19 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
yield from p.get_newest()[1] yield from p.get_newest()[1]
@wraps(gmap) def pmap(*args, **kwargs) -> list[Any]:
def pmap(*args, **kwargs) -> list[Result]:
return list(gmap(*args, **kwargs)) return list(gmap(*args, **kwargs))
@wraps(gmap) def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]: def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
return pmap(fun, *args, **kwargs) return pmap(fun, *args, **kwargs)
return decfun return decfun
try:
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
except AttributeError:
pass

View File

@@ -4,13 +4,13 @@ import asyncio
import multiprocessing import multiprocessing
from collections import UserDict from collections import UserDict
from contextlib import redirect_stderr, redirect_stdout from contextlib import redirect_stderr, redirect_stdout
from os import devnull, getpid from os import cpu_count, devnull, getpid
from time import time from time import time
from traceback import format_exc from traceback import format_exc
from typing import Any, Callable, Hashable, NoReturn, Optional from typing import Any, Callable, Hashable, NoReturn, Optional
from warnings import warn from warnings import warn
from .common import Bar, cpu_count from .common import Bar
from .pickler import dumps, loads from .pickler import dumps, loads
@@ -275,16 +275,17 @@ class PoolSingleton:
new pool. The pool will close itself after 10 minutes of idle time. """ new pool. The pool will close itself after 10 minutes of idle time. """
instance = None instance = None
cpu_count = cpu_count()
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton: def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
# restart if any workers have shut down or if we want to have a different number of processes # restart if any workers have shut down or if we want to have a different number of processes
if cls.instance is not None: if cls.instance is not None:
if (cls.instance.n_workers.value < cls.instance.n_processes or if (cls.instance.n_workers.value < cls.instance.n_processes or
cls.instance.n_processes != (n_processes or cpu_count)): cls.instance.n_processes != (n_processes or cls.cpu_count)):
cls.instance.close() cls.instance.close()
if cls.instance is None or not cls.instance.is_alive: if cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls) new = super().__new__(cls)
new.n_processes = n_processes or cpu_count new.n_processes = n_processes or cls.cpu_count
new.instance = new new.instance = new
new.is_started = False new.is_started = False
ctx = Context() ctx = Context()
@@ -437,14 +438,14 @@ class Worker:
last_active_time = time() last_active_time = time()
while not self.event.is_set() and time() - last_active_time < 600: while not self.event.is_set() and time() - last_active_time < 600:
try: try:
task = self.queue_in.get(True, 0.02) with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
try: task = self.queue_in.get(True, 0.02)
self.add_to_queue('started', task.pool_id, task.handle, pid) try:
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')): self.add_to_queue('started', task.pool_id, task.handle, pid)
self.add_to_queue('done', task.pool_id, task(self.shared_memory)) self.add_to_queue('done', task.pool_id, task(self.shared_memory))
except Exception: # noqa except Exception: # noqa
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc()) self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
self.event.set() self.event.set()
self.shared_memory.garbage_collect() self.shared_memory.garbage_collect()
last_active_time = time() last_active_time = time()
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import queue import queue
import threading import threading
from os import cpu_count
from typing import Any, Callable, Hashable, NoReturn, Optional from typing import Any, Callable, Hashable, NoReturn, Optional
from .common import Bar, cpu_count from .common import Bar
class Worker: class Worker:
@@ -15,6 +16,8 @@ class Worker:
class PoolSingleton: class PoolSingleton:
cpu_count = cpu_count()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -64,7 +67,7 @@ class ParPool:
self.fun = fun self.fun = fun
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self.n_processes = n_processes or cpu_count self.n_processes = n_processes or PoolSingleton.cpu_count
self.threads = {} self.threads = {}
def __getstate__(self) -> NoReturn: def __getstate__(self) -> NoReturn:

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2024.12.0" version = "2024.12.1"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"