- add custom pickler back in
This commit is contained in:
@@ -12,6 +12,8 @@ Tested on linux, Windows and OSX with python 3.10 and 3.12.
|
|||||||
- Easy to use
|
- Easy to use
|
||||||
- Progress bars are built-in
|
- Progress bars are built-in
|
||||||
- Retry the task in the main process upon failure for easy debugging
|
- Retry the task in the main process upon failure for easy debugging
|
||||||
|
- Using a modified version of dill when ray fails to serialize an object:
|
||||||
|
a lot more objects can be used when parallelizing
|
||||||
|
|
||||||
## How it works
|
## How it works
|
||||||
[Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding
|
[Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ import ray
|
|||||||
from numpy.typing import ArrayLike, DTypeLike
|
from numpy.typing import ArrayLike, DTypeLike
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from .pickler import dumps, loads
|
||||||
|
|
||||||
|
|
||||||
__version__ = version("parfor")
|
__version__ = version("parfor")
|
||||||
cpu_count = int(os.cpu_count())
|
cpu_count = int(os.cpu_count())
|
||||||
@@ -229,7 +231,7 @@ class Task:
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.fun = fun
|
self.fun = fun
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs or {}
|
||||||
self.name = fun.__name__ if hasattr(fun, "__name__") else None
|
self.name = fun.__name__ if hasattr(fun, "__name__") else None
|
||||||
self.done = False
|
self.done = False
|
||||||
self.result = None
|
self.result = None
|
||||||
@@ -239,37 +241,51 @@ class Task:
|
|||||||
self.status = "starting"
|
self.status = "starting"
|
||||||
self.allow_output = allow_output
|
self.allow_output = allow_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get(item: tuple[bool, Any]) -> Any:
|
||||||
|
if item[0]:
|
||||||
|
return loads(ray.get(item[1]))
|
||||||
|
else:
|
||||||
|
return ray.get(item[1])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def put(item: Any) -> tuple[bool, Any]:
|
||||||
|
try:
|
||||||
|
return False, ray.put(item)
|
||||||
|
except Exception: # noqa
|
||||||
|
return True, ray.put(dumps(item))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fun(self) -> Callable[[Any, ...], Any]:
|
def fun(self) -> Callable[[Any, ...], Any]:
|
||||||
return ray.get(self._fun)
|
return self.get(self._fun)
|
||||||
|
|
||||||
@fun.setter
|
@fun.setter
|
||||||
def fun(self, fun: Callable[[Any, ...], Any]):
|
def fun(self, fun: Callable[[Any, ...], Any]):
|
||||||
self._fun = ray.put(fun)
|
self._fun = self.put(fun)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> tuple[Any, ...]:
|
def args(self) -> tuple[Any, ...]:
|
||||||
return tuple([ray.get(arg) for arg in self._args])
|
return tuple([self.get(arg) for arg in self._args])
|
||||||
|
|
||||||
@args.setter
|
@args.setter
|
||||||
def args(self, args: tuple[Any, ...]) -> None:
|
def args(self, args: tuple[Any, ...]) -> None:
|
||||||
self._args = [ray.put(arg) for arg in args]
|
self._args = [self.put(arg) for arg in args]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kwargs(self) -> dict[str, Any]:
|
def kwargs(self) -> dict[str, Any]:
|
||||||
return {key: ray.get(value) for key, value in self._kwargs.items()}
|
return {key: self.get(value) for key, value in self._kwargs.items()}
|
||||||
|
|
||||||
@kwargs.setter
|
@kwargs.setter
|
||||||
def kwargs(self, kwargs: dict[str, Any]) -> None:
|
def kwargs(self, kwargs: dict[str, Any]) -> None:
|
||||||
self._kwargs = {key: ray.put(value) for key, value in kwargs.items()}
|
self._kwargs = {key: self.put(value) for key, value in kwargs.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def result(self) -> Any:
|
def result(self) -> Any:
|
||||||
return ray.get(self._result)
|
return self.get(self._result)
|
||||||
|
|
||||||
@result.setter
|
@result.setter
|
||||||
def result(self, result: Any) -> None:
|
def result(self, result: Any) -> None:
|
||||||
self._result = ray.put(result)
|
self._result = self.put(result)
|
||||||
|
|
||||||
def __call__(self) -> Task:
|
def __call__(self) -> Task:
|
||||||
if not self.done:
|
if not self.done:
|
||||||
@@ -324,6 +340,9 @@ class ParPool:
|
|||||||
barlength=barlength,
|
barlength=barlength,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def add_task(
|
def add_task(
|
||||||
self,
|
self,
|
||||||
fun: Callable[[Any, ...], Any] = None,
|
fun: Callable[[Any, ...], Any] = None,
|
||||||
@@ -363,8 +382,12 @@ class ParPool:
|
|||||||
"""Request result and delete its record. Wait if result not yet available."""
|
"""Request result and delete its record. Wait if result not yet available."""
|
||||||
if handle not in self:
|
if handle not in self:
|
||||||
raise ValueError(f"No handle: {handle} in pool")
|
raise ValueError(f"No handle: {handle} in pool")
|
||||||
task = ray.get(self.tasks[handle].future)
|
task = self.tasks[handle]
|
||||||
return self.finalize_task(task)
|
if task.future is None:
|
||||||
|
return task.result
|
||||||
|
else:
|
||||||
|
task = ray.get(self.tasks[handle].future)
|
||||||
|
return self.finalize_task(task)
|
||||||
|
|
||||||
def __contains__(self, handle: Hashable) -> bool:
|
def __contains__(self, handle: Hashable) -> bool:
|
||||||
return handle in self.tasks
|
return handle in self.tasks
|
||||||
@@ -387,10 +410,13 @@ class ParPool:
|
|||||||
while True:
|
while True:
|
||||||
if self.tasks:
|
if self.tasks:
|
||||||
for handle, task in self.tasks.items():
|
for handle, task in self.tasks.items():
|
||||||
if handle in self.bar_lengths:
|
if handle in self.tasks:
|
||||||
try:
|
try:
|
||||||
task = ray.get(task.future, timeout=0.01)
|
if task.future is None:
|
||||||
return task.handle, self.finalize_task(task)
|
return task.handle, task.result
|
||||||
|
else:
|
||||||
|
task = ray.get(task.future, timeout=0.01)
|
||||||
|
return task.handle, self.finalize_task(task)
|
||||||
except ray.exceptions.GetTimeoutError:
|
except ray.exceptions.GetTimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -421,7 +447,9 @@ class PoolSingleton:
|
|||||||
cls.instance = super().__new__(cls)
|
cls.instance = super().__new__(cls)
|
||||||
cls.instance.n_processes = n_processes
|
cls.instance.n_processes = n_processes
|
||||||
if ray.is_initialized():
|
if ray.is_initialized():
|
||||||
warnings.warn("not setting n_processes because parallel pool was already initialized")
|
if cls.instance.n_processes != n_processes:
|
||||||
|
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
|
||||||
|
f"probably with n_processes={cls.instance.n_processes}")
|
||||||
else:
|
else:
|
||||||
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
||||||
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)
|
||||||
|
|||||||
119
parfor/pickler.py
Normal file
119
parfor/pickler.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copyreg
|
||||||
|
from io import BytesIO
|
||||||
|
from pickle import PicklingError
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import dill
|
||||||
|
|
||||||
|
loads = dill.loads
|
||||||
|
|
||||||
|
|
||||||
|
class CouldNotBePickled:
|
||||||
|
def __init__(self, class_name: str) -> None:
|
||||||
|
self.class_name = class_name
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
|
||||||
|
return cls, (type(item).__name__,)
|
||||||
|
|
||||||
|
|
||||||
|
class Pickler(dill.Pickler):
|
||||||
|
""" Overload dill to ignore unpicklable parts of objects.
|
||||||
|
You probably didn't want to use these parts anyhow.
|
||||||
|
However, if you did, you'll have to find some way to make them picklable.
|
||||||
|
"""
|
||||||
|
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
|
||||||
|
""" Copied from pickle and amended. """
|
||||||
|
self.framer.commit_frame()
|
||||||
|
|
||||||
|
# Check for persistent id (defined by a subclass)
|
||||||
|
pid = self.persistent_id(obj)
|
||||||
|
if pid is not None and save_persistent_id:
|
||||||
|
self.save_pers(pid)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check the memo
|
||||||
|
x = self.memo.get(id(obj))
|
||||||
|
if x is not None:
|
||||||
|
self.write(self.get(x[0]))
|
||||||
|
return
|
||||||
|
|
||||||
|
rv = NotImplemented
|
||||||
|
reduce = getattr(self, "reducer_override", None)
|
||||||
|
if reduce is not None:
|
||||||
|
rv = reduce(obj)
|
||||||
|
|
||||||
|
if rv is NotImplemented:
|
||||||
|
# Check the type dispatch table
|
||||||
|
t = type(obj)
|
||||||
|
f = self.dispatch.get(t)
|
||||||
|
if f is not None:
|
||||||
|
f(self, obj) # Call unbound method with explicit self
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check private dispatch table if any, or else
|
||||||
|
# copyreg.dispatch_table
|
||||||
|
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
|
||||||
|
if reduce is not None:
|
||||||
|
rv = reduce(obj)
|
||||||
|
else:
|
||||||
|
# Check for a class with a custom metaclass; treat as regular
|
||||||
|
# class
|
||||||
|
if issubclass(t, type):
|
||||||
|
self.save_global(obj)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for a __reduce_ex__ method, fall back to __reduce__
|
||||||
|
reduce = getattr(obj, "__reduce_ex__", None)
|
||||||
|
try:
|
||||||
|
if reduce is not None:
|
||||||
|
rv = reduce(self.proto)
|
||||||
|
else:
|
||||||
|
reduce = getattr(obj, "__reduce__", None)
|
||||||
|
if reduce is not None:
|
||||||
|
rv = reduce()
|
||||||
|
else:
|
||||||
|
raise PicklingError("Can't pickle %r object: %r" %
|
||||||
|
(t.__name__, obj))
|
||||||
|
except Exception: # noqa
|
||||||
|
rv = CouldNotBePickled.reduce(obj)
|
||||||
|
|
||||||
|
# Check for string returned by reduce(), meaning "save as global"
|
||||||
|
if isinstance(rv, str):
|
||||||
|
try:
|
||||||
|
self.save_global(obj, rv)
|
||||||
|
except Exception: # noqa
|
||||||
|
self.save_global(obj, CouldNotBePickled.reduce(obj))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Assert that reduce() returned a tuple
|
||||||
|
if not isinstance(rv, tuple):
|
||||||
|
raise PicklingError("%s must return string or tuple" % reduce)
|
||||||
|
|
||||||
|
# Assert that it returned an appropriately sized tuple
|
||||||
|
length = len(rv)
|
||||||
|
if not (2 <= length <= 6):
|
||||||
|
raise PicklingError("Tuple returned by %s must have "
|
||||||
|
"two to six elements" % reduce)
|
||||||
|
|
||||||
|
# Save the reduce() output and finally memoize the object
|
||||||
|
try:
|
||||||
|
self.save_reduce(obj=obj, *rv)
|
||||||
|
except Exception: # noqa
|
||||||
|
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
|
||||||
|
|
||||||
|
|
||||||
|
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
|
||||||
|
**kwds: Any) -> bytes:
|
||||||
|
"""pickle an object to a string"""
|
||||||
|
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
||||||
|
_kwds = kwds.copy()
|
||||||
|
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
||||||
|
with BytesIO() as file:
|
||||||
|
Pickler(file, protocol, **_kwds).dump(obj)
|
||||||
|
return file.getvalue()
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "parfor"
|
name = "parfor"
|
||||||
version = "2026.1.1"
|
version = "2026.1.2"
|
||||||
description = "A package to mimic the use of parfor as done in Matlab."
|
description = "A package to mimic the use of parfor as done in Matlab."
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
|
||||||
@@ -10,6 +10,7 @@ readme = "README.md"
|
|||||||
keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
|
keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"dill >= 0.3.0",
|
||||||
"numpy",
|
"numpy",
|
||||||
"tqdm >= 4.50.0",
|
"tqdm >= 4.50.0",
|
||||||
"ray",
|
"ray",
|
||||||
|
|||||||
Reference in New Issue
Block a user