- add custom pickler back in

This commit is contained in:
Wim Pomp
2026-01-08 16:25:24 +01:00
parent f3737ffd44
commit 2cff1a1515
4 changed files with 166 additions and 16 deletions

View File

@@ -12,6 +12,8 @@ Tested on linux, Windows and OSX with python 3.10 and 3.12.
- Easy to use - Easy to use
- Progress bars are built-in - Progress bars are built-in
- Retry the task in the main process upon failure for easy debugging - Retry the task in the main process upon failure for easy debugging
- Using a modified version of dill when ray fails to serialize an object:
a lot more objects can be used when parallelizing
## How it works ## How it works
[Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding [Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding

View File

@@ -28,6 +28,8 @@ import ray
from numpy.typing import ArrayLike, DTypeLike from numpy.typing import ArrayLike, DTypeLike
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .pickler import dumps, loads
__version__ = version("parfor") __version__ = version("parfor")
cpu_count = int(os.cpu_count()) cpu_count = int(os.cpu_count())
@@ -229,7 +231,7 @@ class Task:
self.handle = handle self.handle = handle
self.fun = fun self.fun = fun
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs or {}
self.name = fun.__name__ if hasattr(fun, "__name__") else None self.name = fun.__name__ if hasattr(fun, "__name__") else None
self.done = False self.done = False
self.result = None self.result = None
@@ -239,37 +241,51 @@ class Task:
self.status = "starting" self.status = "starting"
self.allow_output = allow_output self.allow_output = allow_output
@staticmethod
def get(item: tuple[bool, Any]) -> Any:
if item[0]:
return loads(ray.get(item[1]))
else:
return ray.get(item[1])
@staticmethod
def put(item: Any) -> tuple[bool, Any]:
try:
return False, ray.put(item)
except Exception: # noqa
return True, ray.put(dumps(item))
@property @property
def fun(self) -> Callable[[Any, ...], Any]: def fun(self) -> Callable[[Any, ...], Any]:
return ray.get(self._fun) return self.get(self._fun)
@fun.setter @fun.setter
def fun(self, fun: Callable[[Any, ...], Any]): def fun(self, fun: Callable[[Any, ...], Any]):
self._fun = ray.put(fun) self._fun = self.put(fun)
@property @property
def args(self) -> tuple[Any, ...]: def args(self) -> tuple[Any, ...]:
return tuple([ray.get(arg) for arg in self._args]) return tuple([self.get(arg) for arg in self._args])
@args.setter @args.setter
def args(self, args: tuple[Any, ...]) -> None: def args(self, args: tuple[Any, ...]) -> None:
self._args = [ray.put(arg) for arg in args] self._args = [self.put(arg) for arg in args]
@property @property
def kwargs(self) -> dict[str, Any]: def kwargs(self) -> dict[str, Any]:
return {key: ray.get(value) for key, value in self._kwargs.items()} return {key: self.get(value) for key, value in self._kwargs.items()}
@kwargs.setter @kwargs.setter
def kwargs(self, kwargs: dict[str, Any]) -> None: def kwargs(self, kwargs: dict[str, Any]) -> None:
self._kwargs = {key: ray.put(value) for key, value in kwargs.items()} self._kwargs = {key: self.put(value) for key, value in kwargs.items()}
@property @property
def result(self) -> Any: def result(self) -> Any:
return ray.get(self._result) return self.get(self._result)
@result.setter @result.setter
def result(self, result: Any) -> None: def result(self, result: Any) -> None:
self._result = ray.put(result) self._result = self.put(result)
def __call__(self) -> Task: def __call__(self) -> Task:
if not self.done: if not self.done:
@@ -324,6 +340,9 @@ class ParPool:
barlength=barlength, barlength=barlength,
) )
def close(self) -> None:
pass
def add_task( def add_task(
self, self,
fun: Callable[[Any, ...], Any] = None, fun: Callable[[Any, ...], Any] = None,
@@ -363,8 +382,12 @@ class ParPool:
"""Request result and delete its record. Wait if result not yet available.""" """Request result and delete its record. Wait if result not yet available."""
if handle not in self: if handle not in self:
raise ValueError(f"No handle: {handle} in pool") raise ValueError(f"No handle: {handle} in pool")
task = ray.get(self.tasks[handle].future) task = self.tasks[handle]
return self.finalize_task(task) if task.future is None:
return task.result
else:
task = ray.get(self.tasks[handle].future)
return self.finalize_task(task)
def __contains__(self, handle: Hashable) -> bool: def __contains__(self, handle: Hashable) -> bool:
return handle in self.tasks return handle in self.tasks
@@ -387,10 +410,13 @@ class ParPool:
while True: while True:
if self.tasks: if self.tasks:
for handle, task in self.tasks.items(): for handle, task in self.tasks.items():
if handle in self.bar_lengths: if handle in self.tasks:
try: try:
task = ray.get(task.future, timeout=0.01) if task.future is None:
return task.handle, self.finalize_task(task) return task.handle, task.result
else:
task = ray.get(task.future, timeout=0.01)
return task.handle, self.finalize_task(task)
except ray.exceptions.GetTimeoutError: except ray.exceptions.GetTimeoutError:
pass pass
@@ -421,7 +447,9 @@ class PoolSingleton:
cls.instance = super().__new__(cls) cls.instance = super().__new__(cls)
cls.instance.n_processes = n_processes cls.instance.n_processes = n_processes
if ray.is_initialized(): if ray.is_initialized():
warnings.warn("not setting n_processes because parallel pool was already initialized") if cls.instance.n_processes != n_processes:
warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, "
f"probably with n_processes={cls.instance.n_processes}")
else: else:
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False) ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False)

119
parfor/pickler.py Normal file
View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import copyreg
from io import BytesIO
from pickle import PicklingError
from typing import Any, Callable
import dill
loads = dill.loads
class CouldNotBePickled:
def __init__(self, class_name: str) -> None:
self.class_name = class_name
def __repr__(self) -> str:
return f"Item of type '{self.class_name}' could not be pickled and was omitted."
@classmethod
def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]:
return cls, (type(item).__name__,)
class Pickler(dill.Pickler):
""" Overload dill to ignore unpicklable parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them picklable.
"""
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
""" Copied from pickle and amended. """
self.framer.commit_frame()
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
# Check the memo
x = self.memo.get(id(obj))
if x is not None:
self.write(self.get(x[0]))
return
rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)
if rv is NotImplemented:
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular
# class
if issubclass(t, type):
self.save_global(obj)
return
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
try:
if reduce is not None:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
except Exception: # noqa
rv = CouldNotBePickled.reduce(obj)
# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
try:
self.save_global(obj, rv)
except Exception: # noqa
self.save_global(obj, CouldNotBePickled.reduce(obj))
return
# Assert that reduce() returned a tuple
if not isinstance(rv, tuple):
raise PicklingError("%s must return string or tuple" % reduce)
# Assert that it returned an appropriately sized tuple
length = len(rv)
if not (2 <= length <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
# Save the reduce() output and finally memoize the object
try:
self.save_reduce(obj=obj, *rv)
except Exception: # noqa
self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj))
def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True,
**kwds: Any) -> bytes:
"""pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
with BytesIO() as file:
Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue()

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "parfor" name = "parfor"
version = "2026.1.1" version = "2026.1.2"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = [ authors = [
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" } { name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }
@@ -10,6 +10,7 @@ readme = "README.md"
keywords = ["parfor", "concurrency", "multiprocessing", "parallel"] keywords = ["parfor", "concurrency", "multiprocessing", "parallel"]
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"dill >= 0.3.0",
"numpy", "numpy",
"tqdm >= 4.50.0", "tqdm >= 4.50.0",
"ray", "ray",