- special IJTiffParallel class to help generate frames in parallel

- warning now shows which frames are missing
This commit is contained in:
Wim Pomp
2024-10-16 14:26:30 +02:00
parent 654755ab83
commit 83c0e221fb
3 changed files with 90 additions and 54 deletions

View File

@@ -14,24 +14,19 @@ from tqdm.auto import tqdm
from . import tiffwrite_rs as rs # noqa
__all__ = ['Header', 'IJTiffFile', 'IFD', 'FrameInfo', 'Tag', 'Strip', 'tiffwrite']
__all__ = ['IJTiffFile', 'IJTiffParallel', 'FrameInfo', 'Tag', 'tiffwrite']
Tag = rs.Tag
FrameInfo = tuple[np.ndarray, int, int, int]
class Header:
pass
""" deprecated """
class IFD(dict):
pass
class Tag(rs.Tag):
pass
Strip = tuple[list[int], list[int]]
CZT = tuple[int, int, int]
FrameInfo = tuple[np.ndarray, None, CZT]
""" deprecated """
class TiffWriteWarning(UserWarning):
@@ -40,6 +35,7 @@ class TiffWriteWarning(UserWarning):
class IJTiffFile(rs.IJTiffFile):
""" Writes a tiff file in a format that the BioFormats reader in Fiji understands.
Zstd compression is done in parallel using Rust.
file: filename of the new tiff file
shape: not used anymore
dtype: datatype to use when saving to tiff
@@ -97,38 +93,33 @@ class IJTiffFile(rs.IJTiffFile):
def save(self, frame: ArrayLike, c: int, z: int, t: int, extratags: Sequence[Tag] = None) -> None:
""" save a 2d numpy array to the tiff at channel=c, slice=z, time=t, with optional extra tif tags """
for frame, _, (cn, zn, tn) in self.compress_frame(frame):
frame = np.asarray(frame).astype(self.dtype)
match self.dtype:
case np.uint8:
self.save_u8(frame, c + cn, z + zn, t + tn)
case np.uint16:
self.save_u16(frame, c + cn, z + zn, t + tn)
case np.uint32:
self.save_u32(frame, c + cn, z + zn, t + tn)
case np.uint64:
self.save_u64(frame, c + cn, z + zn, t + tn)
case np.int8:
self.save_i8(frame, c + cn, z + zn, t + tn)
case np.int16:
self.save_i16(frame, c + cn, z + zn, t + tn)
case np.int32:
self.save_i32(frame, c + cn, z + zn, t + tn)
case np.int64:
self.save_i64(frame, c + cn, z + zn, t + tn)
case np.float32:
self.save_f32(frame, c + cn, z + zn, t + tn)
case np.float64:
self.save_f64(frame, c + cn, z + zn, t + tn)
case _:
raise TypeError(f'Cannot save type {self.dtype}')
if extratags is not None:
for extra_tag in extratags:
self.append_extra_tag(extra_tag, (c, z, t))
def compress_frame(self, frame: ArrayLike) -> tuple[FrameInfo]: # noqa
""" backwards compatibility """
return (frame, None, (0, 0, 0)),
frame = np.asarray(frame).astype(self.dtype)
match self.dtype:
case np.uint8:
self.save_u8(frame, c, z, t)
case np.uint16:
self.save_u16(frame, c, z, t)
case np.uint32:
self.save_u32(frame, c, z, t)
case np.uint64:
self.save_u64(frame, c, z, t)
case np.int8:
self.save_i8(frame, c, z, t)
case np.int16:
self.save_i16(frame, c, z, t)
case np.int32:
self.save_i32(frame, c, z, t)
case np.int64:
self.save_i64(frame, c, z, t)
case np.float32:
self.save_f32(frame, c, z, t)
case np.float64:
self.save_f64(frame, c, z, t)
case _:
raise TypeError(f'Cannot save type {self.dtype}')
if extratags is not None:
for extra_tag in extratags:
self.append_extra_tag(extra_tag, (c, z, t))
def get_colormap(colormap: str) -> np.ndarray:
@@ -181,3 +172,45 @@ def tiffwrite(file: str | Path, data: np.ndarray, axes: str = 'TZCXY', dtype: DT
for n in tqdm(product(*[range(i) for i in shape]), total=np.prod(shape), # type: ignore
desc='Saving tiff', disable=not bar):
f.save(data[n], *n)
try:
from parfor import ParPool, Task
from abc import abstractmethod, ABCMeta
from functools import wraps
class IJTiffParallel(ParPool, metaclass=ABCMeta):
""" wraps IJTiffFile.save in a parallel pool, the method 'parallel' needs to be overloaded """
@abstractmethod
def parallel(self, frame: Any) -> Sequence[tuple[ArrayLike, int, int, int]]:
""" does something with frame in a parallel process,
and returns a sequence of frames and offsets to c, z and t to save in the tif """
@wraps(IJTiffFile.__init__)
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.ijtifffile = IJTiffFile(*args, **kwargs)
super().__init__(self.parallel) # noqa
def done(self, task: Task) -> None:
c, z, t = task.handle
super().done(task)
for frame, cn, zn, tn in self[c, z, t]:
self.ijtifffile.save(frame, c + cn, z + zn, t + tn)
@wraps(IJTiffFile.close)
def close(self) -> None:
while len(self.tasks):
self.get_newest()
super().close()
self.ijtifffile.close()
@wraps(IJTiffFile.save)
def save(self, frame: Any, c: int, z: int, t: int, extratags: Sequence[Tag] = None) -> None:
self[c, z, t] = frame
if extratags is not None:
for extra_tag in extratags:
self.ijtifffile.append_extra_tag(extra_tag, (c, z, t))
except ImportError:
IJTiffPool = None