- Use parfor to take care of the parallel part.
- Use sha1 hash because it's consistent between processes.
This commit is contained in:
@@ -11,7 +11,7 @@ good compression.
|
||||
- Compresses even more by referencing tag or image data which otherwise would have been saved several times.
|
||||
For example empty frames, or a long string tag on every frame.
|
||||
- Enables memory efficient scripts by saving frames whenever they're ready to be saved, not waiting for the whole stack.
|
||||
- Colormaps, extra tags globally or frame dependent.
|
||||
- Colormaps, extra tags, globally or frame dependent.
|
||||
|
||||
## Installation
|
||||
pip install tiffwrite
|
||||
@@ -92,6 +92,3 @@ or
|
||||
be opened as a correctly ordered hyperstack.
|
||||
- Using the colormap parameter you can make ImageJ open the file and apply the colormap. colormap='glasbey' is very
|
||||
useful.
|
||||
- IJTiffFile does not allow more than one pool of parallel processes to be open at a time. Therefore, when writing
|
||||
multiple tiff's simultaneously you have to open all before you start saving any frame, in this way all files share the
|
||||
same pool.
|
||||
|
||||
+2
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "tiffwrite"
|
||||
version = "2023.3.0"
|
||||
version = "2023.8.0"
|
||||
description = "Parallel tiff writer compatible with ImageJ."
|
||||
authors = ["Wim Pomp, Lenstra lab NKI <w.pomp@nki.nl>"]
|
||||
license = "GPL-3.0-or-later"
|
||||
@@ -15,6 +15,7 @@ numpy = "*"
|
||||
tqdm = "*"
|
||||
colorcet = "*"
|
||||
matplotlib = "*"
|
||||
parfor = ">=2023.8.3"
|
||||
pytest = { version = "*", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from contextlib import ExitStack
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from tiffwrite import IJTiffFile
|
||||
from itertools import product
|
||||
from contextlib import ExitStack
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def test_mult(tmp_path):
|
||||
shape = (3, 5, 12)
|
||||
paths = [tmp_path / f'test{i}.tif' for i in range(8)]
|
||||
shape = (2, 3, 5)
|
||||
paths = [tmp_path / f'test{i}.tif' for i in range(6)]
|
||||
with ExitStack() as stack:
|
||||
tifs = [stack.enter_context(IJTiffFile(path, shape)) for path in paths]
|
||||
for c, z, t in tqdm(product(range(shape[0]), range(shape[1]), range(shape[2])), total=np.prod(shape)):
|
||||
for tif in tifs:
|
||||
tif.save(np.random.randint(0, 255, (1024, 1024)), c, z, t)
|
||||
tif.save(np.random.randint(0, 255, (64, 64)), c, z, t)
|
||||
assert all([path.exists() for path in paths])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from tiffwrite import IJTiffFile
|
||||
from itertools import product
|
||||
|
||||
|
||||
def test_single(tmp_path):
|
||||
|
||||
+58
-227
@@ -1,24 +1,22 @@
|
||||
import os
|
||||
import tifffile
|
||||
import colorcet
|
||||
import struct
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
from multiprocessing import queues
|
||||
from io import BytesIO
|
||||
from tqdm.auto import tqdm
|
||||
from itertools import product
|
||||
from collections.abc import Iterable
|
||||
from numbers import Number
|
||||
from fractions import Fraction
|
||||
from traceback import print_exc, format_exc
|
||||
from functools import cached_property
|
||||
from datetime import datetime
|
||||
from matplotlib import colors as mpl_colors
|
||||
from contextlib import contextmanager
|
||||
from warnings import warn
|
||||
from datetime import datetime
|
||||
from fractions import Fraction
|
||||
from functools import cached_property
|
||||
from hashlib import sha1
|
||||
from importlib.metadata import version
|
||||
from io import BytesIO
|
||||
from itertools import product
|
||||
from numbers import Number
|
||||
|
||||
import colorcet
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from matplotlib import colors as mpl_colors
|
||||
from parfor import ParPool, PoolSingleton
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
__all__ = ["IJTiffFile", "Tag", "tiffwrite"]
|
||||
|
||||
@@ -40,12 +38,12 @@ def tiffwrite(file, data, axes='TZCXY', dtype=None, bar=False, *args, **kwargs):
|
||||
|
||||
axes = axes[-np.ndim(data):].upper()
|
||||
if not axes == 'CZTXY':
|
||||
T = [axes.find(i) for i in 'CZTXY']
|
||||
E = [i for i, j in enumerate(T) if j < 0]
|
||||
T = [i for i in T if i >= 0]
|
||||
data = np.transpose(data, T)
|
||||
for e in E:
|
||||
data = np.expand_dims(data, e)
|
||||
axes_shuffle = [axes.find(i) for i in 'CZTXY']
|
||||
axes_add = [i for i, j in enumerate(axes_shuffle) if j < 0]
|
||||
axes_shuffle = [i for i in axes_shuffle if i >= 0]
|
||||
data = np.transpose(data, axes_shuffle)
|
||||
for axis in axes_add:
|
||||
data = np.expand_dims(data, axis)
|
||||
|
||||
shape = data.shape[:3]
|
||||
with IJTiffFile(file, shape, data.dtype if dtype is None else dtype, *args, **kwargs) as f:
|
||||
@@ -340,7 +338,7 @@ class IJTiffFile:
|
||||
wp@tl20200214
|
||||
"""
|
||||
def __init__(self, path, shape, dtype='uint16', colors=None, colormap=None, pxsize=None, deltaz=None,
|
||||
timeinterval=None, compression=(8, 9), comment=None, processes=None, **extratags):
|
||||
timeinterval=None, compression=(8, 9), comment=None, **extratags):
|
||||
assert len(shape) >= 3, 'please specify all c, z, t for the shape'
|
||||
assert len(shape) <= 3, 'please specify only c, z, t for the shape'
|
||||
assert np.dtype(dtype).char in 'BbHhf', 'datatype not supported'
|
||||
@@ -365,25 +363,22 @@ class IJTiffFile:
|
||||
self.frames = []
|
||||
self.spp = self.shape[0] if self.colormap is None and self.colors is None else 1 # samples/pixel
|
||||
self.nframes = np.prod(self.shape[1:]) if self.colormap is None and self.colors is None else np.prod(self.shape)
|
||||
self.offsets = {}
|
||||
self.strips = {}
|
||||
self.ifds = {}
|
||||
self.frame_extra_tags = {}
|
||||
self.frames_added = []
|
||||
self.frames_written = []
|
||||
self.pool_manager = PoolManager(self, processes)
|
||||
self.fh = FileHandle(path)
|
||||
self.hashes = Manager().manager.dict()
|
||||
self.main_pid = Manager().mp.current_process().pid
|
||||
self.hashes = PoolSingleton().manager.dict()
|
||||
self.pool = ParPool(self.compress_frame)
|
||||
self.main_process = True
|
||||
|
||||
with self.fh.lock() as fh:
|
||||
self.header.write(fh)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.main_process = False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.path)
|
||||
|
||||
def update(self):
|
||||
""" To be overloaded, will be called when a frame has been written. """
|
||||
|
||||
def get_frame_number(self, n):
|
||||
if self.colormap is None and self.colors is None:
|
||||
return n[1] + n[2] * self.shape[1], n[0]
|
||||
@@ -400,53 +395,13 @@ class IJTiffFile:
|
||||
def save(self, frame, c, z, t, **extratags):
|
||||
""" save a 2d numpy array to the tiff at channel=c, slice=z, time=t, with optional extra tif tags
|
||||
"""
|
||||
assert (c, z, t) not in self.frames_written, f'frame {c} {z} {t} is added already'
|
||||
assert (c, z, t) not in self.pool.tasks, f'frame {c} {z} {t} is added already'
|
||||
assert all([0 <= i < s for i, s in zip((c, z, t), self.shape)]), \
|
||||
'frame {} {} {} is outside shape {} {} {}'.format(c, z, t, *self.shape)
|
||||
self.frames_added.append((c, z, t))
|
||||
self.pool_manager.add_frame(self.path,
|
||||
frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, (c, z, t))
|
||||
self.pool(frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, handle=(c, z, t))
|
||||
if extratags:
|
||||
self.frame_extra_tags[(c, z, t)] = Tag.to_tags(extratags)
|
||||
|
||||
@cached_property
|
||||
def loader(self):
|
||||
tif = self.ij_tiff_frame(np.zeros((8, 8), self.dtype))
|
||||
with BytesIO(tif) as fh:
|
||||
return tif, Header(fh), IFD(fh)
|
||||
|
||||
def load(self, *n):
|
||||
""" read back a frame that's just written
|
||||
useful for simulating a large number of frames without using much memory
|
||||
"""
|
||||
if n not in self.frames_added:
|
||||
raise KeyError(f'frame {n} has not been added yet')
|
||||
while n not in self.frames_written:
|
||||
self.pool_manager.get_ifds_from_queue()
|
||||
tif, header, ifd = self.loader
|
||||
framenumber, channel = self.get_frame_number(n)
|
||||
with BytesIO(tif) as fh:
|
||||
fh.seek(0, 2)
|
||||
fstripbyteoffsets, fstripbytecounts = [], []
|
||||
with open(self.path, 'rb') as file:
|
||||
for stripbyteoffset, stripbytecount in zip(*self.strips[(framenumber, channel)]):
|
||||
file.seek(stripbyteoffset)
|
||||
bdata = file.read(stripbytecount)
|
||||
fstripbyteoffsets.append(fh.tell())
|
||||
fstripbytecounts.append(len(bdata))
|
||||
fh.write(bdata)
|
||||
ifd = ifd.copy()
|
||||
for key, value in zip((257, 256, 278, 270, 273, 279),
|
||||
(*self.frame_shape, self.frame_shape[0] // len(fstripbyteoffsets),
|
||||
f'{{"shape": [{self.frame_shape[0]}, {self.frame_shape[1]}]}}',
|
||||
fstripbyteoffsets, fstripbytecounts)):
|
||||
tag = ifd[key]
|
||||
tag.value = value
|
||||
tag.write_tag(fh, key, header, tag.offset)
|
||||
tag.write_data()
|
||||
fh.seek(0)
|
||||
return tifffile.TiffFile(fh).asarray()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
desc = ['ImageJ=1.11a']
|
||||
@@ -473,21 +428,6 @@ class IJTiffFile:
|
||||
desc.append(bytes(self.comment, 'ascii'))
|
||||
return b'\n'.join(desc)
|
||||
|
||||
@cached_property
|
||||
def empty_frame(self):
|
||||
return self.compress_frame(np.zeros(self.frame_shape, self.dtype))
|
||||
|
||||
@cached_property
|
||||
def frame_shape(self):
|
||||
ifd = self.ifds[list(self.ifds.keys())[-1]].copy()
|
||||
return ifd[257].value[0], ifd[256].value[0]
|
||||
|
||||
def add_empty_frame(self, n):
|
||||
framenr, channel = self.get_frame_number(n)
|
||||
ifd, strips = self.empty_frame
|
||||
self.ifds[framenr] = ifd.copy()
|
||||
self.strips[(framenr, channel)] = strips
|
||||
|
||||
@cached_property
|
||||
def colormap_bytes(self):
|
||||
colormap = getattr(colorcet, self.colormap)
|
||||
@@ -507,48 +447,44 @@ class IJTiffFile:
|
||||
dtype=int).T.flatten()]) for color in self.colors]
|
||||
|
||||
def close(self):
|
||||
if Manager().mp.current_process().pid == self.main_pid:
|
||||
self.pool_manager.close(self)
|
||||
with self.fh.lock() as fh:
|
||||
if len(self.frames_added) == 0:
|
||||
warn('At least one frame should be added to the tiff, removing file.')
|
||||
os.remove(self.path)
|
||||
else:
|
||||
if len(self.frames_written) < np.prod(self.shape): # add empty frames if needed
|
||||
for n in product(*[range(i) for i in self.shape]):
|
||||
if n not in self.frames_written:
|
||||
self.add_empty_frame(n)
|
||||
if self.main_process:
|
||||
ifds, strips = {}, {}
|
||||
for n in list(self.pool.tasks):
|
||||
framenr, channel = self.get_frame_number(n)
|
||||
ifds[framenr], strips[(framenr, channel)] = self.pool[n]
|
||||
|
||||
self.pool.close()
|
||||
with self.fh.lock() as fh:
|
||||
for n, tags in self.frame_extra_tags.items():
|
||||
framenr, channel = self.get_frame_number(n)
|
||||
self.ifds[framenr].update(tags)
|
||||
ifds[framenr].update(tags)
|
||||
if self.colormap is not None:
|
||||
self.ifds[0][320] = Tag('SHORT', self.colormap_bytes)
|
||||
self.ifds[0][262] = Tag('SHORT', 3)
|
||||
ifds[0][320] = Tag('SHORT', self.colormap_bytes)
|
||||
ifds[0][262] = Tag('SHORT', 3)
|
||||
if self.colors is not None:
|
||||
for c, color in enumerate(self.colors_bytes):
|
||||
self.ifds[c][320] = Tag('SHORT', color)
|
||||
self.ifds[c][262] = Tag('SHORT', 3)
|
||||
if 306 not in self.ifds[0]:
|
||||
self.ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S'))
|
||||
ifds[c][320] = Tag('SHORT', color)
|
||||
ifds[c][262] = Tag('SHORT', 3)
|
||||
if 306 not in ifds[0]:
|
||||
ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S'))
|
||||
for framenr in range(self.nframes):
|
||||
stripbyteoffsets, stripbytecounts = zip(*[self.strips[(framenr, channel)]
|
||||
stripbyteoffsets, stripbytecounts = zip(*[strips[(framenr, channel)]
|
||||
for channel in range(self.spp)])
|
||||
self.ifds[framenr][258].value = self.spp * self.ifds[framenr][258].value
|
||||
self.ifds[framenr][270] = Tag('ASCII', self.description)
|
||||
self.ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, []))
|
||||
self.ifds[framenr][277] = Tag('SHORT', self.spp)
|
||||
self.ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, []))
|
||||
self.ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI')
|
||||
ifds[framenr][258].value = self.spp * ifds[framenr][258].value
|
||||
ifds[framenr][270] = Tag('ASCII', self.description)
|
||||
ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, []))
|
||||
ifds[framenr][277] = Tag('SHORT', self.spp)
|
||||
ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, []))
|
||||
ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI')
|
||||
if self.extratags is not None:
|
||||
self.ifds[framenr].update(self.extratags)
|
||||
ifds[framenr].update(self.extratags)
|
||||
if self.colormap is None and self.colors is None and self.shape[0] > 1:
|
||||
self.ifds[framenr][284] = Tag('SHORT', 2)
|
||||
self.ifds[framenr].write(fh, self.header, self.write)
|
||||
ifds[framenr][284] = Tag('SHORT', 2)
|
||||
ifds[framenr].write(fh, self.header, self.write)
|
||||
if framenr:
|
||||
self.ifds[framenr].write_offset(self.ifds[framenr - 1].where_to_write_next_ifd_offset)
|
||||
ifds[framenr].write_offset(ifds[framenr - 1].where_to_write_next_ifd_offset)
|
||||
else:
|
||||
self.ifds[framenr].write_offset(self.header.offset - self.header.offsetsize)
|
||||
ifds[framenr].write_offset(self.header.offset - self.header.offsetsize)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -556,9 +492,6 @@ class IJTiffFile:
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.close()
|
||||
|
||||
def clean(self, *args, **kwargs):
|
||||
""" To be overloaded, will be called when the parallel pool is closing. """
|
||||
|
||||
@staticmethod
|
||||
def hash_check(fh, bvalue, offset):
|
||||
addr = fh.tell()
|
||||
@@ -568,7 +501,7 @@ class IJTiffFile:
|
||||
return same
|
||||
|
||||
def write(self, fh, bvalue):
|
||||
hash_value = hash(bvalue)
|
||||
hash_value = sha1(bvalue).hexdigest() # hash uses a random seed making hashes different in different processes
|
||||
if hash_value in self.hashes and self.hash_check(fh, bvalue, self.hashes[hash_value]):
|
||||
return self.hashes[hash_value] # reuse previously saved data
|
||||
else:
|
||||
@@ -601,112 +534,10 @@ class IJTiffFile:
|
||||
return stripbytecounts, ifd, chunks
|
||||
|
||||
|
||||
class Manager:
|
||||
instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls.instance is None:
|
||||
cls.instance = super().__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, 'mp'):
|
||||
self.mp = multiprocessing.get_context('spawn')
|
||||
self.manager = self.mp.Manager()
|
||||
|
||||
|
||||
class PoolManager:
|
||||
instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls.instance is None:
|
||||
cls.instance = super().__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self, tif, processes=None):
|
||||
if not hasattr(self, 'tifs'):
|
||||
self.tifs = {}
|
||||
if not hasattr(self, 'is_alive'):
|
||||
self.is_alive = False
|
||||
if self.is_alive:
|
||||
raise ValueError('Cannot start new tifffile until previous tifffiles have been closed.')
|
||||
self.tifs[tif.path] = tif
|
||||
self.processes = processes
|
||||
|
||||
def close(self, tif):
|
||||
while len(tif.frames_written) < len(tif.frames_added):
|
||||
self.get_ifds_from_queue()
|
||||
self.tifs.pop(tif.path)
|
||||
if not self.tifs:
|
||||
self.__class__.instance = None
|
||||
if self.is_alive:
|
||||
self.is_alive = False
|
||||
self.done.set()
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
while not self.error_queue.empty():
|
||||
print(self.error_queue.get())
|
||||
self.error_queue.close()
|
||||
self.ifd_queue.close()
|
||||
self.ifd_queue.join_thread()
|
||||
self.pool.close()
|
||||
self.pool.join()
|
||||
|
||||
def get_ifds_from_queue(self):
|
||||
while not self.ifd_queue.empty():
|
||||
file, n, ifd, strip = self.ifd_queue.get()
|
||||
framenr, channel = self.tifs[file].get_frame_number(n)
|
||||
self.tifs[file].ifds[framenr] = ifd
|
||||
self.tifs[file].strips[(framenr, channel)] = strip
|
||||
self.tifs[file].frames_written.append(n)
|
||||
self.tifs[file].update()
|
||||
|
||||
def add_frame(self, *args):
|
||||
if not self.is_alive:
|
||||
self.start_pool()
|
||||
self.get_ifds_from_queue()
|
||||
self.queue.put(args)
|
||||
|
||||
def start_pool(self):
|
||||
mp = Manager().mp
|
||||
self.is_alive = True
|
||||
nframes = sum([np.prod(tif.shape) for tif in self.tifs.values()])
|
||||
if self.processes is None:
|
||||
self.processes = max(2, min(mp.cpu_count() // 6, nframes))
|
||||
elif self.processes == 'all':
|
||||
self.processes = max(2, min(mp.cpu_count(), nframes))
|
||||
else:
|
||||
self.processes = self.processes
|
||||
self.queue = mp.Queue(10 * self.processes)
|
||||
self.ifd_queue = mp.Queue(10 * self.processes)
|
||||
self.error_queue = mp.Queue()
|
||||
self.offsets_queue = mp.Queue()
|
||||
self.done = mp.Event()
|
||||
self.pool = mp.Pool(self.processes, self.run)
|
||||
|
||||
def run(self):
|
||||
""" Only this is run in parallel processes. """
|
||||
try:
|
||||
while not self.done.is_set():
|
||||
try:
|
||||
file, frame, n = self.queue.get(True, 0.02)
|
||||
self.ifd_queue.put((file, n, *self.tifs[file].compress_frame(frame)))
|
||||
except queues.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
print_exc()
|
||||
self.error_queue.put(format_exc())
|
||||
finally:
|
||||
for tif in self.tifs.values():
|
||||
tif.clean()
|
||||
|
||||
|
||||
class FileHandle:
|
||||
""" Process safe file handle """
|
||||
def __init__(self, name):
|
||||
manager = Manager().manager
|
||||
manager = PoolSingleton().manager
|
||||
if os.path.exists(name):
|
||||
os.remove(name)
|
||||
with open(name, 'xb'):
|
||||
|
||||
Reference in New Issue
Block a user