diff --git a/parfor/__init__.py b/parfor/__init__.py index 86fbf35..5255ba4 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import multiprocessing from collections import UserDict from contextlib import ExitStack @@ -5,6 +7,7 @@ from functools import wraps from os import getpid from time import time from traceback import format_exc +from typing import Any, Callable, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar from warnings import warn from tqdm.auto import tqdm @@ -14,8 +17,14 @@ from .pickler import dumps, loads cpu_count = int(multiprocessing.cpu_count()) +Result = TypeVar('Result') +Iteration = TypeVar('Iteration') +Arg = TypeVar('Arg') +Return = TypeVar('Return') + + class SharedMemory(UserDict): - def __init__(self, manager): + def __init__(self, manager: multiprocessing.Manager) -> None: super().__init__() self.data = manager.dict() # item_id: dilled representation of object self.references = manager.dict() # item_id: counter @@ -24,10 +33,10 @@ class SharedMemory(UserDict): self.trash_can = {} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} - def __getstate__(self): + def __getstate__(self) -> tuple[dict[int, bytes], dict[int, int], multiprocessing.Lock]: return self.data, self.references, self.references_lock - def __setitem__(self, item_id, value): + def __setitem__(self, item_id: int, value: Any) -> None: if item_id not in self: # values will not be changed try: self.data[item_id] = False, value @@ -40,7 +49,7 @@ class SharedMemory(UserDict): self.references[item_id] = 1 self.cache[item_id] = value # the id of the object will not be reused as long as the object exists - def add_item(self, item, pool_id, task_handle): + def add_item(self, item: Any, pool_id: int, task_handle: Hashable) -> int: item_id = id(item) self[item_id] = item if item_id in self.pool_ids: @@ -49,14 +58,14 @@ class SharedMemory(UserDict): self.pool_ids[item_id] = {(pool_id, task_handle)} return item_id - def remove_pool(self, pool_id): + def remove_pool(self, pool_id: int) -> None: """ remove objects used by a pool that won't be needed anymore """ self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} for item_id in set(self.data.keys()) - set(self.pool_ids): del self[item_id] self.garbage_collect() - def remove_task(self, pool_id, task): + def remove_task(self, pool_id: int, task: Task) -> None: """ remove objects used by a task that won't be needed anymore """ self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): @@ -64,12 +73,12 @@ class SharedMemory(UserDict): self.garbage_collect() # worker functions - def __setstate__(self, state): + def __setstate__(self, state: dict) -> None: self.data, self.references, self.references_lock = state self.cache = {} self.trash_can = None - def __getitem__(self, item_id): + def __getitem__(self, item_id: int) -> Any: if item_id not in self.cache: dilled, value = self.data[item_id] if dilled: @@ -82,7 +91,7 @@ class SharedMemory(UserDict): self.cache[item_id] = value return self.cache[item_id] - def garbage_collect(self): + def garbage_collect(self) -> None: """ clean up the cache """ for item_id in set(self.cache) - set(self.data.keys()): with self.references_lock: @@ -102,7 +111,7 @@ class SharedMemory(UserDict): del self.trash_can[item_id] -class Chunks: +class Chunks(Iterable): """ Yield successive chunks from lists. Usage: chunks(list0, list1, ...) chunks(list0, list1, ..., size=s) @@ -115,25 +124,13 @@ class Chunks: both ratio and number are given: use ratio """ - def __init__(self, *iterators, size=None, number=None, ratio=None, length=None, s=None, n=None, r=None): - # s, r and n are deprecated - if s is not None: - warn('parfor: use of \'s\' is deprecated, use \'size\' instead', DeprecationWarning, stacklevel=2) - warn('parfor: use of \'s\' is deprecated, use \'size\' instead', DeprecationWarning, stacklevel=3) - size = s - if n is not None: - warn('parfor: use of \'n\' is deprecated, use \'number\' instead', DeprecationWarning, stacklevel=2) - warn('parfor: use of \'n\' is deprecated, use \'number\' instead', DeprecationWarning, stacklevel=3) - number = n - if r is not None: - warn('parfor: use of \'r\' is deprecated, use \'ratio\' instead', DeprecationWarning, stacklevel=2) - warn('parfor: use of \'r\' is deprecated, use \'ratio\' instead', DeprecationWarning, stacklevel=3) - ratio = r + def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None, + ratio: float = None, length: int = None) -> None: if length is None: try: - length = min(*[len(iterator) for iterator in iterators]) if len(iterators) > 1 else len(iterators[0]) + length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0]) except TypeError: - raise TypeError('Cannot determine the length of the iterator(s), so the length must be provided as an' + raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an' ' argument.') if size is not None and (number is not None or ratio is not None): if number is None: @@ -144,13 +141,13 @@ class Chunks: number = round(length / size) elif ratio is not None: # number of chunks number = int(cpu_count * ratio) - self.iterators = [iter(arg) for arg in iterators] + self.iterators = [iter(arg) for arg in iterables] self.number_of_items = length self.length = max(1, min(length, number)) self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length) for i in range(self.length)] - def __iter__(self): + def __iter__(self) -> Iterator[Any]: for i in range(self.length): p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length) if len(self.iterators) == 1: @@ -158,37 +155,41 @@ class Chunks: else: yield [[next(iterator) for _ in range(q-p)] for iterator in self.iterators] - def __len__(self): + def __len__(self) -> int: return self.length -class ExternalBar: - def __init__(self, iterable=None, callback=None, total=0): +class Bar(Protocol): + def update(self, n: int = 1) -> None: ... + + +class ExternalBar(Iterable): + def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None: self.iterable = iterable self.callback = callback self.total = total self._n = 0 - def __enter__(self): + def __enter__(self) -> ExternalBar: return self - def __exit__(self, *args, **kwargs): + def __exit__(self, *args: Any, **kwargs: Any) -> None: return - def __iter__(self): + def __iter__(self) -> Iterator[Any]: for n, item in enumerate(self.iterable): yield item self.n = n + 1 - def update(self, n=1): + def update(self, n: int = 1) -> None: self.n += n @property - def n(self): + def n(self) -> int: return self._n @n.setter - def n(self, n): + def n(self, n: int) -> None: if n != self._n: self._n = n if self.callback is not None: @@ -196,7 +197,8 @@ class ExternalBar: class Task: - def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None): + def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: Hashable, fun: Callable[[Any, ...], Any], + args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None: self.pool_id = pool_id self.handle = handle self.fun = shared_memory.add_item(fun, pool_id, handle) @@ -208,20 +210,20 @@ class Task: self.result = None self.pid = None - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__ if self.result is not None: state['result'] = dumps(self.result, recurse=True) return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update({key: value for key, value in state.items() if key != 'result'}) if state['result'] is None: self.result = None else: self.result = loads(state['result']) - def __call__(self, shared_memory: SharedMemory): + def __call__(self, shared_memory: SharedMemory) -> Task: if not self.done: fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa args = [shared_memory[arg] for arg in self.args] @@ -230,7 +232,7 @@ class Task: self.done = True return self - def __repr__(self): + def __repr__(self) -> str: if self.done: return f'Task {self.handle}, result: {self.result}' else: @@ -241,11 +243,11 @@ class Context(multiprocessing.context.SpawnContext): """ Provide a context where child processes never are daemonic. """ class Process(multiprocessing.context.SpawnProcess): @property - def daemon(self): + def daemon(self) -> bool: return False @daemon.setter - def daemon(self, value): + def daemon(self, value: bool) -> None: pass @@ -253,7 +255,8 @@ class ParPool: """ Parallel processing with addition of iterations at any time and request of that result any time after that. The target function and its argument can be changed at any time. """ - def __init__(self, fun=None, args=None, kwargs=None, bar=None): + def __init__(self, fun: Callable[[Any, ...], Any] = None, + args: tuple[Any] = None, kwargs: dict[str, Any] = None, bar: Bar = None): self.id = id(self) self.handle = 0 self.tasks = {} @@ -267,22 +270,23 @@ class ParPool: self.is_started = False self.last_task = None - def __getstate__(self): + def __getstate__(self) -> NoReturn: raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') - def __enter__(self, *args, **kwargs): + def __enter__(self) -> ParPool: return self - def __exit__(self, *args, **kwargs): + def __exit__(self, *args: Any, **kwargs: Any) -> None: self.close() - def close(self): + def close(self) -> None: self.spool.remove_pool(self.id) - def __call__(self, n, handle=None, barlength=1): + def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None: self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength) - def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1): + def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None, + kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]: if self.id not in self.spool.pools: raise ValueError(f'this pool is not registered (anymore) with the pool singleton') if handle is None: @@ -301,11 +305,11 @@ class ParPool: if handle is None: return new_handle - def __setitem__(self, handle, n): + def __setitem__(self, handle: Hashable, n: Any) -> None: """ Add new iteration. """ self(n, handle=handle) - def __getitem__(self, handle): + def __getitem__(self, handle: Hashable) -> Any: """ Request result and delete its record. Wait if result not yet available. """ if handle not in self: raise ValueError(f'No handle: {handle} in pool') @@ -323,53 +327,56 @@ class ParPool: self.tasks.pop(handle) return result - def __contains__(self, handle): + def __contains__(self, handle: Hashable) -> bool: return handle in self.tasks - def __delitem__(self, handle): + def __delitem__(self, handle: Hashable) -> None: self.tasks.pop(handle) - def get_newest(self): + def get_newest(self) -> Any: return self.spool.get_newest_for_pool(self) - def process_queue(self): + def process_queue(self) -> None: self.spool.process_queue() - def task_error(self, handle, error): + def task_error(self, handle: Hashable, error: Exception) -> None: if handle in self: task = self.tasks[handle] print(f'Error from process working on iteration {handle}:\n') print(error) print('Retrying in main thread...') task(self.spool.shared_memory) + self.spool.shared_memory.remove_task(self.id, task) raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.') - self.spool.shared_memory.remove_task(self.id, self.tasks[handle]) - def done(self, task): + def done(self, task: Task) -> None: if task.handle in self: # if not, the task was restarted erroneously self.tasks[task.handle] = task if hasattr(self.bar, 'update'): self.bar.update(self.bar_lengths.pop(task.handle)) self.spool.shared_memory.remove_task(self.id, task) - def started(self, handle, pid): + def started(self, handle: Hashable, pid: int) -> None: self.is_started = True if handle in self: # if not, the task was restarted erroneously self.tasks[handle].pid = pid @property - def working(self): + def working(self) -> bool: return not all([task.pid is None for task in self.tasks.values()]) class PoolSingleton: """ There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a new pool. The pool will close itself after 10 minutes of idle time. """ - def __new__(cls, *args, **kwargs): - if hasattr(cls, 'instance') and cls.instance is not None: # noqa restart if any workers have shut down + + instance = None + + def __new__(cls, *args: Any, **kwargs: Any) -> PoolSingleton: + if cls.instance is not None: # restart if any workers have shut down if cls.instance.n_workers.value < cls.instance.n_processes: cls.instance.close() - if not hasattr(cls, 'instance') or cls.instance is None or not cls.instance.is_alive: # noqa + if cls.instance is None or not cls.instance.is_alive: new = super().__new__(cls) new.n_processes = cpu_count new.instance = new @@ -387,32 +394,32 @@ class PoolSingleton: new.handle = 0 new.pools = {} cls.instance = new - return cls.instance # noqa + return cls.instance - def __init__(self, parpool=None): # noqa + def __init__(self, parpool: Parpool = None) -> None: # noqa if parpool is not None: self.pools[parpool.id] = parpool - def __getstate__(self): + def __getstate__(self) -> NoReturn: raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') # def __del__(self): # self.close() - def remove_pool(self, pool_id): + def remove_pool(self, pool_id: int) -> None: self.shared_memory.remove_pool(pool_id) if pool_id in self.pools: self.pools.pop(pool_id) - def error(self, error): + def error(self, error: Exception) -> NoReturn: self.close() raise Exception(f'Error occurred in worker: {error}') - def process_queue(self): + def process_queue(self) -> None: while self.get_from_queue(): pass - def get_from_queue(self): + def get_from_queue(self) -> bool: """ Get an item from the queue and store it, return True if more messages are waiting. """ try: code, pool_id, *args = self.queue_out.get(True, 0.02) @@ -430,7 +437,7 @@ class PoolSingleton: warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.') return False - def add_task(self, task): + def add_task(self, task: Task) -> None: """ Add new iteration, using optional manually defined handle.""" if self.is_alive and not self.event.is_set(): while self.queue_in.full(): @@ -438,7 +445,7 @@ class PoolSingleton: self.queue_in.put(task) self.shared_memory.garbage_collect() - def get_newest_for_pool(self, pool): + def get_newest_for_pool(self, pool: ParPool) -> tuple[Hashable, Any]: """ Request the newest key and result and delete its record. Wait if result not yet available. """ while len(pool.tasks): self.get_from_queue() @@ -449,7 +456,7 @@ class PoolSingleton: return handle, result @classmethod - def close(cls): + def close(cls) -> None: if hasattr(cls, 'instance') and cls.instance is not None: instance = cls.instance cls.instance = None @@ -465,45 +472,47 @@ class PoolSingleton: except OSError: pass - def close_queue(queue): - empty_queue(queue) + def close_queue(queue: multiprocessing.queues.Queue) -> None: + empty_queue(queue) # noqa if not queue._closed: # noqa queue.close() queue.join_thread() if instance.is_alive: - instance.is_alive = False # noqa + instance.is_alive = False instance.event.set() instance.pool.close() t = time() - while instance.n_workers.value: # noqa - empty_queue(instance.queue_in) # noqa - empty_queue(instance.queue_out) # noqa + while instance.n_workers.value: + empty_queue(instance.queue_in) + empty_queue(instance.queue_out) if time() - t > 10: - warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.') # noqa + warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.') instance.pool.terminate() break - empty_queue(instance.queue_in) # noqa - empty_queue(instance.queue_out) # noqa + empty_queue(instance.queue_in) + empty_queue(instance.queue_out) instance.pool.join() - close_queue(instance.queue_in) # noqa - close_queue(instance.queue_out) # noqa + close_queue(instance.queue_in) + close_queue(instance.queue_out) instance.manager.shutdown() - instance.handle = 0 # noqa + instance.handle = 0 class Worker: """ Manages executing the target function which will be executed in different processes. """ nested = False - def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event): + def __init__(self, shared_memory: SharedMemory, queue_in: multiprocessing.queues.Queue, + queue_out: multiprocessing.queues.Queue, n_workers: multiprocessing.Value, + event: multiprocessing.Event) -> None: self.shared_memory = shared_memory self.queue_in = queue_in self.queue_out = queue_out self.n_workers = n_workers self.event = event - def add_to_queue(self, *args): + def add_to_queue(self, *args: Any) -> None: while not self.event.is_set(): try: self.queue_out.put(args, timeout=0.1) @@ -511,7 +520,7 @@ class Worker: except multiprocessing.queues.Full: # noqa continue - def __call__(self): + def __call__(self) -> None: Worker.nested = True pid = getpid() last_active_time = time() @@ -537,8 +546,10 @@ class Worker: self.n_workers.value -= 1 -def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=True, terminator=None, - serial=None, length=None, **bar_kwargs): +def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None, + args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None, + bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None, + **bar_kwargs: Any) -> list[Result]: """ map a function fun to each iteration in iterable use as a function: pmap use as a decorator: parfor @@ -620,8 +631,8 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar= iterable = Chunks(iterable, ratio=5, length=total) @wraps(fun) - def chunk_fun(iterator, *args, **kwargs): # noqa - return [fun(i, *args, **kwargs) for i in iterator] # noqa + def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa + return [fun(iteration, *args, **kwargs) for iteration in iterable] args = args or () kwargs = kwargs or {} @@ -636,13 +647,13 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar= if callable(bar): return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], []) else: - return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], []) # noqa + return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], []) else: # parallel case with ExitStack() as stack: if callable(bar): - bar = stack.enter_context(ExternalBar(callback=bar)) # noqa + bar = stack.enter_context(ExternalBar(callback=bar)) else: - bar = stack.enter_context(tqdm(**bar_kwargs)) # noqa + bar = stack.enter_context(tqdm(**bar_kwargs)) with ParPool(chunk_fun, args, kwargs, bar) as p: for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue p(j, handle=i, barlength=iterable.lengths[i]) @@ -655,26 +666,7 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar= @wraps(pmap) -def parfor(*args, **kwargs): - def decfun(fun): +def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]: + def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]: return pmap(fun, *args, **kwargs) return decfun - - -def deprecated(cls, name): - """ This is a decorator which can be used to mark functions and classes as deprecated. It will result in a warning - being emitted when the function or class is used.""" - @wraps(cls) - def wrapper(*args, **kwargs): - warn(f'parfor: use of \'{name}\' is deprecated, use \'{cls.__name__}\' instead', - category=DeprecationWarning, stacklevel=2) - warn(f'parfor: use of \'{name}\' is deprecated, use \'{cls.__name__}\' instead', - category=DeprecationWarning, stacklevel=3) - return cls(*args, **kwargs) - return wrapper - - -# backwards compatibility -parpool = deprecated(ParPool, 'parpool') -Parpool = deprecated(ParPool, 'Parpool') -chunks = deprecated(Chunks, 'chunks') diff --git a/parfor/pickler.py b/parfor/pickler.py index 6142e25..27bb447 100644 --- a/parfor/pickler.py +++ b/parfor/pickler.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import copyreg from io import BytesIO from pickle import PicklingError +from typing import Any, Callable import dill @@ -8,14 +11,14 @@ loads = dill.loads class CouldNotBePickled: - def __init__(self, class_name): + def __init__(self, class_name: str) -> None: self.class_name = class_name - def __repr__(self): + def __repr__(self) -> str: return f"Item of type '{self.class_name}' could not be pickled and was omitted." @classmethod - def reduce(cls, item): + def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]: return cls, (type(item).__name__,) @@ -24,7 +27,7 @@ class Pickler(dill.Pickler): You probably didn't want to use these parts anyhow. However, if you did, you'll have to find some way to make them picklable. """ - def save(self, obj, save_persistent_id=True): + def save(self, obj: Any, save_persistent_id: bool = True) -> None: """ Copied from pickle and amended. """ self.framer.commit_frame() @@ -93,8 +96,8 @@ class Pickler(dill.Pickler): raise PicklingError("%s must return string or tuple" % reduce) # Assert that it returned an appropriately sized tuple - l = len(rv) - if not (2 <= l <= 6): + length = len(rv) + if not (2 <= length <= 6): raise PicklingError("Tuple returned by %s must have " "two to six elements" % reduce) @@ -105,11 +108,12 @@ class Pickler(dill.Pickler): self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) -def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): +def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, + **kwds: Any) -> bytes: """pickle an object to a string""" protocol = dill.settings['protocol'] if protocol is None else int(protocol) _kwds = kwds.copy() _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) - file = BytesIO() - Pickler(file, protocol, **_kwds).dump(obj) - return file.getvalue() + with BytesIO() as file: + Pickler(file, protocol, **_kwds).dump(obj) + return file.getvalue() diff --git a/pyproject.toml b/pyproject.toml index 42ff24b..b01f7ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "parfor" -version = "2024.3.0" +version = "2024.4.0" description = "A package to mimic the use of parfor as done in Matlab." authors = ["Wim Pomp "] license = "GPLv3" @@ -9,7 +9,7 @@ keywords = ["parfor", "concurrency", "multiprocessing", "parallel"] repository = "https://github.com/wimpomp/parfor" [tool.poetry.dependencies] -python = "^3.8" +python = "^3.10" tqdm = ">=4.50.0" dill = ">=0.3.0" pytest = { version = "*", optional = true } @@ -17,6 +17,9 @@ pytest = { version = "*", optional = true } [tool.poetry.extras] test = ["pytest"] +[tool.isort] +line_length = 119 + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_parfor.py b/tests/test_parfor.py index 4f648a1..91cb90e 100644 --- a/tests/test_parfor.py +++ b/tests/test_parfor.py @@ -1,7 +1,9 @@ -import pytest -from parfor import Chunks, ParPool, parfor, pmap from dataclasses import dataclass +import pytest + +from parfor import Chunks, ParPool, parfor, pmap + class SequenceIterator: def __init__(self, sequence):