- add some caching to prevent repeated ray.put and pickling

This commit is contained in:
w.pomp
2026-01-22 20:11:13 +01:00
parent 879da12628
commit 8619c0fb34
2 changed files with 41 additions and 7 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import logging import logging
import os import os
from collections import OrderedDict
from contextlib import ExitStack, redirect_stderr, redirect_stdout from contextlib import ExitStack, redirect_stderr, redirect_stdout
from functools import wraps from functools import wraps
from importlib.metadata import version from importlib.metadata import version
@@ -231,6 +232,28 @@ def get_worker(n_processes) -> RemoteFunction:
return ray.remote(worker) # type: ignore return ray.remote(worker) # type: ignore
class DequeDict(OrderedDict):
def __init__(self, maxlen: int = None, *args: Any, **kwargs: Any) -> None:
self.maxlen = maxlen
super().__init__(*args, **kwargs)
def __setitem__(self, *args: Any, **kwargs: Any) -> None:
super().__setitem__(*args, **kwargs)
self.truncate()
def truncate(self) -> None:
if self.maxlen is not None:
while len(self) > self.maxlen:
self.popitem(False)
def update(self, *args: Any, **kwargs: Any) -> None:
super().update(*args, **kwargs) # type: ignore
self.truncate()
cache = DequeDict(128)
class Task: class Task:
def __init__( def __init__(
self, self,
@@ -262,6 +285,17 @@ class Task:
@staticmethod @staticmethod
def put(item: Any) -> tuple[bool, Any]: def put(item: Any) -> tuple[bool, Any]:
try:
h = hash(item)
if not h in cache:
try:
cache[h] = False, ray.put(item)
except Exception: # noqa
cache[h] = True, ray.put(dumps(item, recurse=True))
else:
cache.move_to_end(h)
return cache[h]
except TypeError:
try: try:
return False, ray.put(item) return False, ray.put(item)
except Exception: # noqa except Exception: # noqa
@@ -344,7 +378,7 @@ class ParPool:
return self return self
def __exit__(self, *args: Any, **kwargs: Any) -> None: def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass self.close()
def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None: def __call__(self, n: Any, handle: Hashable = None, barlength: int = 1) -> None:
self.add_task( self.add_task(
@@ -354,7 +388,7 @@ class ParPool:
) )
def close(self) -> None: def close(self) -> None:
pass cache.clear()
def add_task( def add_task(
self, self,

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "parfor" name = "parfor"
version = "2026.2.1" version = "2026.2.2"
description = "A package to mimic the use of parfor as done in Matlab." description = "A package to mimic the use of parfor as done in Matlab."
authors = [ authors = [
{ name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" } { name = "Wim Pomp-Pervova", email = "wimpomp@gmail.com" }