- make cpu_count a field to PoolSingleton for easy global configuration of number of processes
- remove TypeVars - manually wrap parfor and pmap - also redirect output when retrieving task
This commit is contained in:
@@ -4,7 +4,7 @@ import sys
|
|||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import Any, Callable, Generator, Iterable, Iterator, Sized, TypeVar
|
from typing import Any, Callable, Generator, Iterable, Iterator, Sized
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@@ -21,10 +21,6 @@ else:
|
|||||||
__version__ = version('parfor')
|
__version__ = version('parfor')
|
||||||
|
|
||||||
|
|
||||||
Result = TypeVar('Result')
|
|
||||||
Iteration = TypeVar('Iteration')
|
|
||||||
|
|
||||||
|
|
||||||
class Chunks(Iterable):
|
class Chunks(Iterable):
|
||||||
""" Yield successive chunks from lists.
|
""" Yield successive chunks from lists.
|
||||||
Usage: chunks(list0, list1, ...)
|
Usage: chunks(list0, list1, ...)
|
||||||
@@ -106,11 +102,11 @@ class ExternalBar(Iterable):
|
|||||||
self.callback(n)
|
self.callback(n)
|
||||||
|
|
||||||
|
|
||||||
def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
|
def gmap(fun: Callable[[Any, ...], Any], iterable: Iterable[Any] = None,
|
||||||
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
|
||||||
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
|
||||||
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
n_processes: int = None, yield_ordered: bool = True, yield_index: bool = False,
|
||||||
**bar_kwargs: Any) -> Generator[Result, None, None] | list[Result]:
|
**bar_kwargs: Any) -> Generator[Any, None, None]:
|
||||||
""" map a function fun to each iteration in iterable
|
""" map a function fun to each iteration in iterable
|
||||||
use as a function: pmap
|
use as a function: pmap
|
||||||
use as a decorator: parfor
|
use as a decorator: parfor
|
||||||
@@ -197,7 +193,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
|||||||
iterable = Chunks(iterable, ratio=5, length=total)
|
iterable = Chunks(iterable, ratio=5, length=total)
|
||||||
|
|
||||||
@wraps(fun)
|
@wraps(fun)
|
||||||
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
|
def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Any]: # noqa
|
||||||
return [fun(iteration, *args, **kwargs) for iteration in iterable]
|
return [fun(iteration, *args, **kwargs) for iteration in iterable]
|
||||||
|
|
||||||
if args is None:
|
if args is None:
|
||||||
@@ -213,7 +209,7 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
|||||||
bar_kwargs['disable'] = not bar
|
bar_kwargs['disable'] = not bar
|
||||||
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
|
if serial is True or (serial is None and len(iterable) < min(cpu_count, 4)) or Worker.nested: # serial case
|
||||||
|
|
||||||
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Iteration]: # noqa
|
def tqdm_chunks(chunks: Chunks, *args, **kwargs) -> Iterable[Any]: # noqa
|
||||||
with tqdm(*args, **kwargs) as b:
|
with tqdm(*args, **kwargs) as b:
|
||||||
for chunk, length in zip(chunks, chunks.lengths): # noqa
|
for chunk, length in zip(chunks, chunks.lengths): # noqa
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -284,13 +280,19 @@ def gmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iterat
|
|||||||
yield from p.get_newest()[1]
|
yield from p.get_newest()[1]
|
||||||
|
|
||||||
|
|
||||||
@wraps(gmap)
|
def pmap(*args, **kwargs) -> list[Any]:
|
||||||
def pmap(*args, **kwargs) -> list[Result]:
|
|
||||||
return list(gmap(*args, **kwargs))
|
return list(gmap(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@wraps(gmap)
|
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Any, ...], Any]], list[Any]]:
|
||||||
def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
|
def decfun(fun: Callable[[Any, ...], Any]) -> list[Any]:
|
||||||
def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
|
|
||||||
return pmap(fun, *args, **kwargs)
|
return pmap(fun, *args, **kwargs)
|
||||||
return decfun
|
return decfun
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
parfor.__doc__ = pmap.__doc__ = gmap.__doc__
|
||||||
|
pmap.__annotations__ = gmap.__annotations__ | pmap.__annotations__
|
||||||
|
parfor.__annotations__ = {key: value for key, value in pmap.__annotations__.items() if key != 'fun'}
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import asyncio
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from contextlib import redirect_stderr, redirect_stdout
|
from contextlib import redirect_stderr, redirect_stdout
|
||||||
from os import devnull, getpid
|
from os import cpu_count, devnull, getpid
|
||||||
from time import time
|
from time import time
|
||||||
from traceback import format_exc
|
from traceback import format_exc
|
||||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
from typing import Any, Callable, Hashable, NoReturn, Optional
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from .common import Bar, cpu_count
|
from .common import Bar
|
||||||
from .pickler import dumps, loads
|
from .pickler import dumps, loads
|
||||||
|
|
||||||
|
|
||||||
@@ -275,16 +275,17 @@ class PoolSingleton:
|
|||||||
new pool. The pool will close itself after 10 minutes of idle time. """
|
new pool. The pool will close itself after 10 minutes of idle time. """
|
||||||
|
|
||||||
instance = None
|
instance = None
|
||||||
|
cpu_count = cpu_count()
|
||||||
|
|
||||||
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
def __new__(cls, n_processes: int = None, *args: Any, **kwargs: Any) -> PoolSingleton:
|
||||||
# restart if any workers have shut down or if we want to have a different number of processes
|
# restart if any workers have shut down or if we want to have a different number of processes
|
||||||
if cls.instance is not None:
|
if cls.instance is not None:
|
||||||
if (cls.instance.n_workers.value < cls.instance.n_processes or
|
if (cls.instance.n_workers.value < cls.instance.n_processes or
|
||||||
cls.instance.n_processes != (n_processes or cpu_count)):
|
cls.instance.n_processes != (n_processes or cls.cpu_count)):
|
||||||
cls.instance.close()
|
cls.instance.close()
|
||||||
if cls.instance is None or not cls.instance.is_alive:
|
if cls.instance is None or not cls.instance.is_alive:
|
||||||
new = super().__new__(cls)
|
new = super().__new__(cls)
|
||||||
new.n_processes = n_processes or cpu_count
|
new.n_processes = n_processes or cls.cpu_count
|
||||||
new.instance = new
|
new.instance = new
|
||||||
new.is_started = False
|
new.is_started = False
|
||||||
ctx = Context()
|
ctx = Context()
|
||||||
@@ -437,14 +438,14 @@ class Worker:
|
|||||||
last_active_time = time()
|
last_active_time = time()
|
||||||
while not self.event.is_set() and time() - last_active_time < 600:
|
while not self.event.is_set() and time() - last_active_time < 600:
|
||||||
try:
|
try:
|
||||||
task = self.queue_in.get(True, 0.02)
|
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
|
||||||
try:
|
task = self.queue_in.get(True, 0.02)
|
||||||
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
try:
|
||||||
with redirect_stdout(open(devnull, 'w')), redirect_stderr(open(devnull, 'w')):
|
self.add_to_queue('started', task.pool_id, task.handle, pid)
|
||||||
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
self.add_to_queue('done', task.pool_id, task(self.shared_memory))
|
||||||
except Exception: # noqa
|
except Exception: # noqa
|
||||||
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
self.add_to_queue('task_error', task.pool_id, task.handle, format_exc())
|
||||||
self.event.set()
|
self.event.set()
|
||||||
self.shared_memory.garbage_collect()
|
self.shared_memory.garbage_collect()
|
||||||
last_active_time = time()
|
last_active_time = time()
|
||||||
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa
|
except (multiprocessing.queues.Empty, KeyboardInterrupt): # noqa
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
from os import cpu_count
|
||||||
from typing import Any, Callable, Hashable, NoReturn, Optional
|
from typing import Any, Callable, Hashable, NoReturn, Optional
|
||||||
|
|
||||||
from .common import Bar, cpu_count
|
from .common import Bar
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@@ -15,6 +16,8 @@ class Worker:
|
|||||||
|
|
||||||
|
|
||||||
class PoolSingleton:
|
class PoolSingleton:
|
||||||
|
cpu_count = cpu_count()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -64,7 +67,7 @@ class ParPool:
|
|||||||
self.fun = fun
|
self.fun = fun
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.n_processes = n_processes or cpu_count
|
self.n_processes = n_processes or PoolSingleton.cpu_count
|
||||||
self.threads = {}
|
self.threads = {}
|
||||||
|
|
||||||
def __getstate__(self) -> NoReturn:
|
def __getstate__(self) -> NoReturn:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "parfor"
|
name = "parfor"
|
||||||
version = "2024.12.0"
|
version = "2024.12.1"
|
||||||
description = "A package to mimic the use of parfor as done in Matlab."
|
description = "A package to mimic the use of parfor as done in Matlab."
|
||||||
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
authors = ["Wim Pomp <wimpomp@gmail.com>"]
|
||||||
license = "GPLv3"
|
license = "GPLv3"
|
||||||
|
|||||||
Reference in New Issue
Block a user