Files
parfor/parfor/nogil.py
Wim Pomp eb92ce006d - make cpu_count a field to PoolSingleton for easy global configuration of number of processes
- remove TypeVars
- manually wrap parfor and pmap
- also redirect output when retrieving task
2024-12-20 16:43:02 +01:00

159 lines
5.2 KiB
Python

from __future__ import annotations
import queue
import threading
from os import cpu_count
from typing import Any, Callable, Hashable, NoReturn, Optional
from .common import Bar
class Worker:
nested = False
def __init__(self, *args, **kwargs):
pass
class PoolSingleton:
cpu_count = cpu_count()
def __init__(self, *args, **kwargs):
pass
def close(self):
pass
class Task:
def __init__(self, queue: queue.Queue, handle: Hashable, fun: Callable[[Any, ...], Any], # noqa
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
self.queue = queue
self.handle = handle
self.fun = fun
self.args = args
self.kwargs = {} if kwargs is None else kwargs
self.name = fun.__name__ if hasattr(fun, '__name__') else None
self.started = False
self.done = False
self.result = None
def __call__(self):
if not self.done:
self.result = self.fun(*self.args, **self.kwargs)
try:
self.queue.put(self.handle)
except queue.ShutDown:
pass
def __repr__(self) -> str:
if self.done:
return f'Task {self.handle}, result: {self.result}'
else:
return f'Task {self.handle}'
class ParPool:
""" Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time.
"""
def __init__(self, fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None, kwargs: dict[str, Any] = None, n_processes: int = None, bar: Bar = None):
self.queue = queue.Queue()
self.handle = 0
self.tasks = {}
self.bar = bar
self.bar_lengths = {}
self.fun = fun
self.args = args
self.kwargs = kwargs
self.n_processes = n_processes or PoolSingleton.cpu_count
self.threads = {}
def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def __enter__(self) -> ParPool:
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.close()
def close(self) -> None:
self.queue.shutdown() # noqa python3.13
for thread in self.threads.values():
thread.join()
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
if handle is None:
new_handle = self.handle
self.handle += 1
else:
new_handle = handle
if new_handle in self:
raise ValueError(f'handle {new_handle} already present')
task = Task(self.queue, new_handle, fun or self.fun, args or self.args, kwargs or self.kwargs)
while len(self.threads) > self.n_processes:
self.get_from_queue()
thread = threading.Thread(target=task)
thread.start()
self.threads[new_handle] = thread
self.tasks[new_handle] = task
self.bar_lengths[new_handle] = barlength
if handle is None:
return new_handle
def __setitem__(self, handle: Hashable, n: Any) -> None:
""" Add new iteration. """
self(n, handle=handle)
def __getitem__(self, handle: Hashable) -> Any:
""" Request result and delete its record. Wait if result not yet available. """
if handle not in self:
raise ValueError(f'No handle: {handle} in pool')
while not self.tasks[handle].done:
self.get_from_queue()
task = self.tasks.pop(handle)
return task.result
def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks
def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle)
def get_from_queue(self) -> bool:
""" Get an item from the queue and store it, return True if more messages are waiting. """
try:
handle = self.queue.get(True, 0.02)
self.done(handle)
return True
except (queue.Empty, queue.ShutDown):
return False
def get_newest(self) -> Any:
""" Request the newest key and result and delete its record. Wait if result not yet available. """
while len(self.tasks):
self.get_from_queue()
for task in self.tasks.values():
if task.done:
handle, result = task.handle, task.result
self.tasks.pop(handle)
return handle, result
def process_queue(self) -> None:
while self.get_from_queue():
pass
def done(self, handle: Hashable) -> None:
thread = self.threads.pop(handle)
thread.join()
task = self.tasks[handle]
task.done = True
if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(handle))