- Dill (almost) everything.
- Change dill a little to automatically omit undillable parts of objects. - README.me Limitations
This commit is contained in:
@@ -1,9 +1,118 @@
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
import multiprocessing
|
||||
import warnings
|
||||
import dill
|
||||
from tqdm.auto import tqdm
|
||||
from dill import dumps, loads
|
||||
from traceback import format_exc
|
||||
from pickle import PicklingError, dispatch_table
|
||||
|
||||
PY3 = (sys.hexversion >= 0x3000000)
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
if PY3:
|
||||
from io import BytesIO as StringIO
|
||||
else:
|
||||
from StringIO import StringIO
|
||||
|
||||
failed_rv = (lambda *args, **kwargs: None, ())
|
||||
|
||||
class Pickler(dill.Pickler):
|
||||
""" Overload dill to ignore unpickleble parts of objects.
|
||||
You probably didn't want to use these parts anyhow.
|
||||
However, if you did, you'll have to find some way to make them pickleble.
|
||||
"""
|
||||
def save(self, obj, save_persistent_id=True):
|
||||
""" Copied from pickle and amended. """
|
||||
if PY3:
|
||||
self.framer.commit_frame()
|
||||
|
||||
# Check for persistent id (defined by a subclass)
|
||||
pid = self.persistent_id(obj)
|
||||
if pid is not None and save_persistent_id:
|
||||
self.save_pers(pid)
|
||||
return
|
||||
|
||||
# Check the memo
|
||||
x = self.memo.get(id(obj))
|
||||
if x is not None:
|
||||
self.write(self.get(x[0]))
|
||||
return
|
||||
|
||||
rv = NotImplemented
|
||||
reduce = getattr(self, "reducer_override", None)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
|
||||
if rv is NotImplemented:
|
||||
# Check the type dispatch table
|
||||
t = type(obj)
|
||||
f = self.dispatch.get(t)
|
||||
if f is not None:
|
||||
f(self, obj) # Call unbound method with explicit self
|
||||
return
|
||||
|
||||
# Check private dispatch table if any, or else
|
||||
# copyreg.dispatch_table
|
||||
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
else:
|
||||
# Check for a class with a custom metaclass; treat as regular
|
||||
# class
|
||||
if issubclass(t, type):
|
||||
self.save_global(obj)
|
||||
return
|
||||
|
||||
# Check for a __reduce_ex__ method, fall back to __reduce__
|
||||
reduce = getattr(obj, "__reduce_ex__", None)
|
||||
try:
|
||||
if reduce is not None:
|
||||
rv = reduce(self.proto)
|
||||
else:
|
||||
reduce = getattr(obj, "__reduce__", None)
|
||||
if reduce is not None:
|
||||
rv = reduce()
|
||||
else:
|
||||
raise PicklingError("Can't pickle %r object: %r" %
|
||||
(t.__name__, obj))
|
||||
except:
|
||||
rv = failed_rv
|
||||
|
||||
# Check for string returned by reduce(), meaning "save as global"
|
||||
if isinstance(rv, str):
|
||||
try:
|
||||
self.save_global(obj, rv)
|
||||
except:
|
||||
self.save_global(obj, failed_rv)
|
||||
return
|
||||
|
||||
# Assert that reduce() returned a tuple
|
||||
if not isinstance(rv, tuple):
|
||||
raise PicklingError("%s must return string or tuple" % reduce)
|
||||
|
||||
# Assert that it returned an appropriately sized tuple
|
||||
l = len(rv)
|
||||
if not (2 <= l <= 6):
|
||||
raise PicklingError("Tuple returned by %s must have "
|
||||
"two to six elements" % reduce)
|
||||
|
||||
# Save the reduce() output and finally memoize the object
|
||||
try:
|
||||
self.save_reduce(obj=obj, *rv)
|
||||
except:
|
||||
self.save_reduce(obj=obj, *failed_rv)
|
||||
|
||||
def dumps(obj, protocol=None, byref=None, fmode=None, recurse=None, **kwds):#, strictio=None):
|
||||
"""pickle an object to a string"""
|
||||
protocol = dill.settings['protocol'] if protocol is None else int(protocol)
|
||||
_kwds = kwds.copy()
|
||||
_kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
|
||||
file = StringIO()
|
||||
Pickler(file, protocol, **_kwds).dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
def chunks(n, *args):
|
||||
""" Yield successive n-sized chunks from lists. """
|
||||
@@ -205,9 +314,9 @@ class parpool(object):
|
||||
'debuggable error. If it doesn\'t, it\'s an error specific to parallel execution.'
|
||||
.format(r[1]))
|
||||
warnings.formatwarning = pfmt
|
||||
fun, args, kwargs = [loads(f[1]) for f in r[2][1:]]
|
||||
r = (False, r[1], fun(r[2][0], *args, **kwargs))
|
||||
self.res[r[1]] = r[2]
|
||||
fun, args, kwargs = [dill.loads(f[1]) for f in r[2][1:]]
|
||||
r = (False, r[1], fun(dill.loads(r[2][0]), *args, **kwargs))
|
||||
self.res[r[1]] = dill.loads(r[2])
|
||||
if not self.bar is None:
|
||||
self.bar.update()
|
||||
self._qbar_update()
|
||||
@@ -216,6 +325,7 @@ class parpool(object):
|
||||
|
||||
def __call__(self, n, fun=None, args=None, kwargs=None, handle=None):
|
||||
""" Add new iteration, using optional manually defined handle."""
|
||||
n = dumps(n, recurse=True)
|
||||
if not fun is None:
|
||||
self.fun = fun
|
||||
if not args is None:
|
||||
@@ -306,7 +416,7 @@ class parpool(object):
|
||||
fun = self.get_from_cache(*Fun)
|
||||
args = self.get_from_cache(*Args)
|
||||
kwargs = self.get_from_cache(*Kwargs)
|
||||
self.Qo.put((False, i, fun(n, *args, **kwargs)))
|
||||
self.Qo.put((False, i, dumps(fun(dill.loads(n), *args, **kwargs), recurse=True)))
|
||||
except multiprocessing.queues.Empty:
|
||||
continue
|
||||
except:
|
||||
@@ -320,7 +430,7 @@ class parpool(object):
|
||||
hs, objs = zip(*self.cache)
|
||||
if h in hs:
|
||||
return objs[hs.index(h)]
|
||||
obj = loads(ser)
|
||||
obj = dill.loads(ser)
|
||||
self.cache.append((h, obj))
|
||||
while len(self.cache) > self.cachesize:
|
||||
self.cache.pop(0)
|
||||
|
||||
Reference in New Issue
Block a user