- Dill (almost) everything.

- Change dill a little to automatically omit undillable parts of objects.
- README.me Limitations
This commit is contained in:
w.pomp
2020-09-03 11:29:15 +02:00
parent 881496b8f4
commit da70cf7a2f
3 changed files with 130 additions and 13 deletions

View File

@@ -1,9 +1,118 @@
from __future__ import print_function
import sys
import multiprocessing
import warnings
import dill
from tqdm.auto import tqdm
from dill import dumps, loads
from traceback import format_exc
from pickle import PicklingError, dispatch_table
PY3 = (sys.hexversion >= 0x3000000)
try:
from cStringIO import StringIO
except ImportError:
if PY3:
from io import BytesIO as StringIO
else:
from StringIO import StringIO
failed_rv = (lambda *args, **kwargs: None, ())
class Pickler(dill.Pickler):
""" Overload dill to ignore unpickleble parts of objects.
You probably didn't want to use these parts anyhow.
However, if you did, you'll have to find some way to make them pickleble.
"""
def save(self, obj, save_persistent_id=True):
""" Copied from pickle and amended. """
if PY3:
self.framer.commit_frame()
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
# Check the memo
x = self.memo.get(id(obj))
if x is not None:
self.write(self.get(x[0]))
return
rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)
if rv is NotImplemented:
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular
# class
if issubclass(t, type):
self.save_global(obj)
return
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
try:
if reduce is not None:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))
except:
rv = failed_rv
# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
try:
self.save_global(obj, rv)
except:
self.save_global(obj, failed_rv)
return
# Assert that reduce() returned a tuple
if not isinstance(rv, tuple):
raise PicklingError("%s must return string or tuple" % reduce)
# Assert that it returned an appropriately sized tuple
l = len(rv)
if not (2 <= l <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)
# Save the reduce() output and finally memoize the object
try:
self.save_reduce(obj=obj, *rv)
except:
self.save_reduce(obj=obj, *failed_rv)
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=None, **kwds):#, strictio=None):
"""pickle an object to a string"""
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
_kwds = kwds.copy()
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
file = StringIO()
Pickler(file, protocol, **_kwds).dump(obj)
return file.getvalue()
def chunks(n, *args):
""" Yield successive n-sized chunks from lists. """
@@ -205,9 +314,9 @@ class parpool(object):
'debuggable error. If it doesn\'t, it\'s an error specific to parallel execution.'
.format(r[1]))
warnings.formatwarning = pfmt
fun, args, kwargs = [loads(f[1]) for f in r[2][1:]]
r = (False, r[1], fun(r[2][0], *args, **kwargs))
self.res[r[1]] = r[2]
fun, args, kwargs = [dill.loads(f[1]) for f in r[2][1:]]
r = (False, r[1], fun(dill.loads(r[2][0]), *args, **kwargs))
self.res[r[1]] = dill.loads(r[2])
if not self.bar is None:
self.bar.update()
self._qbar_update()
@@ -216,6 +325,7 @@ class parpool(object):
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None):
""" Add new iteration, using optional manually defined handle."""
n = dumps(n, recurse=True)
if not fun is None:
self.fun = fun
if not args is None:
@@ -306,7 +416,7 @@ class parpool(object):
fun = self.get_from_cache(*Fun)
args = self.get_from_cache(*Args)
kwargs = self.get_from_cache(*Kwargs)
self.Qo.put((False, i, fun(n, *args, **kwargs)))
self.Qo.put((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True)))
except multiprocessing.queues.Empty:
continue
except:
@@ -320,7 +430,7 @@ class parpool(object):
hs, objs = zip(*self.cache)
if h in hs:
return objs[hs.index(h)]
obj = loads(ser)
obj = dill.loads(ser)
self.cache.append((h, obj))
while len(self.cache) > self.cachesize:
self.cache.pop(0)