- Use parfor to take care of the parallel part.

- Use sha1 hash because it's consistent between processes.
This commit is contained in:
Wim Pomp
2023-09-11 17:12:04 +02:00
parent f68afd0a1b
commit e736770512
5 changed files with 80 additions and 249 deletions

View File

@@ -11,7 +11,7 @@ good compression.
- Compresses even more by referencing tag or image data which otherwise would have been saved several times.
For example empty frames, or a long string tag on every frame.
- Enables memory efficient scripts by saving frames whenever they're ready to be saved, not waiting for the whole stack.
- Colormaps, extra tags globally or frame dependent.
- Colormaps, extra tags, globally or frame dependent.
## Installation
pip install tiffwrite
@@ -92,6 +92,3 @@ or
be opened as a correctly ordered hyperstack.
- Using the colormap parameter you can make ImageJ open the file and apply the colormap. colormap='glasbey' is very
useful.
- IJTiffFile does not allow more than one pool of parallel processes to be open at a time. Therefore, when writing
multiple tiff's simultaneously you have to open all before you start saving any frame, in this way all files share the
same pool.

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "tiffwrite"
version = "2023.3.0"
version = "2023.8.0"
description = "Parallel tiff writer compatible with ImageJ."
authors = ["Wim Pomp, Lenstra lab NKI <w.pomp@nki.nl>"]
license = "GPL-3.0-or-later"
@@ -15,6 +15,7 @@ numpy = "*"
tqdm = "*"
colorcet = "*"
matplotlib = "*"
parfor = ">=2023.8.3"
pytest = { version = "*", optional = true }
[tool.poetry.extras]

View File

@@ -1,16 +1,17 @@
from contextlib import ExitStack
from itertools import product
import numpy as np
from tiffwrite import IJTiffFile
from itertools import product
from contextlib import ExitStack
from tqdm.auto import tqdm
def test_mult(tmp_path):
shape = (3, 5, 12)
paths = [tmp_path / f'test{i}.tif' for i in range(8)]
shape = (2, 3, 5)
paths = [tmp_path / f'test{i}.tif' for i in range(6)]
with ExitStack() as stack:
tifs = [stack.enter_context(IJTiffFile(path, shape)) for path in paths]
for c, z, t in tqdm(product(range(shape[0]), range(shape[1]), range(shape[2])), total=np.prod(shape)):
for tif in tifs:
tif.save(np.random.randint(0, 255, (1024, 1024)), c, z, t)
tif.save(np.random.randint(0, 255, (64, 64)), c, z, t)
assert all([path.exists() for path in paths])

View File

@@ -1,6 +1,7 @@
from itertools import product
import numpy as np
from tiffwrite import IJTiffFile
from itertools import product
def test_single(tmp_path):

View File

@@ -1,24 +1,22 @@
import os
import tifffile
import colorcet
import struct
import numpy as np
import multiprocessing
from multiprocessing import queues
from io import BytesIO
from tqdm.auto import tqdm
from itertools import product
from collections.abc import Iterable
from numbers import Number
from fractions import Fraction
from traceback import print_exc, format_exc
from functools import cached_property
from datetime import datetime
from matplotlib import colors as mpl_colors
from contextlib import contextmanager
from warnings import warn
from datetime import datetime
from fractions import Fraction
from functools import cached_property
from hashlib import sha1
from importlib.metadata import version
from io import BytesIO
from itertools import product
from numbers import Number
import colorcet
import numpy as np
import tifffile
from matplotlib import colors as mpl_colors
from parfor import ParPool, PoolSingleton
from tqdm.auto import tqdm
__all__ = ["IJTiffFile", "Tag", "tiffwrite"]
@@ -40,12 +38,12 @@ def tiffwrite(file, data, axes='TZCXY', dtype=None, bar=False, *args, **kwargs):
axes = axes[-np.ndim(data):].upper()
if not axes == 'CZTXY':
T = [axes.find(i) for i in 'CZTXY']
E = [i for i, j in enumerate(T) if j < 0]
T = [i for i in T if i >= 0]
data = np.transpose(data, T)
for e in E:
data = np.expand_dims(data, e)
axes_shuffle = [axes.find(i) for i in 'CZTXY']
axes_add = [i for i, j in enumerate(axes_shuffle) if j < 0]
axes_shuffle = [i for i in axes_shuffle if i >= 0]
data = np.transpose(data, axes_shuffle)
for axis in axes_add:
data = np.expand_dims(data, axis)
shape = data.shape[:3]
with IJTiffFile(file, shape, data.dtype if dtype is None else dtype, *args, **kwargs) as f:
@@ -340,7 +338,7 @@ class IJTiffFile:
wp@tl20200214
"""
def __init__(self, path, shape, dtype='uint16', colors=None, colormap=None, pxsize=None, deltaz=None,
timeinterval=None, compression=(8, 9), comment=None, processes=None, **extratags):
timeinterval=None, compression=(8, 9), comment=None, **extratags):
assert len(shape) >= 3, 'please specify all c, z, t for the shape'
assert len(shape) <= 3, 'please specify only c, z, t for the shape'
assert np.dtype(dtype).char in 'BbHhf', 'datatype not supported'
@@ -365,25 +363,22 @@ class IJTiffFile:
self.frames = []
self.spp = self.shape[0] if self.colormap is None and self.colors is None else 1 # samples/pixel
self.nframes = np.prod(self.shape[1:]) if self.colormap is None and self.colors is None else np.prod(self.shape)
self.offsets = {}
self.strips = {}
self.ifds = {}
self.frame_extra_tags = {}
self.frames_added = []
self.frames_written = []
self.pool_manager = PoolManager(self, processes)
self.fh = FileHandle(path)
self.hashes = Manager().manager.dict()
self.main_pid = Manager().mp.current_process().pid
self.hashes = PoolSingleton().manager.dict()
self.pool = ParPool(self.compress_frame)
self.main_process = True
with self.fh.lock() as fh:
self.header.write(fh)
def __setstate__(self, state):
self.__dict__.update(state)
self.main_process = False
def __hash__(self):
return hash(self.path)
def update(self):
""" To be overloaded, will be called when a frame has been written. """
def get_frame_number(self, n):
if self.colormap is None and self.colors is None:
return n[1] + n[2] * self.shape[1], n[0]
@@ -400,53 +395,13 @@ class IJTiffFile:
def save(self, frame, c, z, t, **extratags):
""" save a 2d numpy array to the tiff at channel=c, slice=z, time=t, with optional extra tif tags
"""
assert (c, z, t) not in self.frames_written, f'frame {c} {z} {t} is added already'
assert (c, z, t) not in self.pool.tasks, f'frame {c} {z} {t} is added already'
assert all([0 <= i < s for i, s in zip((c, z, t), self.shape)]), \
'frame {} {} {} is outside shape {} {} {}'.format(c, z, t, *self.shape)
self.frames_added.append((c, z, t))
self.pool_manager.add_frame(self.path,
frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, (c, z, t))
self.pool(frame.astype(self.dtype) if isinstance(frame, np.ndarray) else frame, handle=(c, z, t))
if extratags:
self.frame_extra_tags[(c, z, t)] = Tag.to_tags(extratags)
@cached_property
def loader(self):
tif = self.ij_tiff_frame(np.zeros((8, 8), self.dtype))
with BytesIO(tif) as fh:
return tif, Header(fh), IFD(fh)
def load(self, *n):
""" read back a frame that's just written
useful for simulating a large number of frames without using much memory
"""
if n not in self.frames_added:
raise KeyError(f'frame {n} has not been added yet')
while n not in self.frames_written:
self.pool_manager.get_ifds_from_queue()
tif, header, ifd = self.loader
framenumber, channel = self.get_frame_number(n)
with BytesIO(tif) as fh:
fh.seek(0, 2)
fstripbyteoffsets, fstripbytecounts = [], []
with open(self.path, 'rb') as file:
for stripbyteoffset, stripbytecount in zip(*self.strips[(framenumber, channel)]):
file.seek(stripbyteoffset)
bdata = file.read(stripbytecount)
fstripbyteoffsets.append(fh.tell())
fstripbytecounts.append(len(bdata))
fh.write(bdata)
ifd = ifd.copy()
for key, value in zip((257, 256, 278, 270, 273, 279),
(*self.frame_shape, self.frame_shape[0] // len(fstripbyteoffsets),
f'{{"shape": [{self.frame_shape[0]}, {self.frame_shape[1]}]}}',
fstripbyteoffsets, fstripbytecounts)):
tag = ifd[key]
tag.value = value
tag.write_tag(fh, key, header, tag.offset)
tag.write_data()
fh.seek(0)
return tifffile.TiffFile(fh).asarray()
@property
def description(self):
desc = ['ImageJ=1.11a']
@@ -473,21 +428,6 @@ class IJTiffFile:
desc.append(bytes(self.comment, 'ascii'))
return b'\n'.join(desc)
@cached_property
def empty_frame(self):
return self.compress_frame(np.zeros(self.frame_shape, self.dtype))
@cached_property
def frame_shape(self):
ifd = self.ifds[list(self.ifds.keys())[-1]].copy()
return ifd[257].value[0], ifd[256].value[0]
def add_empty_frame(self, n):
framenr, channel = self.get_frame_number(n)
ifd, strips = self.empty_frame
self.ifds[framenr] = ifd.copy()
self.strips[(framenr, channel)] = strips
@cached_property
def colormap_bytes(self):
colormap = getattr(colorcet, self.colormap)
@@ -507,48 +447,44 @@ class IJTiffFile:
dtype=int).T.flatten()]) for color in self.colors]
def close(self):
if Manager().mp.current_process().pid == self.main_pid:
self.pool_manager.close(self)
with self.fh.lock() as fh:
if len(self.frames_added) == 0:
warn('At least one frame should be added to the tiff, removing file.')
os.remove(self.path)
else:
if len(self.frames_written) < np.prod(self.shape): # add empty frames if needed
for n in product(*[range(i) for i in self.shape]):
if n not in self.frames_written:
self.add_empty_frame(n)
if self.main_process:
ifds, strips = {}, {}
for n in list(self.pool.tasks):
framenr, channel = self.get_frame_number(n)
ifds[framenr], strips[(framenr, channel)] = self.pool[n]
for n, tags in self.frame_extra_tags.items():
framenr, channel = self.get_frame_number(n)
self.ifds[framenr].update(tags)
if self.colormap is not None:
self.ifds[0][320] = Tag('SHORT', self.colormap_bytes)
self.ifds[0][262] = Tag('SHORT', 3)
if self.colors is not None:
for c, color in enumerate(self.colors_bytes):
self.ifds[c][320] = Tag('SHORT', color)
self.ifds[c][262] = Tag('SHORT', 3)
if 306 not in self.ifds[0]:
self.ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S'))
for framenr in range(self.nframes):
stripbyteoffsets, stripbytecounts = zip(*[self.strips[(framenr, channel)]
for channel in range(self.spp)])
self.ifds[framenr][258].value = self.spp * self.ifds[framenr][258].value
self.ifds[framenr][270] = Tag('ASCII', self.description)
self.ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, []))
self.ifds[framenr][277] = Tag('SHORT', self.spp)
self.ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, []))
self.ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI')
if self.extratags is not None:
self.ifds[framenr].update(self.extratags)
if self.colormap is None and self.colors is None and self.shape[0] > 1:
self.ifds[framenr][284] = Tag('SHORT', 2)
self.ifds[framenr].write(fh, self.header, self.write)
if framenr:
self.ifds[framenr].write_offset(self.ifds[framenr - 1].where_to_write_next_ifd_offset)
else:
self.ifds[framenr].write_offset(self.header.offset - self.header.offsetsize)
self.pool.close()
with self.fh.lock() as fh:
for n, tags in self.frame_extra_tags.items():
framenr, channel = self.get_frame_number(n)
ifds[framenr].update(tags)
if self.colormap is not None:
ifds[0][320] = Tag('SHORT', self.colormap_bytes)
ifds[0][262] = Tag('SHORT', 3)
if self.colors is not None:
for c, color in enumerate(self.colors_bytes):
ifds[c][320] = Tag('SHORT', color)
ifds[c][262] = Tag('SHORT', 3)
if 306 not in ifds[0]:
ifds[0][306] = Tag('ASCII', datetime.now().strftime('%Y:%m:%d %H:%M:%S'))
for framenr in range(self.nframes):
stripbyteoffsets, stripbytecounts = zip(*[strips[(framenr, channel)]
for channel in range(self.spp)])
ifds[framenr][258].value = self.spp * ifds[framenr][258].value
ifds[framenr][270] = Tag('ASCII', self.description)
ifds[framenr][273] = Tag('LONG8', sum(stripbyteoffsets, []))
ifds[framenr][277] = Tag('SHORT', self.spp)
ifds[framenr][279] = Tag('LONG8', sum(stripbytecounts, []))
ifds[framenr][305] = Tag('ASCII', 'tiffwrite_tllab_NKI')
if self.extratags is not None:
ifds[framenr].update(self.extratags)
if self.colormap is None and self.colors is None and self.shape[0] > 1:
ifds[framenr][284] = Tag('SHORT', 2)
ifds[framenr].write(fh, self.header, self.write)
if framenr:
ifds[framenr].write_offset(ifds[framenr - 1].where_to_write_next_ifd_offset)
else:
ifds[framenr].write_offset(self.header.offset - self.header.offsetsize)
def __enter__(self):
return self
@@ -556,9 +492,6 @@ class IJTiffFile:
def __exit__(self, *args, **kwargs):
self.close()
def clean(self, *args, **kwargs):
""" To be overloaded, will be called when the parallel pool is closing. """
@staticmethod
def hash_check(fh, bvalue, offset):
addr = fh.tell()
@@ -568,7 +501,7 @@ class IJTiffFile:
return same
def write(self, fh, bvalue):
hash_value = hash(bvalue)
hash_value = sha1(bvalue).hexdigest() # hash uses a random seed making hashes different in different processes
if hash_value in self.hashes and self.hash_check(fh, bvalue, self.hashes[hash_value]):
return self.hashes[hash_value] # reuse previously saved data
else:
@@ -601,112 +534,10 @@ class IJTiffFile:
return stripbytecounts, ifd, chunks
class Manager:
instance = None
def __new__(cls, *args, **kwargs):
if cls.instance is None:
cls.instance = super().__new__(cls)
return cls.instance
def __init__(self):
if not hasattr(self, 'mp'):
self.mp = multiprocessing.get_context('spawn')
self.manager = self.mp.Manager()
class PoolManager:
instance = None
def __new__(cls, *args, **kwargs):
if cls.instance is None:
cls.instance = super().__new__(cls)
return cls.instance
def __init__(self, tif, processes=None):
if not hasattr(self, 'tifs'):
self.tifs = {}
if not hasattr(self, 'is_alive'):
self.is_alive = False
if self.is_alive:
raise ValueError('Cannot start new tifffile until previous tifffiles have been closed.')
self.tifs[tif.path] = tif
self.processes = processes
def close(self, tif):
while len(tif.frames_written) < len(tif.frames_added):
self.get_ifds_from_queue()
self.tifs.pop(tif.path)
if not self.tifs:
self.__class__.instance = None
if self.is_alive:
self.is_alive = False
self.done.set()
while not self.queue.empty():
self.queue.get()
self.queue.close()
self.queue.join_thread()
while not self.error_queue.empty():
print(self.error_queue.get())
self.error_queue.close()
self.ifd_queue.close()
self.ifd_queue.join_thread()
self.pool.close()
self.pool.join()
def get_ifds_from_queue(self):
while not self.ifd_queue.empty():
file, n, ifd, strip = self.ifd_queue.get()
framenr, channel = self.tifs[file].get_frame_number(n)
self.tifs[file].ifds[framenr] = ifd
self.tifs[file].strips[(framenr, channel)] = strip
self.tifs[file].frames_written.append(n)
self.tifs[file].update()
def add_frame(self, *args):
if not self.is_alive:
self.start_pool()
self.get_ifds_from_queue()
self.queue.put(args)
def start_pool(self):
mp = Manager().mp
self.is_alive = True
nframes = sum([np.prod(tif.shape) for tif in self.tifs.values()])
if self.processes is None:
self.processes = max(2, min(mp.cpu_count() // 6, nframes))
elif self.processes == 'all':
self.processes = max(2, min(mp.cpu_count(), nframes))
else:
self.processes = self.processes
self.queue = mp.Queue(10 * self.processes)
self.ifd_queue = mp.Queue(10 * self.processes)
self.error_queue = mp.Queue()
self.offsets_queue = mp.Queue()
self.done = mp.Event()
self.pool = mp.Pool(self.processes, self.run)
def run(self):
""" Only this is run in parallel processes. """
try:
while not self.done.is_set():
try:
file, frame, n = self.queue.get(True, 0.02)
self.ifd_queue.put((file, n, *self.tifs[file].compress_frame(frame)))
except queues.Empty:
continue
except Exception:
print_exc()
self.error_queue.put(format_exc())
finally:
for tif in self.tifs.values():
tif.clean()
class FileHandle:
""" Process safe file handle """
def __init__(self, name):
manager = Manager().manager
manager = PoolSingleton().manager
if os.path.exists(name):
os.remove(name)
with open(name, 'xb'):