- make cpu_count a field to PoolSingleton for easy global configuration of number of processes
- remove TypeVars - manually wrap parfor and pmap - also redirect output when retrieving task
This commit is contained in:
@@ -4,7 +4,7 @@ import sys
|
||||
from contextlib import ExitStack
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar
|
||||
from typing import Any, Callable, Generator, Iterable, Iterator, Sized
|
||||
from warnings import warn
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -21,10 +21,6 @@ else:
|
||||
__version__ = version('parfor')
|
||||
|
||||
|
||||
Result = TypeVar('Result')
|
||||
Iteration = TypeVar('Iteration')
|
||||
|
||||
|
||||
class Chunks(Iterable):
|
||||
""" Yield successive chunks from lists.
|
||||
Usage: chunks(list0, list1, ...)
|
||||
@@ -106,11 +102,11 @@ class ExternalBar(Iterable):
|
||||
self.callback(n)
|
||||
|
||||
|
||||
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
|
||||
def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
||||
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
||||
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
||||
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
||||
**bar_kwargs: Any) -> Generator[Result, None, None] | list[Result]:
|
||||
**bar_kwargs: Any) -> Generator[Any, None, None]:
|
||||
""" map a function fun to each iteration in iterable
|
||||
use as a function: pmap
|
||||
use as a decorator: parfor
|
||||
@@ -197,7 +193,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
||||
iterable = Chunks(iterable, ratio=5, length=total)
|
||||
|
||||
@wraps(fun)
|
||||
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
|
||||
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Any]: # noqa
|
||||
return [fun(iteration, *args, **kwargs) for iteration in iterable]
|
||||
|
||||
if args is None:
|
||||
@@ -213,7 +209,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
||||
bar_kwargs['disable'] = not bar
|
||||
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
|
||||
|
||||
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
|
||||
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
|
||||
with tqdm(*args, **kwargs) as b:
|
||||
for chunk, length in zip(chunks, chunks.lengths): # noqa
|
||||
yield chunk
|
||||
@@ -284,13 +280,19 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
||||
yield from p.get_newest()[1]
|
||||
|
||||
|
||||
@wraps(gmap)
|
||||
def pmap(*args, **kwargs) -> list[Result]:
|
||||
def pmap(*args, **kwargs) -> list[Any]:
|
||||
return list(gmap(*args, **kwargs))
|
||||
|
||||
|
||||
@wraps(gmap)
|
||||
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
|
||||
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
|
||||
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
|
||||
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
|
||||
return pmap(fun, *args, **kwargs)
|
||||
return decfun
|
||||
|
||||
|
||||
try:
|
||||
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
|
||||
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
|
||||
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -4,13 +4,13 @@ import asyncio
|
||||
import multiprocessing
|
||||
from collections import UserDict
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from os import devnull, getpid
|
||||
from os import cpu_count, devnull, getpid
|
||||
from time import time
|
||||
from traceback import format_exc
|
||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
||||
from warnings import warn
|
||||
|
||||
from .common import Bar, cpu_count
|
||||
from .common import Bar
|
||||
from .pickler import dumps, loads
|
||||
|
||||
|
||||
@@ -275,16 +275,17 @@ class PoolSingleton:
|
||||
new pool. The pool will close itself after 10 minutes of idle time. """
|
||||
|
||||
instance = None
|
||||
cpu_count = cpu_count()
|
||||
|
||||
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
||||
# restart if any workers have shut down or if we want to have a different number of processes
|
||||
if cls.instance is not None:
|
||||
if (cls.instance.n_workers.value < cls.instance.n_processes or
|
||||
cls.instance.n_processes != (n_processes or cpu_count)):
|
||||
cls.instance.n_processes != (n_processes or cls.cpu_count)):
|
||||
cls.instance.close()
|
||||
if cls.instance is None or not cls.instance.is_alive:
|
||||
new = super().__new__(cls)
|
||||
new.n_processes = n_processes or cpu_count
|
||||
new.n_processes = n_processes or cls.cpu_count
|
||||
new.instance = new
|
||||
new.is_started = False
|
||||
ctx = Context()
|
||||
@@ -437,14 +438,14 @@ class Worker:
|
||||
last_active_time = time()
|
||||
while not self.event.is_set() and time() - last_active_time < 600:
|
||||
try:
|
||||
task = self.queue_in.get(True, 0.02)
|
||||
try:
|
||||
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
||||
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
|
||||
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
|
||||
task = self.queue_in.get(True, 0.02)
|
||||
try:
|
||||
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||
except Exception: # noqa
|
||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||
self.event.set()
|
||||
except Exception: # noqa
|
||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||
self.event.set()
|
||||
self.shared_memory.garbage_collect()
|
||||
last_active_time = time()
|
||||
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from os import cpu_count
|
||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
||||
|
||||
from .common import Bar, cpu_count
|
||||
from .common import Bar
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -15,6 +16,8 @@ class Worker:
|
||||
|
||||
|
||||
class PoolSingleton:
|
||||
cpu_count = cpu_count()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -64,7 +67,7 @@ class ParPool:
|
||||
self.fun = fun
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.n_processes = n_processes or cpu_count
|
||||
self.n_processes = n_processes or PoolSingleton.cpu_count
|
||||
self.threads = {}
|
||||
|
||||
def __getstate__(self) -> NoReturn:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "parfor"
|
||||
version = "2024.12.0"
|
||||
version = "2024.12.1"
|
||||
description = "A package to mimic the use of parfor as done in Matlab."
|
||||
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
||||
license = "GPLv3"
|
||||
|
||||
Reference in New Issue
Block a user