- minimum python: 3.10

- typing
- task_error bug fix
- remove some deprecated functions
This commit is contained in:
Wim Pomp
2024-04-26 18:32:12 +02:00
parent 42746d21eb
commit ac4d599646
4 changed files with 133 additions and 132 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import multiprocessing import multiprocessing
from collections import UserDict from collections import UserDict
from contextlib import ExitStack from contextlib import ExitStack
@@ -5,6 +7,7 @@ from functools import wraps
from os import getpid from os import getpid
from time import time from time import time
from traceback import format_exc from traceback import format_exc
from typing import Any, Callable, Hashable, Iterable, Iterator, NoReturn, Optional, Protocol, Sized, TypeVar
from warnings import warn from warnings import warn
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -14,8 +17,14 @@ from .pickler import dumps, loads
cpu_count = int(multiprocessing.cpu_count()) cpu_count = int(multiprocessing.cpu_count())
Result = TypeVar('Result')
Iteration = TypeVar('Iteration')
Arg = TypeVar('Arg')
Return = TypeVar('Return')
class SharedMemory(UserDict): class SharedMemory(UserDict):
def __init__(self, manager): def __init__(self, manager: multiprocessing.Manager) -> None:
super().__init__() super().__init__()
self.data = manager.dict() # item_id: dilled representation of object self.data = manager.dict() # item_id: dilled representation of object
self.references = manager.dict() # item_id: counter self.references = manager.dict() # item_id: counter
@@ -24,10 +33,10 @@ class SharedMemory(UserDict):
self.trash_can = {} self.trash_can = {}
self.pool_ids = {} # item_id: {(pool_id, task_handle), ...} self.pool_ids = {} # item_id: {(pool_id, task_handle), ...}
def __getstate__(self): def __getstate__(self) -> tuple[dict[int, bytes], dict[int, int], multiprocessing.Lock]:
return self.data, self.references, self.references_lock return self.data, self.references, self.references_lock
def __setitem__(self, item_id, value): def __setitem__(self, item_id: int, value: Any) -> None:
if item_id not in self: # values will not be changed if item_id not in self: # values will not be changed
try: try:
self.data[item_id] = False, value self.data[item_id] = False, value
@@ -40,7 +49,7 @@ class SharedMemory(UserDict):
self.references[item_id] = 1 self.references[item_id] = 1
self.cache[item_id] = value # the id of the object will not be reused as long as the object exists self.cache[item_id] = value # the id of the object will not be reused as long as the object exists
def add_item(self, item, pool_id, task_handle): def add_item(self, item: Any, pool_id: int, task_handle: Hashable) -> int:
item_id = id(item) item_id = id(item)
self[item_id] = item self[item_id] = item
if item_id in self.pool_ids: if item_id in self.pool_ids:
@@ -49,14 +58,14 @@ class SharedMemory(UserDict):
self.pool_ids[item_id] = {(pool_id, task_handle)} self.pool_ids[item_id] = {(pool_id, task_handle)}
return item_id return item_id
def remove_pool(self, pool_id): def remove_pool(self, pool_id: int) -> None:
""" remove objects used by a pool that won't be needed anymore """ """ remove objects used by a pool that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})} self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := {i for i in value if i[0] != pool_id})}
for item_id in set(self.data.keys()) - set(self.pool_ids): for item_id in set(self.data.keys()) - set(self.pool_ids):
del self[item_id] del self[item_id]
self.garbage_collect() self.garbage_collect()
def remove_task(self, pool_id, task): def remove_task(self, pool_id: int, task: Task) -> None:
""" remove objects used by a task that won't be needed anymore """ """ remove objects used by a task that won't be needed anymore """
self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})} self.pool_ids = {key: v for key, value in self.pool_ids.items() if (v := value - {(pool_id, task.handle)})}
for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids): for item_id in {task.fun, *task.args, *task.kwargs} - set(self.pool_ids):
@@ -64,12 +73,12 @@ class SharedMemory(UserDict):
self.garbage_collect() self.garbage_collect()
# worker functions # worker functions
def __setstate__(self, state): def __setstate__(self, state: dict) -> None:
self.data, self.references, self.references_lock = state self.data, self.references, self.references_lock = state
self.cache = {} self.cache = {}
self.trash_can = None self.trash_can = None
def __getitem__(self, item_id): def __getitem__(self, item_id: int) -> Any:
if item_id not in self.cache: if item_id not in self.cache:
dilled, value = self.data[item_id] dilled, value = self.data[item_id]
if dilled: if dilled:
@@ -82,7 +91,7 @@ class SharedMemory(UserDict):
self.cache[item_id] = value self.cache[item_id] = value
return self.cache[item_id] return self.cache[item_id]
def garbage_collect(self): def garbage_collect(self) -> None:
""" clean up the cache """ """ clean up the cache """
for item_id in set(self.cache) - set(self.data.keys()): for item_id in set(self.cache) - set(self.data.keys()):
with self.references_lock: with self.references_lock:
@@ -102,7 +111,7 @@ class SharedMemory(UserDict):
del self.trash_can[item_id] del self.trash_can[item_id]
class Chunks: class Chunks(Iterable):
""" Yield successive chunks from lists. """ Yield successive chunks from lists.
Usage: chunks(list0, list1, ...) Usage: chunks(list0, list1, ...)
chunks(list0, list1, ..., size=s) chunks(list0, list1, ..., size=s)
@@ -115,25 +124,13 @@ class Chunks:
both ratio and number are given: use ratio both ratio and number are given: use ratio
""" """
def __init__(self, *iterators, size=None, number=None, ratio=None, length=None, s=None, n=None, r=None): def __init__(self, *iterables: Iterable[Any] | Sized[Any], size: int = None, number: int = None,
# s, r and n are deprecated ratio: float = None, length: int = None) -> None:
if s is not None:
warn('parfor: use of \'s\' is deprecated, use \'size\' instead', DeprecationWarning, stacklevel=2)
warn('parfor: use of \'s\' is deprecated, use \'size\' instead', DeprecationWarning, stacklevel=3)
size = s
if n is not None:
warn('parfor: use of \'n\' is deprecated, use \'number\' instead', DeprecationWarning, stacklevel=2)
warn('parfor: use of \'n\' is deprecated, use \'number\' instead', DeprecationWarning, stacklevel=3)
number = n
if r is not None:
warn('parfor: use of \'r\' is deprecated, use \'ratio\' instead', DeprecationWarning, stacklevel=2)
warn('parfor: use of \'r\' is deprecated, use \'ratio\' instead', DeprecationWarning, stacklevel=3)
ratio = r
if length is None: if length is None:
try: try:
length = min(*[len(iterator) for iterator in iterators]) if len(iterators) > 1 else len(iterators[0]) length = min(*[len(iterable) for iterable in iterables]) if len(iterables) > 1 else len(iterables[0])
except TypeError: except TypeError:
raise TypeError('Cannot determine the length of the iterator(s), so the length must be provided as an' raise TypeError('Cannot determine the length of the iterables(s), so the length must be provided as an'
' argument.') ' argument.')
if size is not None and (number is not None or ratio is not None): if size is not None and (number is not None or ratio is not None):
if number is None: if number is None:
@@ -144,13 +141,13 @@ class Chunks:
number = round(length / size) number = round(length / size)
elif ratio is not None: # number of chunks elif ratio is not None: # number of chunks
number = int(cpu_count * ratio) number = int(cpu_count * ratio)
self.iterators = [iter(arg) for arg in iterators] self.iterators = [iter(arg) for arg in iterables]
self.number_of_items = length self.number_of_items = length
self.length = max(1, min(length, number)) self.length = max(1, min(length, number))
self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length) self.lengths = [((i + 1) * self.number_of_items // self.length) - (i * self.number_of_items // self.length)
for i in range(self.length)] for i in range(self.length)]
def __iter__(self): def __iter__(self) -> Iterator[Any]:
for i in range(self.length): for i in range(self.length):
p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length) p, q = (i * self.number_of_items // self.length), ((i + 1) * self.number_of_items // self.length)
if len(self.iterators) == 1: if len(self.iterators) == 1:
@@ -158,37 +155,41 @@ class Chunks:
else: else:
yield [[next(iterator) for _ in range(q-p)] for iterator in self.iterators] yield [[next(iterator) for _ in range(q-p)] for iterator in self.iterators]
def __len__(self): def __len__(self) -> int:
return self.length return self.length
class ExternalBar: class Bar(Protocol):
def __init__(self, iterable=None, callback=None, total=0): def update(self, n: int = 1) -> None: ...
class ExternalBar(Iterable):
def __init__(self, iterable: Iterable = None, callback: Callable[[int], None] = None, total: int = 0) -> None:
self.iterable = iterable self.iterable = iterable
self.callback = callback self.callback = callback
self.total = total self.total = total
self._n = 0 self._n = 0
def __enter__(self): def __enter__(self) -> ExternalBar:
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args: Any, **kwargs: Any) -> None:
return return
def __iter__(self): def __iter__(self) -> Iterator[Any]:
for n, item in enumerate(self.iterable): for n, item in enumerate(self.iterable):
yield item yield item
self.n = n + 1 self.n = n + 1
def update(self, n=1): def update(self, n: int = 1) -> None:
self.n += n self.n += n
@property @property
def n(self): def n(self) -> int:
return self._n return self._n
@n.setter @n.setter
def n(self, n): def n(self, n: int) -> None:
if n != self._n: if n != self._n:
self._n = n self._n = n
if self.callback is not None: if self.callback is not None:
@@ -196,7 +197,8 @@ class ExternalBar:
class Task: class Task:
def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: int, fun=None, args=(), kwargs=None): def __init__(self, shared_memory: SharedMemory, pool_id: int, handle: Hashable, fun: Callable[[Any, ...], Any],
args: tuple[Any, ...] = (), kwargs: dict[str, Any] = None) -> None:
self.pool_id = pool_id self.pool_id = pool_id
self.handle = handle self.handle = handle
self.fun = shared_memory.add_item(fun, pool_id, handle) self.fun = shared_memory.add_item(fun, pool_id, handle)
@@ -208,20 +210,20 @@ class Task:
self.result = None self.result = None
self.pid = None self.pid = None
def __getstate__(self): def __getstate__(self) -> dict[str, Any]:
state = self.__dict__ state = self.__dict__
if self.result is not None: if self.result is not None:
state['result'] = dumps(self.result, recurse=True) state['result'] = dumps(self.result, recurse=True)
return state return state
def __setstate__(self, state): def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update({key: value for key, value in state.items() if key != 'result'}) self.__dict__.update({key: value for key, value in state.items() if key != 'result'})
if state['result'] is None: if state['result'] is None:
self.result = None self.result = None
else: else:
self.result = loads(state['result']) self.result = loads(state['result'])
def __call__(self, shared_memory: SharedMemory): def __call__(self, shared_memory: SharedMemory) -> Task:
if not self.done: if not self.done:
fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa fun = shared_memory[self.fun] or (lambda *args, **kwargs: None) # noqa
args = [shared_memory[arg] for arg in self.args] args = [shared_memory[arg] for arg in self.args]
@@ -230,7 +232,7 @@ class Task:
self.done = True self.done = True
return self return self
def __repr__(self): def __repr__(self) -> str:
if self.done: if self.done:
return f'Task {self.handle}, result: {self.result}' return f'Task {self.handle}, result: {self.result}'
else: else:
@@ -241,11 +243,11 @@ class Context(multiprocessing.context.SpawnContext):
""" Provide a context where child processes never are daemonic. """ """ Provide a context where child processes never are daemonic. """
class Process(multiprocessing.context.SpawnProcess): class Process(multiprocessing.context.SpawnProcess):
@property @property
def daemon(self): def daemon(self) -> bool:
return False return False
@daemon.setter @daemon.setter
def daemon(self, value): def daemon(self, value: bool) -> None:
pass pass
@@ -253,7 +255,8 @@ class ParPool:
""" Parallel processing with addition of iterations at any time and request of that result any time after that. """ Parallel processing with addition of iterations at any time and request of that result any time after that.
The target function and its argument can be changed at any time. The target function and its argument can be changed at any time.
""" """
def __init__(self, fun=None, args=None, kwargs=None, bar=None): def __init__(self, fun: Callable[[Any, ...], Any] = None,
args: tuple[Any] = None, kwargs: dict[str, Any] = None, bar: Bar = None):
self.id = id(self) self.id = id(self)
self.handle = 0 self.handle = 0
self.tasks = {} self.tasks = {}
@@ -267,22 +270,23 @@ class ParPool:
self.is_started = False self.is_started = False
self.last_task = None self.last_task = None
def __getstate__(self): def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
def __enter__(self, *args, **kwargs): def __enter__(self) -> ParPool:
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.close() self.close()
def close(self): def close(self) -> None:
self.spool.remove_pool(self.id) self.spool.remove_pool(self.id)
def __call__(self, n, handle=None, barlength=1): def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength) self.add_task(args=(n, *(() if self.args is None else self.args)), handle=handle, barlength=barlength)
def add_task(self, fun=None, args=None, kwargs=None, handle=None, barlength=1): def add_task(self, fun: Callable[[Any, ...], Any] = None, args: tuple[Any, ...] = None,
kwargs: dict[str, Any] = None, handle: Hashable = None, barlength: int = 1) -> Optional[int]:
if self.id not in self.spool.pools: if self.id not in self.spool.pools:
raise ValueError(f'this pool is not registered (anymore) with the pool singleton') raise ValueError(f'this pool is not registered (anymore) with the pool singleton')
if handle is None: if handle is None:
@@ -301,11 +305,11 @@ class ParPool:
if handle is None: if handle is None:
return new_handle return new_handle
def __setitem__(self, handle, n): def __setitem__(self, handle: Hashable, n: Any) -> None:
""" Add new iteration. """ """ Add new iteration. """
self(n, handle=handle) self(n, handle=handle)
def __getitem__(self, handle): def __getitem__(self, handle: Hashable) -> Any:
""" Request result and delete its record. Wait if result not yet available. """ """ Request result and delete its record. Wait if result not yet available. """
if handle not in self: if handle not in self:
raise ValueError(f'No handle: {handle} in pool') raise ValueError(f'No handle: {handle} in pool')
@@ -323,53 +327,56 @@ class ParPool:
self.tasks.pop(handle) self.tasks.pop(handle)
return result return result
def __contains__(self, handle): def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks return handle in self.tasks
def __delitem__(self, handle): def __delitem__(self, handle: Hashable) -> None:
self.tasks.pop(handle) self.tasks.pop(handle)
def get_newest(self): def get_newest(self) -> Any:
return self.spool.get_newest_for_pool(self) return self.spool.get_newest_for_pool(self)
def process_queue(self): def process_queue(self) -> None:
self.spool.process_queue() self.spool.process_queue()
def task_error(self, handle, error): def task_error(self, handle: Hashable, error: Exception) -> None:
if handle in self: if handle in self:
task = self.tasks[handle] task = self.tasks[handle]
print(f'Error from process working on iteration {handle}:\n') print(f'Error from process working on iteration {handle}:\n')
print(error) print(error)
print('Retrying in main thread...') print('Retrying in main thread...')
task(self.spool.shared_memory) task(self.spool.shared_memory)
self.spool.shared_memory.remove_task(self.id, task)
raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.') raise Exception(f'Function \'{task.name}\' cannot be executed by parfor, amend or execute in serial.')
self.spool.shared_memory.remove_task(self.id, self.tasks[handle])
def done(self, task): def done(self, task: Task) -> None:
if task.handle in self: # if not, the task was restarted erroneously if task.handle in self: # if not, the task was restarted erroneously
self.tasks[task.handle] = task self.tasks[task.handle] = task
if hasattr(self.bar, 'update'): if hasattr(self.bar, 'update'):
self.bar.update(self.bar_lengths.pop(task.handle)) self.bar.update(self.bar_lengths.pop(task.handle))
self.spool.shared_memory.remove_task(self.id, task) self.spool.shared_memory.remove_task(self.id, task)
def started(self, handle, pid): def started(self, handle: Hashable, pid: int) -> None:
self.is_started = True self.is_started = True
if handle in self: # if not, the task was restarted erroneously if handle in self: # if not, the task was restarted erroneously
self.tasks[handle].pid = pid self.tasks[handle].pid = pid
@property @property
def working(self): def working(self) -> bool:
return not all([task.pid is None for task in self.tasks.values()]) return not all([task.pid is None for task in self.tasks.values()])
class PoolSingleton: class PoolSingleton:
""" There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a """ There can be only one pool at a time, but the pool can be restarted by calling close() and then constructing a
new pool. The pool will close itself after 10 minutes of idle time. """ new pool. The pool will close itself after 10 minutes of idle time. """
def __new__(cls, *args, **kwargs):
if hasattr(cls, 'instance') and cls.instance is not None: # noqa restart if any workers have shut down instance = None
def __new__(cls, *args: Any, **kwargs: Any) -> PoolSingleton:
if cls.instance is not None: # restart if any workers have shut down
if cls.instance.n_workers.value < cls.instance.n_processes: if cls.instance.n_workers.value < cls.instance.n_processes:
cls.instance.close() cls.instance.close()
if not hasattr(cls, 'instance') or cls.instance is None or not cls.instance.is_alive: # noqa if cls.instance is None or not cls.instance.is_alive:
new = super().__new__(cls) new = super().__new__(cls)
new.n_processes = cpu_count new.n_processes = cpu_count
new.instance = new new.instance = new
@@ -387,32 +394,32 @@ class PoolSingleton:
new.handle = 0 new.handle = 0
new.pools = {} new.pools = {}
cls.instance = new cls.instance = new
return cls.instance # noqa return cls.instance
def __init__(self, parpool=None): # noqa def __init__(self, parpool: Parpool = None) -> None: # noqa
if parpool is not None: if parpool is not None:
self.pools[parpool.id] = parpool self.pools[parpool.id] = parpool
def __getstate__(self): def __getstate__(self) -> NoReturn:
raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.') raise RuntimeError(f'Cannot pickle {self.__class__.__name__} object.')
# def __del__(self): # def __del__(self):
# self.close() # self.close()
def remove_pool(self, pool_id): def remove_pool(self, pool_id: int) -> None:
self.shared_memory.remove_pool(pool_id) self.shared_memory.remove_pool(pool_id)
if pool_id in self.pools: if pool_id in self.pools:
self.pools.pop(pool_id) self.pools.pop(pool_id)
def error(self, error): def error(self, error: Exception) -> NoReturn:
self.close() self.close()
raise Exception(f'Error occurred in worker: {error}') raise Exception(f'Error occurred in worker: {error}')
def process_queue(self): def process_queue(self) -> None:
while self.get_from_queue(): while self.get_from_queue():
pass pass
def get_from_queue(self): def get_from_queue(self) -> bool:
""" Get an item from the queue and store it, return True if more messages are waiting. """ """ Get an item from the queue and store it, return True if more messages are waiting. """
try: try:
code, pool_id, *args = self.queue_out.get(True, 0.02) code, pool_id, *args = self.queue_out.get(True, 0.02)
@@ -430,7 +437,7 @@ class PoolSingleton:
warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.') warn(f'Task {task.handle} was restarted because process {task.pid} was probably killed.')
return False return False
def add_task(self, task): def add_task(self, task: Task) -> None:
""" Add new iteration, using optional manually defined handle.""" """ Add new iteration, using optional manually defined handle."""
if self.is_alive and not self.event.is_set(): if self.is_alive and not self.event.is_set():
while self.queue_in.full(): while self.queue_in.full():
@@ -438,7 +445,7 @@ class PoolSingleton:
self.queue_in.put(task) self.queue_in.put(task)
self.shared_memory.garbage_collect() self.shared_memory.garbage_collect()
def get_newest_for_pool(self, pool): def get_newest_for_pool(self, pool: ParPool) -> tuple[Hashable, Any]:
""" Request the newest key and result and delete its record. Wait if result not yet available. """ """ Request the newest key and result and delete its record. Wait if result not yet available. """
while len(pool.tasks): while len(pool.tasks):
self.get_from_queue() self.get_from_queue()
@@ -449,7 +456,7 @@ class PoolSingleton:
return handle, result return handle, result
@classmethod @classmethod
def close(cls): def close(cls) -> None:
if hasattr(cls, 'instance') and cls.instance is not None: if hasattr(cls, 'instance') and cls.instance is not None:
instance = cls.instance instance = cls.instance
cls.instance = None cls.instance = None
@@ -465,45 +472,47 @@ class PoolSingleton:
except OSError: except OSError:
pass pass
def close_queue(queue): def close_queue(queue: multiprocessing.queues.Queue) -> None:
empty_queue(queue) empty_queue(queue) # noqa
if not queue._closed: # noqa if not queue._closed: # noqa
queue.close() queue.close()
queue.join_thread() queue.join_thread()
if instance.is_alive: if instance.is_alive:
instance.is_alive = False # noqa instance.is_alive = False
instance.event.set() instance.event.set()
instance.pool.close() instance.pool.close()
t = time() t = time()
while instance.n_workers.value: # noqa while instance.n_workers.value:
empty_queue(instance.queue_in) # noqa empty_queue(instance.queue_in)
empty_queue(instance.queue_out) # noqa empty_queue(instance.queue_out)
if time() - t > 10: if time() - t > 10:
warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.') # noqa warn(f'Parfor: Closing pool timed out, {instance.n_workers.value} processes still alive.')
instance.pool.terminate() instance.pool.terminate()
break break
empty_queue(instance.queue_in) # noqa empty_queue(instance.queue_in)
empty_queue(instance.queue_out) # noqa empty_queue(instance.queue_out)
instance.pool.join() instance.pool.join()
close_queue(instance.queue_in) # noqa close_queue(instance.queue_in)
close_queue(instance.queue_out) # noqa close_queue(instance.queue_out)
instance.manager.shutdown() instance.manager.shutdown()
instance.handle = 0 # noqa instance.handle = 0
class Worker: class Worker:
""" Manages executing the target function which will be executed in different processes. """ """ Manages executing the target function which will be executed in different processes. """
nested = False nested = False
def __init__(self, shared_memory: SharedMemory, queue_in, queue_out, n_workers, event): def __init__(self, shared_memory: SharedMemory, queue_in: multiprocessing.queues.Queue,
queue_out: multiprocessing.queues.Queue, n_workers: multiprocessing.Value,
event: multiprocessing.Event) -> None:
self.shared_memory = shared_memory self.shared_memory = shared_memory
self.queue_in = queue_in self.queue_in = queue_in
self.queue_out = queue_out self.queue_out = queue_out
self.n_workers = n_workers self.n_workers = n_workers
self.event = event self.event = event
def add_to_queue(self, *args): def add_to_queue(self, *args: Any) -> None:
while not self.event.is_set(): while not self.event.is_set():
try: try:
self.queue_out.put(args, timeout=0.1) self.queue_out.put(args, timeout=0.1)
@@ -511,7 +520,7 @@ class Worker:
except multiprocessing.queues.Full: # noqa except multiprocessing.queues.Full: # noqa
continue continue
def __call__(self): def __call__(self) -> None:
Worker.nested = True Worker.nested = True
pid = getpid() pid = getpid()
last_active_time = time() last_active_time = time()
@@ -537,8 +546,10 @@ class Worker:
self.n_workers.value -= 1 self.n_workers.value -= 1
def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=True, terminator=None, def pmap(fun: Callable[[Iteration, Any, ...], Result], iterable: Iterable[Iteration] = None,
serial=None, length=None, **bar_kwargs): args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None, total: int = None, desc: str = None,
bar: Bar | bool = True, terminator: Callable[[], None] = None, serial: bool = None, length: int = None,
**bar_kwargs: Any) -> list[Result]:
""" map a function fun to each iteration in iterable """ map a function fun to each iteration in iterable
use as a function: pmap use as a function: pmap
use as a decorator: parfor use as a decorator: parfor
@@ -620,8 +631,8 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
iterable = Chunks(iterable, ratio=5, length=total) iterable = Chunks(iterable, ratio=5, length=total)
@wraps(fun) @wraps(fun)
def chunk_fun(iterator, *args, **kwargs): # noqa def chunk_fun(iterable: Iterable, *args: Any, **kwargs: Any) -> list[Result]: # noqa
return [fun(i, *args, **kwargs) for i in iterator] # noqa return [fun(iteration, *args, **kwargs) for iteration in iterable]
args = args or () args = args or ()
kwargs = kwargs or {} kwargs = kwargs or {}
@@ -636,13 +647,13 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
if callable(bar): if callable(bar):
return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], []) return sum([chunk_fun(c, *args, **kwargs) for c in ExternalBar(iterable, bar)], [])
else: else:
return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], []) # noqa return sum([chunk_fun(c, *args, **kwargs) for c in tqdm(iterable, **bar_kwargs)], [])
else: # parallel case else: # parallel case
with ExitStack() as stack: with ExitStack() as stack:
if callable(bar): if callable(bar):
bar = stack.enter_context(ExternalBar(callback=bar)) # noqa bar = stack.enter_context(ExternalBar(callback=bar))
else: else:
bar = stack.enter_context(tqdm(**bar_kwargs)) # noqa bar = stack.enter_context(tqdm(**bar_kwargs))
with ParPool(chunk_fun, args, kwargs, bar) as p: with ParPool(chunk_fun, args, kwargs, bar) as p:
for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue for i, (j, l) in enumerate(zip(iterable, iterable.lengths)): # add work to the queue
p(j, handle=i, barlength=iterable.lengths[i]) p(j, handle=i, barlength=iterable.lengths[i])
@@ -655,26 +666,7 @@ def pmap(fun, iterable=None, args=None, kwargs=None, total=None, desc=None, bar=
@wraps(pmap) @wraps(pmap)
def parfor(*args, **kwargs): def parfor(*args: Any, **kwargs: Any) -> Callable[[Callable[[Iteration, Any, ...], Result]], list[Result]]:
def decfun(fun): def decfun(fun: Callable[[Iteration, Any, ...], Result]) -> list[Result]:
return pmap(fun, *args, **kwargs) return pmap(fun, *args, **kwargs)
return decfun return decfun
def deprecated(cls, name):
""" This is a decorator which can be used to mark functions and classes as deprecated. It will result in a warning
being emitted when the function or class is used."""
@wraps(cls)
def wrapper(*args, **kwargs):
warn(f'parfor: use of \'{name}\' is deprecated, use \'{cls.__name__}\' instead',
category=DeprecationWarning, stacklevel=2)
warn(f'parfor: use of \'{name}\' is deprecated, use \'{cls.__name__}\' instead',
category=DeprecationWarning, stacklevel=3)
return cls(*args, **kwargs)
return wrapper
# backwards compatibility
parpool = deprecated(ParPool, 'parpool')
Parpool = deprecated(ParPool, 'Parpool')
chunks = deprecated(Chunks, 'chunks')

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import copyreg import copyreg
from io import BytesIO from io import BytesIO
from pickle import PicklingError from pickle import PicklingError
from typing import Any, Callable
import dill import dill
@@ -8,14 +11,14 @@ loads = dill.loads
class CouldNotBePickled: class CouldNotBePickled:
def __init__(self, class_name): def __init__(self, class_name: str) -> None:
self.class_name = class_name self.class_name = class_name
def __repr__(self): def __repr__(self) -> str:
return f"Item of type '{self.class_name}' could not be pickled and was omitted." return f"Item of type '{self.class_name}' could not be pickled and was omitted."
@classmethod @classmethod
def reduce(cls, item): def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
return cls, (type(item).__name__,) return cls, (type(item).__name__,)
@@ -24,7 +27,7 @@ class Pickler(dill.Pickler):
You probably didn't want to use these parts anyhow. You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable. However, if you did, you'll have to find some way to make them picklable.
""" """
def save(self, obj, save_persistent_id=True): def save(self, obj: Any, save_persistent_id: bool = True) -> None:
""" Copied from pickle and amended. """ """ Copied from pickle and amended. """
self.framer.commit_frame() self.framer.commit_frame()
@@ -93,8 +96,8 @@ class Pickler(dill.Pickler):
raise PicklingError("%s must return string or tuple" % reduce) raise PicklingError("%s must return string or tuple" % reduce)
# Assert that it returned an appropriately sized tuple # Assert that it returned an appropriately sized tuple
l = len(rv) length = len(rv)
if not (2 <= l <= 6): if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have " raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce) "two to six elements" % reduce)
@@ -105,11 +108,12 @@ class Pickler(dill.Pickler):
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=True, **kwds): def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
**kwds: Any) -> bytes:
"""pickle an object to a string""" """pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol) protocol = dill.settings['protocol'] if protocol is None else int(protocol)
_kwds = kwds.copy() _kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
file = BytesIO() with BytesIO() as file:
Pickler(file, protocol, **_kwds).dump(obj) Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue() return file.getvalue()

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "parfor" name = "parfor"
version = "2024.3.0" version = "2024.4.0"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = ["Wim Pomp <wimpomp@gmail.com>"] authors = ["Wim Pomp <wimpomp@gmail.com>"]
license = "GPLv3" license = "GPLv3"
@@ -9,7 +9,7 @@ keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
repository = "https://github.com/wimpomp/parfor" repository = "https://github.com/wimpomp/parfor"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.10"
tqdm = ">=4.50.0" tqdm = ">=4.50.0"
dill = ">=0.3.0" dill = ">=0.3.0"
pytest = { version = "*", optional = true } pytest = { version = "*", optional = true }
@@ -17,6 +17,9 @@ pytest = { version = "*", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
test = ["pytest"] test = ["pytest"]
[tool.isort]
line_length = 119
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@@ -1,7 +1,9 @@
import pytest
from parfor import Chunks, ParPool, parfor, pmap
from dataclasses import dataclass from dataclasses import dataclass
import pytest
from parfor import Chunks, ParPool, parfor, pmap
class SequenceIterator: class SequenceIterator:
def __init__(self, sequence): def __init__(self, sequence):