- add custom pickler back in
This commit is contained in:
@@ -12,6 +12,8 @@ Tested on linux, Windows and OSX with python 3.10 and 3.12.
|
||||
- Easy to use
|
||||
- Progress bars are built-in
|
||||
- Retry the task in the main process upon failure for easy debugging
|
||||
- Using a modified version of dill when ray fails to serialize an object:
|
||||
a lot more objects can be used when parallelizing
|
||||
|
||||
## How it works
|
||||
[Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding
|
||||
|
||||
@@ -28,6 +28,8 @@ import ray
|
||||
from numpy.typing import ArrayLike, DTypeLike
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .pickler import dumps, loads
|
||||
|
||||
|
||||
__version__ = version("parfor")
|
||||
cpu_count = int(os.cpu_count())
|
||||
@@ -229,7 +231,7 @@ class Task:
|
||||
self.handle = handle
|
||||
self.fun = fun
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.kwargs = kwargs or {}
|
||||
self.name = fun.__name__ if hasattr(fun, "__name__") else None
|
||||
self.done = False
|
||||
self.result = None
|
||||
@@ -239,37 +241,51 @@ class Task:
|
||||
self.status = "starting"
|
||||
self.allow_output = allow_output
|
||||
|
||||
@staticmethod
|
||||
def get(item: tuple[bool, Any]) -> Any:
|
||||
if item[0]:
|
||||
return loads(ray.get(item[1]))
|
||||
else:
|
||||
return ray.get(item[1])
|
||||
|
||||
@staticmethod
|
||||
def put(item: Any) -> tuple[bool, Any]:
|
||||
try:
|
||||
return False, ray.put(item)
|
||||
except Exception: # noqa
|
||||
return True, ray.put(dumps(item))
|
||||
|
||||
@property
|
||||
def fun(self) -> Callable[[Any, ...], Any]:
|
||||
return ray.get(self._fun)
|
||||
return self.get(self._fun)
|
||||
|
||||
@fun.setter
|
||||
def fun(self, fun: Callable[[Any, ...], Any]):
|
||||
self._fun = ray.put(fun)
|
||||
self._fun = self.put(fun)
|
||||
|
||||
@property
|
||||
def args(self) -> tuple[Any, ...]:
|
||||
return tuple([ray.get(arg) for arg in self._args])
|
||||
return tuple([self.get(arg) for arg in self._args])
|
||||
|
||||
@args.setter
|
||||
def args(self, args: tuple[Any, ...]) -> None:
|
||||
self._args = [ray.put(arg) for arg in args]
|
||||
self._args = [self.put(arg) for arg in args]
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict[str, Any]:
|
||||
return {key: ray.get(value) for key, value in self._kwargs.items()}
|
||||
return {key: self.get(value) for key, value in self._kwargs.items()}
|
||||
|
||||
@kwargs.setter
|
||||
def kwargs(self, kwargs: dict[str, Any]) -> None:
|
||||
self._kwargs = {key: ray.put(value) for key, value in kwargs.items()}
|
||||
self._kwargs = {key: self.put(value) for key, value in kwargs.items()}
|
||||
|
||||
@property
|
||||
def result(self) -> Any:
|
||||
return ray.get(self._result)
|
||||
return self.get(self._result)
|
||||
|
||||
@result.setter
|
||||
def result(self, result: Any) -> None:
|
||||
self._result = ray.put(result)
|
||||
self._result = self.put(result)
|
||||
|
||||
def __call__(self) -> Task:
|
||||
if not self.done:
|
||||
@@ -324,6 +340,9 @@ class ParPool:
|
||||
barlength=barlength,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
fun: Callable[[Any, ...], Any] = None,
|
||||
@@ -363,6 +382,10 @@ class ParPool:
|
||||
"""Request result and delete its record. Wait if result not yet available."""
|
||||
if handle not in self:
|
||||
raise ValueError(f"No handle: {handle} in pool")
|
||||
task = self.tasks[handle]
|
||||
if task.future is None:
|
||||
return task.result
|
||||
else:
|
||||
task = ray.get(self.tasks[handle].future)
|
||||
return self.finalize_task(task)
|
||||
|
||||
@@ -387,8 +410,11 @@ class ParPool:
|
||||
while True:
|
||||
if self.tasks:
|
||||
for handle, task in self.tasks.items():
|
||||
if handle in self.bar_lengths:
|
||||
if handle in self.tasks:
|
||||
try:
|
||||
if task.future is None:
|
||||
return task.handle, task.result
|
||||
else:
|
||||
task = ray.get(task.future, timeout=0.01)
|
||||
return task.handle, self.finalize_task(task)
|
||||
except ray.exceptions.GetTimeoutError:
|
||||
@@ -421,7 +447,9 @@ class PoolSingleton:
|
||||
cls.instance = super().__new__(cls)
|
||||
cls.instance.n_processes = n_processes
|
||||
if ray.is_initialized():
|
||||
warnings.warn("not setting n_processes because parallel pool was already initialized")
|
||||
if cls.instance.n_processes != n_processes:
|
||||
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
|
||||
f"probably with n_processes={cls.instance.n_processes}")
|
||||
else:
|
||||
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
||||
|
||||
119
parfor/pickler.py
Normal file
119
parfor/pickler.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copyreg
|
||||
from io import BytesIO
|
||||
from pickle import PicklingError
|
||||
from typing import Any, Callable
|
||||
|
||||
import dill
|
||||
|
||||
loads = dill.loads
|
||||
|
||||
|
||||
class CouldNotBePickled:
|
||||
def __init__(self, class_name: str) -> None:
|
||||
self.class_name = class_name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
|
||||
|
||||
@classmethod
|
||||
def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
|
||||
return cls, (type(item).__name__,)
|
||||
|
||||
|
||||
class Pickler(dill.Pickler):
|
||||
""" Overload dill to ignore unpicklable parts of objects.
|
||||
You probably didn't want to use these parts anyhow.
|
||||
However, if you did, you'll have to find some way to make them picklable.
|
||||
"""
|
||||
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
||||
""" Copied from pickle and amended. """
|
||||
self.framer.commit_frame()
|
||||
|
||||
# Check for persistent id (defined by a subclass)
|
||||
pid = self.persistent_id(obj)
|
||||
if pid is not None and save_persistent_id:
|
||||
self.save_pers(pid)
|
||||
return
|
||||
|
||||
# Check the memo
|
||||
x = self.memo.get(id(obj))
|
||||
if x is not None:
|
||||
self.write(self.get(x[0]))
|
||||
return
|
||||
|
||||
rv = NotImplemented
|
||||
reduce = getattr(self, "reducer_override", None)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
|
||||
if rv is NotImplemented:
|
||||
# Check the type dispatch table
|
||||
t = type(obj)
|
||||
f = self.dispatch.get(t)
|
||||
if f is not None:
|
||||
f(self, obj) # Call unbound method with explicit self
|
||||
return
|
||||
|
||||
# Check private dispatch table if any, or else
|
||||
# copyreg.dispatch_table
|
||||
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
else:
|
||||
# Check for a class with a custom metaclass; treat as regular
|
||||
# class
|
||||
if issubclass(t, type):
|
||||
self.save_global(obj)
|
||||
return
|
||||
|
||||
# Check for a __reduce_ex__ method, fall back to __reduce__
|
||||
reduce = getattr(obj, "__reduce_ex__", None)
|
||||
try:
|
||||
if reduce is not None:
|
||||
rv = reduce(self.proto)
|
||||
else:
|
||||
reduce = getattr(obj, "__reduce__", None)
|
||||
if reduce is not None:
|
||||
rv = reduce()
|
||||
else:
|
||||
raise PicklingError("Can't pickle %r object: %r" %
|
||||
(t.__name__, obj))
|
||||
except Exception: # noqa
|
||||
rv = CouldNotBePickled.reduce(obj)
|
||||
|
||||
# Check for string returned by reduce(), meaning "save as global"
|
||||
if isinstance(rv, str):
|
||||
try:
|
||||
self.save_global(obj, rv)
|
||||
except Exception: # noqa
|
||||
self.save_global(obj, CouldNotBePickled.reduce(obj))
|
||||
return
|
||||
|
||||
# Assert that reduce() returned a tuple
|
||||
if not isinstance(rv, tuple):
|
||||
raise PicklingError("%s must return string or tuple" % reduce)
|
||||
|
||||
# Assert that it returned an appropriately sized tuple
|
||||
length = len(rv)
|
||||
if not (2 <= length <= 6):
|
||||
raise PicklingError("Tuple returned by %s must have "
|
||||
"two to six elements" % reduce)
|
||||
|
||||
# Save the reduce() output and finally memoize the object
|
||||
try:
|
||||
self.save_reduce(obj=obj, *rv)
|
||||
except Exception: # noqa
|
||||
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
||||
|
||||
|
||||
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
|
||||
**kwds: Any) -> bytes:
|
||||
"""pickle an object to a string"""
|
||||
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
||||
_kwds = kwds.copy()
|
||||
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
||||
with BytesIO() as file:
|
||||
Pickler(file, protocol, **_kwds).dump(obj)
|
||||
return file.getvalue()
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "parfor"
|
||||
version = "2026.1.1"
|
||||
version = "2026.1.2"
|
||||
description = "A package to mimic the use of parfor as done in Matlab."
|
||||
authors = [
|
||||
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
||||
@@ -10,6 +10,7 @@ readme = "README.md"
|
||||
keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"dill >= 0.3.0",
|
||||
"numpy",
|
||||
"tqdm >= 4.50.0",
|
||||
"ray",
|
||||
|
||||
Reference in New Issue
Block a user