diff --git a/README.md b/README.md index 07a4149..1b0059a 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ Tested on linux, Windows and OSX with python 3.10 and 3.12. - Easy to use - Progress bars are built-in - Retry the task in the main process upon failure for easy debugging +- Using a modified version of dill when ray fails to serialize an object: + a lot more objects can be used when parallelizing ## How it works [Ray](https://pypi.org/project/ray/) does all the heavy lifting. Parfor now is just a wrapper around ray, adding diff --git a/parfor/__init__.py b/parfor/__init__.py index bca2ddb..1a02d84 100644 --- a/parfor/__init__.py +++ b/parfor/__init__.py @@ -28,6 +28,8 @@ import ray from numpy.typing import ArrayLike, DTypeLike from tqdm.auto import tqdm +from .pickler import dumps, loads + __version__ = version("parfor") cpu_count = int(os.cpu_count()) @@ -229,7 +231,7 @@ class Task: self.handle = handle self.fun = fun self.args = args - self.kwargs = kwargs + self.kwargs = kwargs or {} self.name = fun.__name__ if hasattr(fun, "__name__") else None self.done = False self.result = None @@ -239,37 +241,51 @@ class Task: self.status = "starting" self.allow_output = allow_output + @staticmethod + def get(item: tuple[bool, Any]) -> Any: + if item[0]: + return loads(ray.get(item[1])) + else: + return ray.get(item[1]) + + @staticmethod + def put(item: Any) -> tuple[bool, Any]: + try: + return False, ray.put(item) + except Exception: # noqa + return True, ray.put(dumps(item)) + @property def fun(self) -> Callable[[Any, ...], Any]: - return ray.get(self._fun) + return self.get(self._fun) @fun.setter def fun(self, fun: Callable[[Any, ...], Any]): - self._fun = ray.put(fun) + self._fun = self.put(fun) @property def args(self) -> tuple[Any, ...]: - return tuple([ray.get(arg) for arg in self._args]) + return tuple([self.get(arg) for arg in self._args]) @args.setter def args(self, args: tuple[Any, ...]) -> None: - self._args = [ray.put(arg) for arg in args] + self._args = [self.put(arg) for arg in args] @property def kwargs(self) -> dict[str, Any]: - return {key: ray.get(value) for key, value in self._kwargs.items()} + return {key: self.get(value) for key, value in self._kwargs.items()} @kwargs.setter def kwargs(self, kwargs: dict[str, Any]) -> None: - self._kwargs = {key: ray.put(value) for key, value in kwargs.items()} + self._kwargs = {key: self.put(value) for key, value in kwargs.items()} @property def result(self) -> Any: - return ray.get(self._result) + return self.get(self._result) @result.setter def result(self, result: Any) -> None: - self._result = ray.put(result) + self._result = self.put(result) def __call__(self) -> Task: if not self.done: @@ -324,6 +340,9 @@ class ParPool: barlength=barlength, ) + def close(self) -> None: + pass + def add_task( self, fun: Callable[[Any, ...], Any] = None, @@ -363,8 +382,12 @@ class ParPool: """Request result and delete its record. Wait if result not yet available.""" if handle not in self: raise ValueError(f"No handle: {handle} in pool") - task = ray.get(self.tasks[handle].future) - return self.finalize_task(task) + task = self.tasks[handle] + if task.future is None: + return task.result + else: + task = ray.get(self.tasks[handle].future) + return self.finalize_task(task) def __contains__(self, handle: Hashable) -> bool: return handle in self.tasks @@ -387,10 +410,13 @@ class ParPool: while True: if self.tasks: for handle, task in self.tasks.items(): - if handle in self.bar_lengths: + if handle in self.tasks: try: - task = ray.get(task.future, timeout=0.01) - return task.handle, self.finalize_task(task) + if task.future is None: + return task.handle, task.result + else: + task = ray.get(task.future, timeout=0.01) + return task.handle, self.finalize_task(task) except ray.exceptions.GetTimeoutError: pass @@ -421,7 +447,9 @@ class PoolSingleton: cls.instance = super().__new__(cls) cls.instance.n_processes = n_processes if ray.is_initialized(): - warnings.warn("not setting n_processes because parallel pool was already initialized") + if cls.instance.n_processes != n_processes: + warnings.warn(f"not setting n_processes={n_processes} because parallel pool was already initialized, " + f"probably with n_processes={cls.instance.n_processes}") else: os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" ray.init(num_cpus=n_processes, logging_level=logging.ERROR, log_to_driver=False) diff --git a/parfor/pickler.py b/parfor/pickler.py new file mode 100644 index 0000000..0a016a4 --- /dev/null +++ b/parfor/pickler.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import copyreg +from io import BytesIO +from pickle import PicklingError +from typing import Any, Callable + +import dill + +loads = dill.loads + + +class CouldNotBePickled: + def __init__(self, class_name: str) -> None: + self.class_name = class_name + + def __repr__(self) -> str: + return f"Item of type '{self.class_name}' could not be pickled and was omitted." + + @classmethod + def reduce(cls, item: Any) -> tuple[Callable[[str], CouldNotBePickled], tuple[str]]: + return cls, (type(item).__name__,) + + +class Pickler(dill.Pickler): + """ Overload dill to ignore unpicklable parts of objects. + You probably didn't want to use these parts anyhow. + However, if you did, you'll have to find some way to make them picklable. + """ + def save(self, obj: Any, save_persistent_id: bool = True) -> None: + """ Copied from pickle and amended. """ + self.framer.commit_frame() + + # Check for persistent id (defined by a subclass) + pid = self.persistent_id(obj) + if pid is not None and save_persistent_id: + self.save_pers(pid) + return + + # Check the memo + x = self.memo.get(id(obj)) + if x is not None: + self.write(self.get(x[0])) + return + + rv = NotImplemented + reduce = getattr(self, "reducer_override", None) + if reduce is not None: + rv = reduce(obj) + + if rv is NotImplemented: + # Check the type dispatch table + t = type(obj) + f = self.dispatch.get(t) + if f is not None: + f(self, obj) # Call unbound method with explicit self + return + + # Check private dispatch table if any, or else + # copyreg.dispatch_table + reduce = getattr(self, 'dispatch_table', copyreg.dispatch_table).get(t) + if reduce is not None: + rv = reduce(obj) + else: + # Check for a class with a custom metaclass; treat as regular + # class + if issubclass(t, type): + self.save_global(obj) + return + + # Check for a __reduce_ex__ method, fall back to __reduce__ + reduce = getattr(obj, "__reduce_ex__", None) + try: + if reduce is not None: + rv = reduce(self.proto) + else: + reduce = getattr(obj, "__reduce__", None) + if reduce is not None: + rv = reduce() + else: + raise PicklingError("Can't pickle %r object: %r" % + (t.__name__, obj)) + except Exception: # noqa + rv = CouldNotBePickled.reduce(obj) + + # Check for string returned by reduce(), meaning "save as global" + if isinstance(rv, str): + try: + self.save_global(obj, rv) + except Exception: # noqa + self.save_global(obj, CouldNotBePickled.reduce(obj)) + return + + # Assert that reduce() returned a tuple + if not isinstance(rv, tuple): + raise PicklingError("%s must return string or tuple" % reduce) + + # Assert that it returned an appropriately sized tuple + length = len(rv) + if not (2 <= length <= 6): + raise PicklingError("Tuple returned by %s must have " + "two to six elements" % reduce) + + # Save the reduce() output and finally memoize the object + try: + self.save_reduce(obj=obj, *rv) + except Exception: # noqa + self.save_reduce(obj=obj, *CouldNotBePickled.reduce(obj)) + + +def dumps(obj: Any, protocol: str = None, byref: bool = None, fmode: str = None, recurse: bool = True, + **kwds: Any) -> bytes: + """pickle an object to a string""" + protocol = dill.settings['protocol'] if protocol is None else int(protocol) + _kwds = kwds.copy() + _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) + with BytesIO() as file: + Pickler(file, protocol, **_kwds).dump(obj) + return file.getvalue() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2fabce6..5bfd5d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "parfor" -version = "2026.1.1" +version = "2026.1.2" description = "A package to mimic the use of parfor as done in Matlab." authors = [ { name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" } @@ -10,6 +10,7 @@ readme = "README.md" keywords = ["parfor", "concurrency", "multiprocessing", "parallel"] requires-python = ">=3.10" dependencies = [ + "dill >= 0.3.0", "numpy", "tqdm >= 4.50.0", "ray",