diff --git a/README.md b/README.md index af1980a..b4e75ff 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ good compression. - Compresses even more by referencing tag or image data which otherwise would have been saved several times. For example empty frames, or a long string tag on every frame. - Enables memory efficient scripts by saving frames whenever they're ready to be saved, not waiting for the whole stack. -- Colormaps, extra tags globally or frame dependent. +- Colormaps, extra tags, globally or frame dependent. ## Installation pip install tiffwrite @@ -92,6 +92,3 @@ or be opened as a correctly ordered hyperstack. - Using the colormap parameter you can make ImageJ open the file and apply the colormap. colormap='glasbey' is very useful. -- IJTiffFile does not allow more than one pool of parallel processes to be open at a time. Therefore, when writing -multiple tiff's simultaneously you have to open all before you start saving any frame, in this way all files share the -same pool. diff --git a/pyproject.toml b/pyproject.toml index 59660f2..290b8b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tiffwrite" -version = "2023.3.0" +version = "2023.8.0" description = "Parallel tiff writer compatible with ImageJ." authors = ["Wim Pomp, Lenstra lab NKI "] license = "GPL-3.0-or-later" @@ -15,6 +15,7 @@ numpy = "*" tqdm = "*" colorcet = "*" matplotlib = "*" +parfor = ">=2023.8.3" pytest = { version = "*", optional = true } [tool.poetry.extras] diff --git a/tests/test_multiple.py b/tests/test_multiple.py index 324caf0..c134f74 100644 --- a/tests/test_multiple.py +++ b/tests/test_multiple.py @@ -1,16 +1,17 @@ +from contextlib import ExitStack +from itertools import product + import numpy as np from tiffwrite import IJTiffFile -from itertools import product -from contextlib import ExitStack from tqdm.auto import tqdm def test_mult(tmp_path): - shape = (3, 5, 12) - paths = [tmp_path / f'test{i}.tif' for i in range(8)] + shape = (2, 3, 5) + paths = [tmp_path / f'test{i}.tif' for i in range(6)] with ExitStack() as stack: tifs = [stack.enter_context(IJTiffFile(path, shape)) for path in paths] for c, z, t in tqdm(product(range(shape[0]), range(shape[1]), range(shape[2])), total=np.prod(shape)): for tif in tifs: - tif.save(np.random.randint(0, 255, (1024, 1024)), c, z, t) + tif.save(np.random.randint(0, 255, (64, 64)), c, z, t) assert all([path.exists() for path in paths]) diff --git a/tests/test_single.py b/tests/test_single.py index 15d8952..b703a33 100644 --- a/tests/test_single.py +++ b/tests/test_single.py @@ -1,6 +1,7 @@ +from itertools import product + import numpy as np from tiffwrite import IJTiffFile -from itertools import product def test_single(tmp_path): diff --git a/tiffwrite/__init__.py b/tiffwrite/__init__.py index 766adf1..6d13fe6 100755 --- a/tiffwrite/__init__.py +++ b/tiffwrite/__init__.py @@ -1,24 +1,22 @@ import os -import tifffile -import colorcet import struct -import numpy as np -import multiprocessing -from multiprocessing import queues -from io import BytesIO -from tqdm.auto import tqdm -from itertools import product from collections.abc import Iterable -from numbers import Number -from fractions import Fraction -from traceback import print_exc, format_exc -from functools import cached_property -from datetime import datetime -from matplotlib import colors as mpl_colors from contextlib import contextmanager -from warnings import warn +from datetime import datetime +from fractions import Fraction +from functools import cached_property +from hashlib import sha1 from importlib.metadata import version +from io import BytesIO +from itertools import product +from numbers import Number +import colorcet +import numpy as np +import tifffile +from matplotlib import colors as mpl_colors +from parfor import ParPool, PoolSingleton +from tqdm.auto import tqdm __all__ = ["IJTiffFile", "Tag", "tiffwrite"] @@ -40,12 +38,12 @@ def tiffwrite(file, data, axes='TZCXY', dtype=None, bar=False, *args, **kwargs): axes = axes[-np.ndim(data):].upper() if not axes == 'CZTXY': - T = [axes.find(i) for i in 'CZTXY'] - E = [i for i, j in enumerate(T) if j < 0] - T = [i for i in T if i >= 0] - data = np.transpose(data, T) - for e in E: - data = np.expand_dims(data, e) + axes_shuffle = [axes.find(i) for i in 'CZTXY'] + axes_add = [i for i, j in enumerate(axes_shuffle) if j < 0] + axes_shuffle = [i for i in axes_shuffle if i >= 0] + data = np.transpose(data, axes_shuffle) + for axis in axes_add: + data = np.expand_dims(data, axis) shape = data.shape[:3] with IJTiffFile(file, shape, data.dtype if dtype is None else dtype, *args, **kwargs) as f: @@ -340,7 +338,7 @@ class IJTiffFile: wp@tl20200214 """ def __init__(self, path, shape, dtype='uint16', colors=None, colormap=None, pxsize=None, deltaz=None, - timeinterval=None, compression=(8, 9), comment=None, processes=None, **extratags): + timeinterval=None, compression=(8, 9), comment=None, **extratags): assert len(shape) >= 3, 'please specify all c, z, t for the shape' assert len(shape) <= 3, 'please specify only c, z, t for the shape' assert np.dtype(dtype).char in 'BbHhf', 'datatype not supported' @@ -365,25 +363,22 @@ class IJTiffFile: self.frames = [] self.spp = self.shape[0] if self.colormap is None and self.colors is None else 1 # samples/pixel self.nframes = np.prod(self.shape[1:]) if self.colormap is None and self.colors is None else np.prod(self.shape) - self.offsets = {} - self.strips = {} - self.ifds = {} self.frame_extra_tags = {} - self.frames_added = [] - self.frames_written = [] - self.pool_manager = PoolManager(self, processes) self.fh = FileHandle(path) - self.hashes = Manager().manager.dict() - self.main_pid = Manager().mp.current_process().pid + self.hashes = PoolSingleton().manager.dict() + self.pool = ParPool(self.compress_frame) + self.main_process = True + with self.fh.lock() as fh: self.header.write(fh) + def __setstate__(self, state): + self.__dict__.update(state) + self.main_process = False + def __hash__(self): return hash(self.path) - def update(self): - """ To be overloaded, will be called when a frame has been written. """ - def get_frame_number(self, n): if self.colormap is None and self.colors is None: return n[1] + n[2] * self.shape[1], n[0] @@ -400,53 +395,13 @@ class IJTiffFile: def save(self, frame, c, z, t, **extratags): """ save a 2d numpy array to the tiff at channel=c, slice=z, time=t, with optional extra tif tags """ - assert (c, z, t) not in self.frames_written, f'frame {c} {z} {t} is added already' + assert (c, z, t) not in self.pool.tasks, f'frame {c} {z} {t} is added already' assert all([0 <= i < s for i, s in zip((c, z, t), self.shape)]), \ 'frame {} {} {} is outside shape {} {} {}'.format(c, z, t, *self.shape) - self.frames_added.append((c, z, t)) - self.pool_manager.add_frame(self.path, - frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, (c, z, t)) + self.pool(frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, handle=(c, z, t)) if extratags: self.frame_extra_tags[(c, z, t)] = Tag.to_tags(extratags) - @cached_property - def loader(self): - tif = self.ij_tiff_frame(np.zeros((8, 8), self.dtype)) - with BytesIO(tif) as fh: - return tif, Header(fh), IFD(fh) - - def load(self, *n): - """ read back a frame that's just written - useful for simulating a large number of frames without using much memory - """ - if n not in self.frames_added: - raise KeyError(f'frame {n} has not been added yet') - while n not in self.frames_written: - self.pool_manager.get_ifds_from_queue() - tif, header, ifd = self.loader - framenumber, channel = self.get_frame_number(n) - with BytesIO(tif) as fh: - fh.seek(0, 2) - fstripbyteoffsets, fstripbytecounts = [], [] - with open(self.path, 'rb') as file: - for stripbyteoffset, stripbytecount in zip(*self.strips[(framenumber, channel)]): - file.seek(stripbyteoffset) - bdata = file.read(stripbytecount) - fstripbyteoffsets.append(fh.tell()) - fstripbytecounts.append(len(bdata)) - fh.write(bdata) - ifd = ifd.copy() - for key, value in zip((257, 256, 278, 270, 273, 279), - (*self.frame_shape, self.frame_shape[0] // len(fstripbyteoffsets), - f'{{"shape": [{self.frame_shape[0]}, {self.frame_shape[1]}]}}', - fstripbyteoffsets, fstripbytecounts)): - tag = ifd[key] - tag.value = value - tag.write_tag(fh, key, header, tag.offset) - tag.write_data() - fh.seek(0) - return tifffile.TiffFile(fh).asarray() - @property def description(self): desc = ['ImageJ=1.11a'] @@ -473,21 +428,6 @@ class IJTiffFile: desc.append(bytes(self.comment, 'ascii')) return b'\n'.join(desc) - @cached_property - def empty_frame(self): - return self.compress_frame(np.zeros(self.frame_shape, self.dtype)) - - @cached_property - def frame_shape(self): - ifd = self.ifds[list(self.ifds.keys())[-1]].copy() - return ifd[257].value[0], ifd[256].value[0] - - def add_empty_frame(self, n): - framenr, channel = self.get_frame_number(n) - ifd, strips = self.empty_frame - self.ifds[framenr] = ifd.copy() - self.strips[(framenr, channel)] = strips - @cached_property def colormap_bytes(self): colormap = getattr(colorcet, self.colormap) @@ -507,48 +447,44 @@ class IJTiffFile: dtype=int).T.flatten()]) for color in self.colors] def close(self): - if Manager().mp.current_process().pid == self.main_pid: - self.pool_manager.close(self) - with self.fh.lock() as fh: - if len(self.frames_added) == 0: - warn('At least one frame should be added to the tiff, removing file.') - os.remove(self.path) - else: - if len(self.frames_written) < np.prod(self.shape): # add empty frames if needed - for n in product(*[range(i) for i in self.shape]): - if n not in self.frames_written: - self.add_empty_frame(n) + if self.main_process: + ifds, strips = {}, {} + for n in list(self.pool.tasks): + framenr, channel = self.get_frame_number(n) + ifds[framenr], strips[(framenr, channel)] = self.pool[n] - for n, tags in self.frame_extra_tags.items(): - framenr, channel = self.get_frame_number(n) - self.ifds[framenr].update(tags) - if self.colormap is not None: - self.ifds[0][320] = Tag('SHORT', self.colormap_bytes) - self.ifds[0][262] = Tag('SHORT', 3) - if self.colors is not None: - for c, color in enumerate(self.colors_bytes): - self.ifds[c][320] = Tag('SHORT', color) - self.ifds[c][262] = Tag('SHORT', 3) - if 306 not in self.ifds[0]: - self.ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S')) - for framenr in range(self.nframes): - stripbyteoffsets, stripbytecounts = zip(*[self.strips[(framenr, channel)] - for channel in range(self.spp)]) - self.ifds[framenr][258].value = self.spp * self.ifds[framenr][258].value - self.ifds[framenr][270] = Tag('ASCII', self.description) - self.ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, [])) - self.ifds[framenr][277] = Tag('SHORT', self.spp) - self.ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, [])) - self.ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI') - if self.extratags is not None: - self.ifds[framenr].update(self.extratags) - if self.colormap is None and self.colors is None and self.shape[0] > 1: - self.ifds[framenr][284] = Tag('SHORT', 2) - self.ifds[framenr].write(fh, self.header, self.write) - if framenr: - self.ifds[framenr].write_offset(self.ifds[framenr - 1].where_to_write_next_ifd_offset) - else: - self.ifds[framenr].write_offset(self.header.offset - self.header.offsetsize) + self.pool.close() + with self.fh.lock() as fh: + for n, tags in self.frame_extra_tags.items(): + framenr, channel = self.get_frame_number(n) + ifds[framenr].update(tags) + if self.colormap is not None: + ifds[0][320] = Tag('SHORT', self.colormap_bytes) + ifds[0][262] = Tag('SHORT', 3) + if self.colors is not None: + for c, color in enumerate(self.colors_bytes): + ifds[c][320] = Tag('SHORT', color) + ifds[c][262] = Tag('SHORT', 3) + if 306 not in ifds[0]: + ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S')) + for framenr in range(self.nframes): + stripbyteoffsets, stripbytecounts = zip(*[strips[(framenr, channel)] + for channel in range(self.spp)]) + ifds[framenr][258].value = self.spp * ifds[framenr][258].value + ifds[framenr][270] = Tag('ASCII', self.description) + ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, [])) + ifds[framenr][277] = Tag('SHORT', self.spp) + ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, [])) + ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI') + if self.extratags is not None: + ifds[framenr].update(self.extratags) + if self.colormap is None and self.colors is None and self.shape[0] > 1: + ifds[framenr][284] = Tag('SHORT', 2) + ifds[framenr].write(fh, self.header, self.write) + if framenr: + ifds[framenr].write_offset(ifds[framenr - 1].where_to_write_next_ifd_offset) + else: + ifds[framenr].write_offset(self.header.offset - self.header.offsetsize) def __enter__(self): return self @@ -556,9 +492,6 @@ class IJTiffFile: def __exit__(self, *args, **kwargs): self.close() - def clean(self, *args, **kwargs): - """ To be overloaded, will be called when the parallel pool is closing. """ - @staticmethod def hash_check(fh, bvalue, offset): addr = fh.tell() @@ -568,7 +501,7 @@ class IJTiffFile: return same def write(self, fh, bvalue): - hash_value = hash(bvalue) + hash_value = sha1(bvalue).hexdigest() # hash uses a random seed making hashes different in different processes if hash_value in self.hashes and self.hash_check(fh, bvalue, self.hashes[hash_value]): return self.hashes[hash_value] # reuse previously saved data else: @@ -601,112 +534,10 @@ class IJTiffFile: return stripbytecounts, ifd, chunks -class Manager: - instance = None - - def __new__(cls, *args, **kwargs): - if cls.instance is None: - cls.instance = super().__new__(cls) - return cls.instance - - def __init__(self): - if not hasattr(self, 'mp'): - self.mp = multiprocessing.get_context('spawn') - self.manager = self.mp.Manager() - - -class PoolManager: - instance = None - - def __new__(cls, *args, **kwargs): - if cls.instance is None: - cls.instance = super().__new__(cls) - return cls.instance - - def __init__(self, tif, processes=None): - if not hasattr(self, 'tifs'): - self.tifs = {} - if not hasattr(self, 'is_alive'): - self.is_alive = False - if self.is_alive: - raise ValueError('Cannot start new tifffile until previous tifffiles have been closed.') - self.tifs[tif.path] = tif - self.processes = processes - - def close(self, tif): - while len(tif.frames_written) < len(tif.frames_added): - self.get_ifds_from_queue() - self.tifs.pop(tif.path) - if not self.tifs: - self.__class__.instance = None - if self.is_alive: - self.is_alive = False - self.done.set() - while not self.queue.empty(): - self.queue.get() - self.queue.close() - self.queue.join_thread() - while not self.error_queue.empty(): - print(self.error_queue.get()) - self.error_queue.close() - self.ifd_queue.close() - self.ifd_queue.join_thread() - self.pool.close() - self.pool.join() - - def get_ifds_from_queue(self): - while not self.ifd_queue.empty(): - file, n, ifd, strip = self.ifd_queue.get() - framenr, channel = self.tifs[file].get_frame_number(n) - self.tifs[file].ifds[framenr] = ifd - self.tifs[file].strips[(framenr, channel)] = strip - self.tifs[file].frames_written.append(n) - self.tifs[file].update() - - def add_frame(self, *args): - if not self.is_alive: - self.start_pool() - self.get_ifds_from_queue() - self.queue.put(args) - - def start_pool(self): - mp = Manager().mp - self.is_alive = True - nframes = sum([np.prod(tif.shape) for tif in self.tifs.values()]) - if self.processes is None: - self.processes = max(2, min(mp.cpu_count() // 6, nframes)) - elif self.processes == 'all': - self.processes = max(2, min(mp.cpu_count(), nframes)) - else: - self.processes = self.processes - self.queue = mp.Queue(10 * self.processes) - self.ifd_queue = mp.Queue(10 * self.processes) - self.error_queue = mp.Queue() - self.offsets_queue = mp.Queue() - self.done = mp.Event() - self.pool = mp.Pool(self.processes, self.run) - - def run(self): - """ Only this is run in parallel processes. """ - try: - while not self.done.is_set(): - try: - file, frame, n = self.queue.get(True, 0.02) - self.ifd_queue.put((file, n, *self.tifs[file].compress_frame(frame))) - except queues.Empty: - continue - except Exception: - print_exc() - self.error_queue.put(format_exc()) - finally: - for tif in self.tifs.values(): - tif.clean() - - class FileHandle: """ Process safe file handle """ def __init__(self, name): - manager = Manager().manager + manager = PoolSingleton().manager if os.path.exists(name): os.remove(name) with open(name, 'xb'):