diff --git a/ndbioimage/__init__.py b/ndbioimage/__init__.py index b08c80f..4b39ff3 100755 --- a/ndbioimage/__init__.py +++ b/ndbioimage/__init__.py @@ -1,33 +1,32 @@ +import multiprocessing import re import warnings - -import pandas -import yaml -import numpy as np -import multiprocessing -import ome_types -from ome_types import ureg, model -from pint import set_application_registry -from datetime import datetime -from tqdm.auto import tqdm -from itertools import product -from collections import OrderedDict -from abc import ABCMeta, abstractmethod -from functools import cached_property, wraps -from parfor import parfor -from tiffwrite import IJTiffFile -from numbers import Number +from abc import ABC, ABCMeta, abstractmethod from argparse import ArgumentParser -from pathlib import Path +from collections import OrderedDict +from datetime import datetime +from functools import cached_property, wraps from importlib.metadata import version -from traceback import print_exc +from itertools import product +from numbers import Number from operator import truediv -from .transforms import Transform, Transforms +from pathlib import Path +from traceback import print_exc + +import numpy as np +import ome_types +import yaml +from ome_types import model, ureg +from pint import set_application_registry +from tiffwrite import IJTiffFile +from tqdm.auto import tqdm + from .jvm import JVM +from .transforms import Transform, Transforms try: __version__ = version(Path(__file__).parent.name) -except (Exception,): +except Exception: # noqa __version__ = 'unknown' try: @@ -35,7 +34,7 @@ try: head = g.read().split(':')[1].strip() with open(Path(__file__).parent.parent / '.git' / head) as h: __git_commit_hash__ = h.read().rstrip('\n') -except (Exception,): +except Exception: # noqa __git_commit_hash__ = 'unknown' ureg.default_format = '~P' @@ -57,235 +56,6 @@ class TransformTiff(IJTiffFile): return super().compress_frame(np.asarray(self.image(*frame)).astype(self.dtype)) -class ImTransforms(Transforms): - """ Transforms class with methods to calculate channel transforms from bead files etc. """ - - def __init__(self, path, cyllens, file=None, transforms=None): - super().__init__() - self.cyllens = tuple(cyllens) - if transforms is None: - # TODO: check this - if re.search(r'^Pos\d+', path.name): - self.path = path.parent.parent - else: - self.path = path.parent - if file is not None: - if isinstance(file, str) and file.lower().endswith('.yml'): - self.ymlpath = file - self.beadfile = None - else: - self.ymlpath = self.path / 'transform.yml' - self.beadfile = file - else: - self.ymlpath = self.path / 'transform.yml' - self.beadfile = None - self.tifpath = self.ymlpath.with_suffix('.tif') - try: - self.load(self.ymlpath) - except (Exception,): - print('No transform file found, trying to generate one.') - if not self.files: - raise FileNotFoundError('No bead files found to calculate the transform from.') - self.calculate_transforms() - self.save(self.ymlpath) - self.save_transform_tiff() - print(f'Saving transform in {self.ymlpath}.') - print(f'Please check the transform in {self.tifpath}.') - else: # load from dict transforms - self.path = path - self.beadfile = file - for i, (key, value) in enumerate(transforms.items()): - self[key] = Transform(value) - - def coords_pandas(self, array, cnamelist, colums=None): - if isinstance(array, pandas.DataFrame): - return pandas.concat([self[(cnamelist[int(row['C'])],)].coords(row, colums) - for _, row in array.iterrows()], axis=1).T - elif isinstance(array, pandas.Series): - return self[(cnamelist[int(array['C'])],)].coords(array, colums) - else: - raise TypeError('Not a pandas DataFrame or Series.') - - @cached_property - def files(self): - try: - if self.beadfile is None: - files = self.get_bead_files() - else: - files = self.beadfile - if isinstance(files, str): - files = (Path(files),) - elif isinstance(files, Path): - files = (files,) - return tuple(files) - except (Exception,): - return () - - def get_bead_files(self): - files = sorted([f for f in self.path.iterdir() if f.name.lower().startswith('beads') - and not f.suffix.lower() == '.pdf' and not f.suffix.lower() == 'pkl']) - if not files: - raise Exception('No bead file found!') - checked_files = [] - for file in files: - try: - if file.is_dir(): - file /= 'Pos0' - with Imread(file): # check for errors opening the file - checked_files.append(file) - except (Exception,): - continue - if not checked_files: - raise Exception('No bead file found!') - return checked_files - - def calculate_transform(self, file): - """ When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162 - in the horizontal direction """ - with Imread(file, axes='zcxy') as im: - max_ims = im.max('z') - goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)] - if not goodch: - goodch = list(range(len(max_ims))) - untransformed = [c for c in range(im.shape['c']) if self.cyllens[im.detector[c]].lower() == 'none'] - - good_and_untrans = sorted(set(goodch) & set(untransformed)) - if good_and_untrans: - masterch = good_and_untrans[0] - else: - masterch = goodch[0] - transform = Transform() - if not good_and_untrans: - matrix = transform.matrix - matrix[0, 0] = 0.86 - transform.matrix = matrix - transforms = Transforms() - for c in tqdm(goodch): - if c == masterch: - transforms[(im.cnamelist[c],)] = transform - else: - transforms[(im.cnamelist[c],)] = Transform(max_ims[masterch], max_ims[c]) * transform - return transforms - - def calculate_transforms(self): - transforms = [self.calculate_transform(file) for file in self.files] - for key in set([key for transform in transforms for key in transform.keys()]): - new_transforms = [transform[key] for transform in transforms if key in transform] - if len(new_transforms) == 1: - self[key] = new_transforms[0] - else: - self[key] = Transform() - self[key].parameters = np.mean([t.parameters for t in new_transforms], 0) - self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) / - np.sqrt(len(new_transforms))).tolist() - - def save_transform_tiff(self): - n_channels = 0 - for file in self.files: - with Imread(file) as im: - n_channels = max(n_channels, im.shape['c']) - with IJTiffFile(self.tifpath, (n_channels, 1, len(self.files))) as tif: - for t, file in enumerate(self.files): - with Imread(file) as im: - with Imread(file, transform=True) as jm: - for c in range(im.shape['c']): - tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t) - - -class ImShiftTransforms(Transforms): - """ Class to handle drift in xy. The image this is applied to must have a channel transform already, which is then - replaced by this class. """ - - def __init__(self, im, shifts=None): - """ im: Calculate shifts from channel-transformed images - im, t x 2 array Sets shifts from array, one row per frame - im, dict {frame: shift} Sets shifts from dict, each key is a frame number from where a new shift is applied - im, file Loads shifts from a saved file """ - super().__init__() - with (Imread(im, transform=True, drift=False) if isinstance(im, str) - else im.new(transform=True, drift=False)) as im: - self.impath = im.path - self.path = self.impath.parent / self.impath.stem + '_shifts.txt' - self.tracks, self.detectors, self.files = im.track, im.detector, im.beadfile - if shifts is not None: - if isinstance(shifts, np.ndarray): - self.shifts = shifts - self.shifts2transforms(im) - elif isinstance(shifts, dict): - self.shifts = np.zeros((im.shape['t'], 2)) - for k in sorted(shifts.keys()): - self.shifts[k:] = shifts[k] - self.shifts2transforms(im) - elif isinstance(shifts, str): - self.load(im, shifts) - elif self.path.exists(): - self.load(im, self.path) - else: - self.calulate_shifts(im) - self.save() - - def __call__(self, channel, time, tracks=None, detectors=None): - tracks = tracks or self.tracks - detectors = detectors or self.detectors - track, detector = tracks[channel], detectors[channel] - if (track, detector, time) in self: - return self[track, detector, time] - elif (0, detector, time) in self: - return self[0, detector, time] - else: - return Transform() - - def load(self, im, file): - self.shifts = np.loadtxt(file) - self.shifts2transforms(im) - - def save(self, file=None): - self.path = file or self.path - np.savetxt(self.path, self.shifts) - - def coords(self, array, colums=None): - if isinstance(array, pandas.DataFrame): - return pandas.concat([self(int(row['C']), int(row['T'])).coords(row, colums) - for _, row in array.iterrows()], axis=1).T - elif isinstance(array, pandas.Series): - return self(int(array['C']), int(array['T'])).coords(array, colums) - else: - raise TypeError('Not a pandas DataFrame or Series.') - - def calulate_shifts0(self, im): - """ Calculate shifts relative to the first frame """ - im0 = im[:, 0, 0].squeeze().transpose(2, 0, 1) - - @parfor(range(1, im.shape['t']), (im, im0), desc='Calculating image shifts.') - def fun(t, im, im0): - return Transform(im0, im[:, 0, t].squeeze().transpose(2, 0, 1), 'translation') - - transforms = [Transform()] + fun - self.shifts = np.array([t.parameters[4:] for t in transforms]) - self.set_transforms(transforms, im.transform) - - def calulate_shifts(self, im): - """ Calculate shifts relative to the previous frame """ - - @parfor(range(1, im.shape['t']), (im,), desc='Calculating image shifts.') - def fun(t, im): - return Transform(im[:, 0, t - 1].squeeze().transpose(2, 0, 1), im[:, 0, t].squeeze().transpose(2, 0, 1), - 'translation') - - transforms = [Transform()] + fun - self.shifts = np.cumsum([t.parameters[4:] for t in transforms]) - self.set_transforms(transforms, im.transform) - - def shifts2transforms(self, im): - self.set_transforms([Transform(np.array(((1, 0, s[0]), (0, 1, s[1]), (0, 0, 1)))) - for s in self.shifts], im.transform) - - def set_transforms(self, shift_transforms, channel_transforms): - for key, value in channel_transforms.items(): - for t, T in enumerate(shift_transforms): - self[key[0], key[1], t] = T * channel_transforms[key] - - class DequeDict(OrderedDict): def __init__(self, maxlen=None, *args, **kwargs): self.maxlen = maxlen @@ -314,24 +84,24 @@ def find(obj, **kwargs): def try_default(fun, default, *args, **kwargs): try: return fun(*args, **kwargs) - except (Exception,): + except Exception: # noqa return default def get_ome(path): from .readers.bfread import jars try: - jvm = JVM(jars) + jvm = JVM(jars) # noqa ome_meta = jvm.metadata_tools.createOMEXMLMetadata() reader = jvm.image_reader() reader.setMetadataStore(ome_meta) reader.setId(str(path)) ome = ome_types.from_xml(str(ome_meta.dumpXML()), parser='lxml') - except (Exception,): + except Exception: # noqa print_exc() ome = model.OME() finally: - jvm.kill_vm() + jvm.kill_vm() # noqa return ome @@ -356,7 +126,59 @@ class Shape(tuple): return tuple(self[i] for i in 'xyczt') -class Imread(np.lib.mixins.NDArrayOperatorsMixin): +class Imread(np.lib.mixins.NDArrayOperatorsMixin, ABC): + """ class to read image files, while taking good care of important metadata, + currently optimized for .czi files, but can open anything that bioformats can handle + path: path to the image file + optional: + series: in case multiple experiments are saved in one file, like in .lif files + dtype: datatype to be used when returning frames + meta: define metadata, used for pickle-ing + + NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate + + modify images on the fly with a decorator function: + define a function which takes an instance of this object, one image frame, + and the coordinates c, z, t as arguments, and one image frame as return + >> imread.frame_decorator = fun + then use imread as usually + + Examples: + >> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi') + >> im + << shows summary + >> im.shape + << (256, 256, 2, 1, 600) + >> plt.imshow(im(1, 0, 100)) + << plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array + with 1 frame + >> data = im[:,:,0,0,:25] + << retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array + >> plt.imshow(im.max(0, None, 0)) + << plots max-z projection at c=0, t=0 + >> len(im) + << total number of frames + >> im.pxsize + << 0.09708737864077668 image-plane pixel size in um + >> im.laserwavelengths + << [642, 488] + >> im.laserpowers + << [0.02, 0.0005] in % + + See __init__ and other functions for more ideas. + + Subclassing: + Subclass AbstractReader to add more file types. A subclass should always have at least the following + methods: + staticmethod _can_open(path): returns True when the subclass can open the image in path + __metadata__(self): pulls some metadata from the file and do other format specific things, + it needs to define a few properties, like shape, etc. + __frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t + optional close(self): close the file in a proper way + optional field priority: subclasses with lower priority will be tried first, default = 99 + Any other method can be overridden as needed + wp@tl2019-2023 """ + def __new__(cls, path=None, *args, **kwargs): if cls is not Imread: return super().__new__(cls) @@ -376,18 +198,13 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): return super().__new__(subclass) raise ReaderNotFoundError(f'No reader found for {path}.') - def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None, - transform=False, drift=False, beadfile=None, frame_decorator=None): - self.base = base + def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None, frame_decorator=None): + self.base = base or self self.slice = slice self._shape = Shape(shape) self.dtype = dtype self.frame_decorator = frame_decorator - - self.transform = transform - self.drift = drift - self.beadfile = beadfile - + self.transform = Transforms() self.flags = dict(C_CONTIGUOUS=False, F_CONTIGUOUS=False, OWNDATA=False, WRITEABLE=False, ALIGNED=False, WRITEBACKIFCOPY=False, UPDATEIFCOPY=False) @@ -426,7 +243,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): def __getitem__(self, n): """ slice like a numpy array but return an Imread instance """ if self.isclosed: - raise IOError("file is closed") + raise OSError("file is closed") if isinstance(n, (slice, Number)): # None = : n = (n,) elif isinstance(n, type(Ellipsis)): @@ -644,9 +461,6 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): cfun(*tmps)).astype(p.sub('', dtype.name)) return out - def __framet__(self, c, z, t): - return self.transform_frame(self.__frame__(c, z, t), c, t) - @property def axes(self): return self.shape.axes @@ -677,7 +491,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): return try: return self.get_config(pname) - except (Exception,): + except Exception: # noqa return return @@ -704,11 +518,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): def summary(self): """ gives a helpful summary of the recorded experiment """ s = [f"path/filename: {self.path}", - f"series/pos: {self.series}"] - if isinstance(self, View): - s.append(f"reader: {self.base.__class__.__module__.split('.')[-1]} view") - else: - s.append(f"reader: {self.__class__.__module__.split('.')[-1]} base") + f"series/pos: {self.series}", + f"reader: {self.base.__class__.__module__.split('.')[-1]}"] s.extend((f"dtype: {self.dtype}", f"shape ({self.axes}):".ljust(15) + f"{' x '.join(str(i) for i in self.shape)}")) if self.pxsize_um: @@ -728,15 +539,15 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): s.append('laser powers: ' + ' | '.join([' & '.join(len(p) * ('{:.3g}',)).format(*[100 * i for i in p]) for p in self.laserpowers]) + ' %') if self.objective: - s.append('objective: {}'.format(self.objective.model)) + s.append(f'objective: {self.objective.model}') if self.magnification: - s.append('magnification: {}x'.format(self.magnification)) + s.append(f'magnification: {self.magnification}x') if self.tubelens: - s.append('tubelens: {}'.format(self.tubelens.model)) + s.append(f'tubelens: {self.tubelens.model}') if self.filter: - s.append('filterset: {}'.format(self.filter)) + s.append(f'filterset: {self.filter}') if self.powermode: - s.append('powermode: {}'.format(self.powermode)) + s.append(f'powermode: {self.powermode}') if self.collimator: s.append('collimator: ' + (' {}' * len(self.collimator)).format(*self.collimator)) if self.tirfangle: @@ -881,9 +692,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): return new @wraps(np.transpose) - def transpose(self, axes=None): + def transpose(self, *axes): new = self.copy() - if axes is None: + if not axes: new.axes = new.axes[::-1] else: new.axes = ''.join(ax if isinstance(ax, str) else new.axes[ax] for ax in axes) @@ -918,8 +729,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): def block(self, x=None, y=None, c=None, z=None, t=None): """ returns 5D block of frames """ - x, y, c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) - for i, e in zip('xyczt', (x, y, c, z, t))] + x, y, c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) + for i, e in zip('xyczt', (x, y, c, z, t))) d = np.empty((len(x), len(y), len(c), len(z), len(t)), self.dtype) for (ci, cj), (zi, zj), (ti, tj) in product(enumerate(c), enumerate(z), enumerate(t)): d[:, :, ci, zi, ti] = self.frame(cj, zj, tj)[x][:, y] @@ -930,7 +741,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): def data(self, c=0, z=0, t=0): """ returns 3D stack of frames """ - c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t))] + c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t))) return np.dstack([self.frame(ci, zi, ti) for ci, zi, ti in product(c, z, t)]) def frame(self, c=0, z=0, t=0): @@ -946,7 +757,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): self.cache.move_to_end(key) f = self.cache[key] else: - f = self.__framet__(c, z, t) + f = self.transform[self.channel_names[c], t].frame(self.__frame__(c, z, t)) if self.frame_decorator is not None: f = self.frame_decorator(self, f, c, z, t) self.cache[key] = f @@ -959,9 +770,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): if not isinstance(channel_name, str): return channel_name else: - c = [i for i, c in enumerate(self.cnamelist) if c.lower().startswith(channel_name.lower())] - assert len(c) > 0, 'Channel {} not found in {}'.format(c, self.cnamelist) - assert len(c) < 2, 'Channel {} not unique in {}'.format(c, self.cnamelist) + c = [i for i, c in enumerate(self.channel_names) if c.lower().startswith(channel_name.lower())] + assert len(c) > 0, f'Channel {c} not found in {self.channel_names}' + assert len(c) < 2, f'Channel {c} not unique in {self.channel_names}' return c[0] @staticmethod @@ -978,7 +789,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): |[-+]?\\.(?:inf|Inf|INF) |\.(?:nan|NaN|NAN))$''', re.X), list(r'-+0123456789.')) - with open(file, 'r') as f: + with open(file) as f: return yaml.load(f, loader) def get_czt(self, c, z, t): @@ -995,7 +806,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): stop = n.stop czt.append(list(range(n.start % self.shape[i], stop, n.step))) elif isinstance(n, Number): - czt.append([n % self.shape[i]]) + czt.append([n % self.shape[i]]) # noqa else: czt.append([k % self.shape[i] for k in n]) return [self.get_channel(c) for c in czt[0]], *czt[1:] @@ -1055,28 +866,42 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): shape = [len(i) for i in n] with TransformTiff(self, fname.with_suffix('.tif'), shape, pixel_type, pxsize=self.pxsize_um, deltaz=self.deltaz_um, **kwargs) as tif: - for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)), + for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)), # noqa total=np.prod(shape), desc='Saving tiff', disable=not bar): tif.save(m, *i) - def set_transform(self): - # handle transforms - if self.transform is False or self.transform is None: - self.transform = None - else: - if isinstance(self.transform, Transforms): - self.transform = self.transform - else: - if isinstance(self.transform, str): - self.transform = ImTransforms(self.path, self.cyllens, self.transform) - else: - self.transform = ImTransforms(self.path, self.cyllens, self.beadfile) - if self.drift is True: - self.transform = ImShiftTransforms(self) - elif not (self.drift is False or self.drift is None): - self.transform = ImShiftTransforms(self, self.drift) - self.transform.adapt(self.frameoffset, self.shape.xyczt) - self.beadfile = self.transform.files + def with_transform(self, channels=True, drift=False, file=None, bead_files=()): + """ returns a view where channels and/or frames are registered with an affine transformation + channels: True/False register channels using bead_files + drift: True/False register frames to correct drift + file: load registration from file with name file, default: transform.yml in self.path.parent + bead_files: files used to register channels, default: files in self.path.parent, + with names starting with 'beads' + """ + view = self.view() + if file is None: + file = Path(view.path.parent) / 'transform.yml' + if not bead_files: + bead_files = Transforms.get_bead_files(view.path.parent) + + if channels: + try: + view.transform = Transforms.from_file(file, T=drift) + # for key in view.channel_names: + # if + except Exception: # noqa + view.transform = Transforms().with_beads(view.cyllens, bead_files) + if drift: + view.transform = view.transform.with_drift(view) + view.transform.save(file.with_suffix('.yml')) + view.transform.save_channel_transform_tiff(bead_files, file.with_suffix('.tif')) + elif drift: + try: + view.transform = Transforms.from_file(file, C=False) + except Exception: # noqa + view.transform = Transforms().with_drift(self) + view.transform.adapt(view.frameoffset, view.shape.xyczt, view.channel_names) + return view @staticmethod def split_path_series(path): @@ -1086,21 +911,14 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin): return path.parent, int(path.name.lstrip('Pos')) return path, 0 - def transform_frame(self, frame, c, t=0): - if self.transform is None: - return frame - else: - return self.transform[(self.cnamelist[c],)].frame(frame) - def view(self, *args, **kwargs): return View(self, *args, **kwargs) -class View(Imread): - def __init__(self, base, dtype=None, transform=None, drift=None, beadfile=None): - super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, transform or base.transform, - drift or base.drift, beadfile or base.beadfile, base.frame_decorator) - self.set_transform() +class View(Imread, ABC): + def __init__(self, base, dtype=None): + super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, base.frame_decorator) + self.transform = base.transform def __getattr__(self, item): if not hasattr(self.base, item): @@ -1109,60 +927,6 @@ class View(Imread): class AbstractReader(Imread, metaclass=ABCMeta): - """ class to read image files, while taking good care of important metadata, - currently optimized for .czi files, but can open anything that bioformats can handle - path: path to the image file - optional: - series: in case multiple experiments are saved in one file, like in .lif files - transform: automatically correct warping between channels, need transforms.py among others - drift: automatically correct for drift, only works if transform is not None or False - beadfile: image file(s) with beads which can be used for correcting warp - dtype: datatype to be used when returning frames - meta: define metadata, used for pickle-ing - - NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate - - modify images on the fly with a decorator function: - define a function which takes an instance of this object, one image frame, - and the coordinates c, z, t as arguments, and one image frame as return - >> imread.frame_decorator = fun - then use imread as usually - - Examples: - >> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi') - >> im - << shows summary - >> im.shape - << (256, 256, 2, 1, 600) - >> plt.imshow(im(1, 0, 100)) - << plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array - with 1 frame - >> data = im[:,:,0,0,:25] - << retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array - >> plt.imshow(im.max(0, None, 0)) - << plots max-z projection at c=0, t=0 - >> len(im) - << total number of frames - >> im.pxsize - << 0.09708737864077668 image-plane pixel size in um - >> im.laserwavelengths - << [642, 488] - >> im.laserpowers - << [0.02, 0.0005] in % - - See __init__ and other functions for more ideas. - - Subclassing: - Subclass this class to add more file types. A subclass should always have at least the following methods: - staticmethod _can_open(path): returns True when the subclass can open the image in path - __metadata__(self): pulls some metadata from the file and do other format specific things, it needs to - define a few properties, like shape, etc. - __frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t - optional close(self): close the file in a proper way - optional field priority: subclasses with lower priority will be tried first, default = 99 - Any other method can be overridden as needed - wp@tl2019-2023 """ - priority = 99 do_not_pickle = 'cache' ureg = ureg @@ -1187,10 +951,10 @@ class AbstractReader(Imread, metaclass=ABCMeta): def close(self): # Optionally override this, close file handles etc. return - def __init__(self, path, transform=False, drift=False, beadfile=None, dtype=None, axes=None): + def __init__(self, path, dtype=None, axes=None): if isinstance(path, Imread): return - super().__init__(self, transform=transform, drift=drift, beadfile=beadfile) + super().__init__() self.isclosed = False if isinstance(path, str): path = Path(path) @@ -1248,7 +1012,9 @@ class AbstractReader(Imread, metaclass=ABCMeta): self.pxsize *= self.binning[0] except (AttributeError, IndexError, ValueError): self.binning = None - self.cnamelist = [channel.name for channel in image.pixels.channels] + self.channel_names = [channel.name for channel in image.pixels.channels] + self.channel_names += [chr(97 + i) for i in range(len(self.channel_names), self.shape['c'])] + self.cnamelist = self.channel_names try: optovars = [objective for objective in instrument.objectives if 'tubelens' in objective.id.lower()] except AttributeError: @@ -1297,19 +1063,18 @@ class AbstractReader(Imread, metaclass=ABCMeta): self.feedback = m['FeedbackChannels'] else: self.feedback = m['FeedbackChannel'] - except (Exception,): + except Exception: # noqa self.cyllens = ['None', 'None'] self.duolink = 'None' self.feedback = [] try: self.cyllenschannels = np.where([self.cyllens[self.detector[c]].lower() != 'none' for c in range(self.shape['c'])])[0].tolist() - except (Exception,): + except Exception: # noqa pass - self.set_transform() try: s = int(re.findall(r'_(\d{3})_', self.duolink)[0]) * ureg.nm - except (Exception,): + except Exception: # noqa s = 561 * ureg.nm try: sigma = [] @@ -1320,7 +1085,7 @@ class AbstractReader(Imread, metaclass=ABCMeta): sigma[sigma == 0] = 600 * ureg.nm sigma /= 2 * self.NA * self.pxsize self.sigma = sigma.magnitude.tolist() - except (Exception,): + except Exception: # noqa self.sigma = [2] * self.shape['c'] if not self.NA: self.immersionN = 1 @@ -1341,6 +1106,7 @@ class AbstractReader(Imread, metaclass=ABCMeta): except Exception: pass + def main(): parser = ArgumentParser(description='Display info and save as tif') parser.add_argument('file', help='image_file') @@ -1353,13 +1119,15 @@ def main(): parser.add_argument('-f', '--force', help='force overwrite', action='store_true') args = parser.parse_args() - with Imread(args.file, transform=args.register) as im: + with Imread(args.file) as im: + if args.register: + im = im.with_transform() print(im.summary) if args.out: out = Path(args.out).absolute() out.parent.mkdir(parents=True, exist_ok=True) if out.exists() and not args.force: - print('File {} exists already, add the -f flag if you want to overwrite it.'.format(args.out)) + print(f'File {args.out} exists already, add the -f flag if you want to overwrite it.') else: im.save_as_tiff(out, args.channel, args.zslice, args.time, args.split) diff --git a/ndbioimage/jvm.py b/ndbioimage/jvm.py index 0c6494d..85ecacf 100644 --- a/ndbioimage/jvm.py +++ b/ndbioimage/jvm.py @@ -19,27 +19,27 @@ try: def __init__(self, jars=None): if not self.vm_started and not self.vm_killed: try: - jarpath = Path(__file__).parent / 'jars' + jar_path = Path(__file__).parent / 'jars' if jars is None: jars = {} for jar, src in jars.items(): - if not (jarpath / jar).exists(): - JVM.download(src, jarpath / jar) - classpath = [str(jarpath / jar) for jar in jars.keys()] + if not (jar_path / jar).exists(): + JVM.download(src, jar_path / jar) + classpath = [str(jar_path / jar) for jar in jars.keys()] import jpype jpype.startJVM(classpath=classpath) - except Exception: + except Exception: # noqa self.vm_started = False else: self.vm_started = True try: import jpype.imports - from loci.common import DebugTools - from loci.formats import ImageReader - from loci.formats import ChannelSeparator - from loci.formats import FormatTools - from loci.formats import MetadataTools + from loci.common import DebugTools # noqa + from loci.formats import ImageReader # noqa + from loci.formats import ChannelSeparator # noqa + from loci.formats import FormatTools # noqa + from loci.formats import MetadataTools # noqa DebugTools.setRootLevel("ERROR") @@ -47,7 +47,7 @@ try: self.channel_separator = ChannelSeparator self.format_tools = FormatTools self.metadata_tools = MetadataTools - except Exception: + except Exception: # noqa pass if self.vm_killed: @@ -64,7 +64,7 @@ try: self = cls._instance if self is not None and self.vm_started and not self.vm_killed: import jpype - jpype.shutdownJVM() + jpype.shutdownJVM() # noqa self.vm_started = False self.vm_killed = True diff --git a/ndbioimage/readers/bfread.py b/ndbioimage/readers/bfread.py index bd4864e..dc62689 100644 --- a/ndbioimage/readers/bfread.py +++ b/ndbioimage/readers/bfread.py @@ -153,7 +153,7 @@ class JVMReader: self.queue_out.put(image[..., c]) else: self.queue_out.put(image) - except queues.Empty: + except queues.Empty: # noqa continue except (Exception,): print_exc() @@ -171,7 +171,7 @@ def can_open(path): except (Exception,): return False finally: - jvm.kill_vm() + jvm.kill_vm() # noqa class Reader(AbstractReader, ABC): diff --git a/ndbioimage/readers/cziread.py b/ndbioimage/readers/cziread.py index dc138b9..9bb540f 100644 --- a/ndbioimage/readers/cziread.py +++ b/ndbioimage/readers/cziread.py @@ -31,7 +31,7 @@ class Reader(AbstractReader, ABC): filedict[c, z, t].append(directory_entry) else: filedict[c, z, t] = [directory_entry] - self.filedict = filedict + self.filedict = filedict # noqa def close(self): self.reader.close() @@ -116,8 +116,8 @@ class Reader(AbstractReader, ABC): y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]]) size_x = x_max - x_min size_y = y_max - y_min - size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)] - for directory_entry in 'CZT'] + size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)] + for directory_entry in 'CZT') image = information.find("Image") pixel_type = text(image.find("PixelType"), "Gray16") @@ -277,23 +277,23 @@ class Reader(AbstractReader, ABC): text(light_source.find("LightSourceType").find("Laser").find("Wavelength"))))) multi_track_setup = acquisition_block.find("MultiTrackSetup") - for idx, tube_lens in enumerate(set(text(track_setup.find("TubeLensPosition")) - for track_setup in multi_track_setup)): + for idx, tube_lens in enumerate({text(track_setup.find("TubeLensPosition")) + for track_setup in multi_track_setup}): ome.instruments[0].objectives.append( model.Objective(id=f"Objective:Tubelens:{idx}", model=tube_lens, nominal_magnification=float( re.findall(r'\d+[,.]\d*', tube_lens)[0].replace(',', '.')) )) - for idx, filter_ in enumerate(set(text(beam_splitter.find("Filter")) - for track_setup in multi_track_setup - for beam_splitter in track_setup.find("BeamSplitters"))): + for idx, filter_ in enumerate({text(beam_splitter.find("Filter")) + for track_setup in multi_track_setup + for beam_splitter in track_setup.find("BeamSplitters")}): ome.instruments[0].filter_sets.append( model.FilterSet(id=f"FilterSet:{idx}", model=filter_) ) - for idx, collimator in enumerate(set(text(track_setup.find("FWFOVPosition")) - for track_setup in multi_track_setup)): + for idx, collimator in enumerate({text(track_setup.find("FWFOVPosition")) + for track_setup in multi_track_setup}): ome.instruments[0].filters.append(model.Filter(id=f"Filter:Collimator:{idx}", model=collimator)) x_min = min([f.start[f.axes.index('X')] for f in self.filedict[0, 0, 0]]) @@ -302,8 +302,8 @@ class Reader(AbstractReader, ABC): y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]]) size_x = x_max - x_min size_y = y_max - y_min - size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)] - for directory_entry in 'CZT'] + size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)] + for directory_entry in 'CZT') image = information.find("Image") pixel_type = text(image.find("PixelType"), "Gray16") diff --git a/ndbioimage/readers/fijiread.py b/ndbioimage/readers/fijiread.py index 26083b0..3d513fd 100644 --- a/ndbioimage/readers/fijiread.py +++ b/ndbioimage/readers/fijiread.py @@ -32,10 +32,10 @@ class Reader(AbstractReader, ABC): self.reader = TiffFile(self.path) assert self.reader.pages[0].compression == 1, "Can only read uncompressed tiff files." assert self.reader.pages[0].samplesperpixel == 1, "Can only read 1 sample per pixel." - self.offset = self.reader.pages[0].dataoffsets[0] - self.count = self.reader.pages[0].databytecounts[0] - self.bytes_per_sample = self.reader.pages[0].bitspersample // 8 - self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1] + self.offset = self.reader.pages[0].dataoffsets[0] # noqa + self.count = self.reader.pages[0].databytecounts[0] # noqa + self.bytes_per_sample = self.reader.pages[0].bitspersample // 8 # noqa + self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1] # noqa def close(self): self.reader.close() diff --git a/ndbioimage/readers/ndread.py b/ndbioimage/readers/ndread.py index cc2b303..9c64525 100644 --- a/ndbioimage/readers/ndread.py +++ b/ndbioimage/readers/ndread.py @@ -15,7 +15,7 @@ class Reader(AbstractReader, ABC): @cached_property def ome(self): - def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1): + def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1): # noqa return size_x, size_y, size_c, size_z, size_t size_x, size_y, size_c, size_z, size_t = shape(*self.array.shape) try: @@ -42,7 +42,7 @@ class Reader(AbstractReader, ABC): if isinstance(self.path, np.ndarray): self.array = np.array(self.path) while self.array.ndim < 5: - self.array = np.expand_dims(self.array, -1) + self.array = np.expand_dims(self.array, -1) # noqa self.path = 'numpy array' def __frame__(self, c, z, t): diff --git a/ndbioimage/readers/seqread.py b/ndbioimage/readers/seqread.py index addbb37..07971a4 100644 --- a/ndbioimage/readers/seqread.py +++ b/ndbioimage/readers/seqread.py @@ -4,7 +4,7 @@ import re from pathlib import Path from functools import cached_property from ome_types import model -from ome_types.units import _quantity_property +from ome_types.units import _quantity_property # noqa from itertools import product from datetime import datetime from abc import ABC @@ -15,7 +15,10 @@ def lazy_property(function, field, *arg_fields): def lazy(self): if self.__dict__.get(field) is None: self.__dict__[field] = function(*[getattr(self, arg_field) for arg_field in arg_fields]) - self.model_fields_set.add(field) + try: + self.model_fields_set.add(field) + except Exception: # noqa + pass return self.__dict__[field] return property(lazy) @@ -24,6 +27,7 @@ class Plane(model.Plane): """ Lazily retrieve delta_t from metadata """ def __init__(self, t0, file, **kwargs): super().__init__(**kwargs) + # setting fields here because they would be removed by ome_types/pydantic after class definition setattr(self.__class__, 'delta_t', lazy_property(self.get_delta_t, 'delta_t', 't0', 'file')) setattr(self.__class__, 'delta_t_quantity', _quantity_property('delta_t')) self.__dict__['t0'] = t0 @@ -79,7 +83,7 @@ class Reader(AbstractReader, ABC): else: pixel_type = "uint16" # assume - size_c, size_z, size_t = [max(i) + 1 for i in zip(*self.filedict.keys())] + size_c, size_z, size_t = (max(i) + 1 for i in zip(*self.filedict.keys())) t0 = datetime.strptime(metadata["Info"]["Time"], "%Y-%m-%d %H:%M:%S %z") ome.images.append( model.Image( @@ -123,7 +127,7 @@ class Reader(AbstractReader, ABC): pattern_c = re.compile(r"img_\d{3,}_(.*)_\d{3,}$") pattern_z = re.compile(r"(\d{3,})$") pattern_t = re.compile(r"img_(\d{3,})") - self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]), + self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]), # noqa int(pattern_z.findall(file.stem)[0]), int(pattern_t.findall(file.stem)[0])): file for file in filelist} diff --git a/ndbioimage/readers/tifread.py b/ndbioimage/readers/tifread.py index 0bb2fd9..6ea2934 100644 --- a/ndbioimage/readers/tifread.py +++ b/ndbioimage/readers/tifread.py @@ -17,7 +17,7 @@ class Reader(AbstractReader, ABC): def _can_open(path): if isinstance(path, Path) and path.suffix in ('.tif', '.tiff'): with tifffile.TiffFile(path) as tif: - return tif.is_imagej and tif.pages[-1]._nextifd() == 0 + return tif.is_imagej and tif.pages[-1]._nextifd() == 0 # noqa else: return False @@ -27,12 +27,12 @@ class Reader(AbstractReader, ABC): for key, value in self.reader.imagej_metadata.items()} page = self.reader.pages[0] - self.p_ndim = page.ndim + self.p_ndim = page.ndim # noqa size_x = page.imagelength size_y = page.imagewidth if self.p_ndim == 3: size_c = page.samplesperpixel - self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0] + self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0] # noqa size_t = metadata.get('frames', 1) # // C else: size_c = metadata.get('channels', 1) diff --git a/ndbioimage/transforms.py b/ndbioimage/transforms.py index 8972325..15c1df7 100644 --- a/ndbioimage/transforms.py +++ b/ndbioimage/transforms.py @@ -1,19 +1,24 @@ -import yaml -import re -import numpy as np +import warnings from copy import deepcopy from pathlib import Path +import numpy as np +import yaml +from parfor import pmap, Chunks +from skimage import filters +from tiffwrite import IJTiffFile +from tqdm.auto import tqdm + try: # best if SimpleElastix is installed: https://simpleelastix.readthedocs.io/GettingStarted.html - import SimpleITK as sitk + import SimpleITK as sitk # noqa except ImportError: sitk = None try: - from pandas import DataFrame, Series + from pandas import DataFrame, Series, concat except ImportError: - DataFrame, Series = None, None + DataFrame, Series, concat = None, None, None if hasattr(yaml, 'full_load'): @@ -24,10 +29,30 @@ else: class Transforms(dict): def __init__(self, *args, **kwargs): - super().__init__(*args[1:], **kwargs) + super().__init__(*args, **kwargs) self.default = Transform() - if len(args): - self.load(args[0]) + + @classmethod + def from_file(cls, file, C=True, T=True): + with open(Path(file).with_suffix('.yml')) as f: + return cls.from_dict(yamlload(f), C, T) + + @classmethod + def from_dict(cls, d, C=True, T=True): + new = cls() + for key, value in d.items(): + if isinstance(key, str) and C: + new[key.replace(r'\:', ':').replace('\\\\', '\\')] = Transform.from_dict(value) + elif T: + new[key] = Transform.from_dict(value) + return new + + @classmethod + def from_shifts(cls, shifts): + new = cls() + for key, shift in shifts.items(): + new[key] = Transform.from_shift(shift) + return new def __mul__(self, other): new = Transforms() @@ -44,18 +69,11 @@ class Transforms(dict): return new def asdict(self): - return {':'.join(str(i).replace('\\', '\\\\').replace(':', r'\:') for i in key): value.asdict() + return {key.replace('\\', '\\\\').replace(':', r'\:') if isinstance(key, str) else key: value.asdict() for key, value in self.items()} - def load(self, file): - if isinstance(file, dict): - d = file - else: - with open(file.with_suffix(".yml"), 'r') as f: - d = yamlload(f) - pattern = re.compile(r'[^\\]:') - for key, value in d.items(): - self[tuple(i.replace(r'\:', ':').replace('\\\\', '\\') for i in pattern.split(key))] = Transform(value) + def __getitem__(self, item): + return np.prod([self[i] for i in item[::-1]]) if isinstance(item, tuple) else super().__getitem__(item) def __missing__(self, key): return self.default @@ -70,47 +88,242 @@ class Transforms(dict): return hash(frozenset((*self.__dict__.items(), *self.items()))) def save(self, file): - with open(file.with_suffix(".yml"), 'w') as f: + with open(Path(file).with_suffix(".yml"), 'w') as f: yaml.safe_dump(self.asdict(), f, default_flow_style=None) def copy(self): return deepcopy(self) - def adapt(self, origin, shape): + def adapt(self, origin, shape, channel_names): + def key_map(a, b): + def fun(b, key_a): + for key_b in b: + if key_b in key_a or key_a in key_b: + return key_a, key_b + + return {n[0]: n[1] for key_a in a if (n := fun(b, key_a))} + for value in self.values(): value.adapt(origin, shape) self.default.adapt(origin, shape) + transform_channels = {key for key in self.keys() if isinstance(key, str)} + if set(channel_names) - transform_channels: + mapping = key_map(channel_names, transform_channels) + warnings.warn(f'The image file and the transform do not have the same channels,' + f' creating a mapping: {mapping}') + for key_im, key_t in mapping.items(): + self[key_im] = self[key_t] @property def inverse(self): + # TODO: check for C@T inverse = self.copy() for key, value in self.items(): inverse[key] = value.inverse return inverse - @property - def ndim(self): - return len(list(self.keys())[0]) + def coords_pandas(self, array, channel_names, columns=None): + if isinstance(array, DataFrame): + return concat([self.coords_pandas(row, channel_names, columns) for _, row in array.iterrows()], axis=1).T + elif isinstance(array, Series): + key = [] + if 'C' in array: + key.append(channel_names[int(array['C'])]) + if 'T' in array: + key.append(int(array['T'])) + return self[tuple(key)].coords(array, columns) + else: + raise TypeError('Not a pandas DataFrame or Series.') + + def with_beads(self, cyllens, bead_files): + assert len(bead_files) > 0, "At least one file is needed to calculate the registration." + transforms = [self.calculate_channel_transforms(file, cyllens) for file in bead_files] + for key in {key for transform in transforms for key in transform.keys()}: + new_transforms = [transform[key] for transform in transforms if key in transform] + if len(new_transforms) == 1: + self[key] = new_transforms[0] + else: + self[key] = Transform() + self[key].parameters = np.mean([t.parameters for t in new_transforms], 0) + self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) / + np.sqrt(len(new_transforms))).tolist() + return self + + @staticmethod + def get_bead_files(path): + from . import Imread + files = [] + for file in path.iterdir(): + if file.name.lower().startswith('beads'): + try: + with Imread(file): + files.append(file) + except Exception: + pass + files = sorted(files) + if not files: + raise Exception('No bead file found!') + checked_files = [] + for file in files: + try: + if file.is_dir(): + file /= 'Pos0' + with Imread(file): # check for errors opening the file + checked_files.append(file) + except (Exception,): + continue + if not checked_files: + raise Exception('No bead file found!') + return checked_files + + @staticmethod + def calculate_channel_transforms(bead_file, cyllens): + """ When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162 + in the horizontal direction """ + from . import Imread + + with Imread(bead_file, axes='zcxy') as im: # noqa + max_ims = im.max('z') + goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)] + if not goodch: + goodch = list(range(len(max_ims))) + untransformed = [c for c in range(im.shape['c']) if cyllens[im.detector[c]].lower() == 'none'] + + good_and_untrans = sorted(set(goodch) & set(untransformed)) + if good_and_untrans: + masterch = good_and_untrans[0] + else: + masterch = goodch[0] + transform = Transform() + if not good_and_untrans: + matrix = transform.matrix + matrix[0, 0] = 0.86 + transform.matrix = matrix + transforms = Transforms() + for c in tqdm(goodch, desc='Calculating channel transforms'): # noqa + if c == masterch: + transforms[im.channel_names[c]] = transform + else: + transforms[im.channel_names[c]] = Transform.register(max_ims[masterch], max_ims[c]) * transform + return transforms + + @staticmethod + def save_channel_transform_tiff(bead_files, tiffile): + from . import Imread + n_channels = 0 + for file in bead_files: + with Imread(file) as im: + n_channels = max(n_channels, im.shape['c']) + with IJTiffFile(tiffile, (n_channels, 1, len(bead_files))) as tif: + for t, file in enumerate(bead_files): + with Imread(file) as im: + with Imread(file).with_transform() as jm: + for c in range(im.shape['c']): + tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t) + + def with_drift(self, im): + """ Calculate shifts relative to the first frame + divide the sequence into groups, + compare each frame to the frame in the middle of the group and compare these middle frames to each other + """ + im = im.transpose('tzycx') + t_groups = [list(chunk) for chunk in Chunks(range(im.shape['t']), size=round(np.sqrt(im.shape['t'])))] + t_keys = [int(np.round(np.mean(t_group))) for t_group in t_groups] + t_pairs = [(int(np.round(np.mean(t_group))), frame) for t_group in t_groups for frame in t_group] + t_pairs.extend(zip(t_keys, t_keys[1:])) + fmaxz_keys = {t_key: filters.gaussian(im[t_key].max('z'), 5) for t_key in t_keys} + + def fun(t_key_t, im, fmaxz_keys): + t_key, t = t_key_t + if t_key == t: + return 0, 0 + else: + fmaxz = filters.gaussian(im[t].max('z'), 5) + return Transform.register(fmaxz_keys[t_key], fmaxz, 'translation').parameters[4:] + + shifts = np.array(pmap(fun, t_pairs, (im, fmaxz_keys), desc='Calculating image shifts.')) + shift_keys_cum = np.zeros(2) + for shift_keys, t_group in zip(np.vstack((-shifts[0], shifts[im.shape['t']:])), t_groups): + shift_keys_cum += shift_keys + shifts[t_group] += shift_keys_cum + + for i, shift in enumerate(shifts[:im.shape['t']]): + self[i] = Transform.from_shift(shift) + return self class Transform: - def __init__(self, *args): + def __init__(self): if sitk is None: raise ImportError('SimpleElastix is not installed: ' 'https://simpleelastix.readthedocs.io/GettingStarted.html') self.transform = sitk.ReadTransform(str(Path(__file__).parent / 'transform.txt')) - self.dparameters = 0, 0, 0, 0, 0, 0 - self.shape = 512, 512 - self.origin = 255.5, 255.5 - if len(args) == 1: # load from file or dict - if isinstance(args[0], np.ndarray): - self.matrix = args[0] - else: - self.load(*args) - elif len(args) > 1: # make new transform using fixed and moving image - self.register(*args) + self.dparameters = [0., 0., 0., 0., 0., 0.] + self.shape = [512., 512.] + self.origin = [255.5, 255.5] self._last, self._inverse = None, None + def __reduce__(self): + return self.from_dict, (self.asdict(),) + + def __repr__(self): + return self.asdict().__repr__() + + def __str__(self): + return self.asdict().__str__() + + @classmethod + def register(cls, fix, mov, kind=None): + """ kind: 'affine', 'translation', 'rigid' """ + new = cls() + kind = kind or 'affine' + new.shape = fix.shape + fix, mov = new.cast_image(fix), new.cast_image(mov) + # TODO: implement RigidTransform + tfilter = sitk.ElastixImageFilter() + tfilter.LogToConsoleOff() + tfilter.SetFixedImage(fix) + tfilter.SetMovingImage(mov) + tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind)) + tfilter.Execute() + transform = tfilter.GetTransformParameterMap()[0] + if kind == 'affine': + new.parameters = [float(t) for t in transform['TransformParameters']] + new.shape = [float(t) for t in transform['Size']] + new.origin = [float(t) for t in transform['CenterOfRotationPoint']] + elif kind == 'translation': + new.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']] + new.shape = [float(t) for t in transform['Size']] + new.origin = [(t - 1) / 2 for t in new.shape] + else: + raise NotImplementedError(f'{kind} tranforms not implemented (yet)') + new.dparameters = 6 * [np.nan] + return new + + @classmethod + def from_shift(cls, shift): + return cls.from_array(np.array(((1, 0, shift[0]), (0, 1, shift[1]), (0, 0, 1)))) + + @classmethod + def from_array(cls, array): + new = cls() + new.matrix = array + return new + + @classmethod + def from_file(cls, file): + with open(Path(file).with_suffix('.yml')) as f: + return cls.from_dict(yamlload(f)) + + @classmethod + def from_dict(cls, d): + new = cls() + new.origin = [float(i) for i in d['CenterOfRotationPoint']] + new.parameters = [float(i) for i in d['TransformParameters']] + new.dparameters = [float(i) for i in d['dTransformParameters']] if 'dTransformParameters' in d else 6 * [np.nan] + new.shape = [float(i) for i in d['Size']] + return new + def __mul__(self, other): # TODO: take care of dmatrix result = self.copy() if isinstance(other, Transform): @@ -121,9 +334,6 @@ class Transform: result.dmatrix = self.dmatrix @ other return result - def __reduce__(self): - return self.__class__, (self.asdict(),) - def is_unity(self): return self.parameters == [1, 0, 0, 1, 0, 0] @@ -166,7 +376,7 @@ class Transform: @property def parameters(self): - return self.transform.GetParameters() + return list(self.transform.GetParameters()) @parameters.setter def parameters(self, value): @@ -186,7 +396,7 @@ class Transform: def inverse(self): if self._last is None or self._last != self.asdict(): self._last = self.asdict() - self._inverse = Transform(self.asdict()) + self._inverse = Transform.from_dict(self.asdict()) self._inverse.transform = self._inverse.transform.GetInverse() self._inverse._last = self._inverse.asdict() self._inverse._inverse = self @@ -234,45 +444,3 @@ class Transform: file += '.yml' with open(file, 'w') as f: yaml.safe_dump(self.asdict(), f, default_flow_style=None) - - def load(self, file): - """ load the parameters of a transform from a yaml file or a dict - """ - if isinstance(file, dict): - d = file - else: - if not file[-3:] == 'yml': - file += '.yml' - with open(file, 'r') as f: - d = yamlload(f) - self.origin = [float(i) for i in d['CenterOfRotationPoint']] - self.parameters = [float(i) for i in d['TransformParameters']] - self.dparameters = [float(i) for i in d['dTransformParameters']] \ - if 'dTransformParameters' in d else 6 * [np.nan] - self.shape = [float(i) for i in d['Size']] - - def register(self, fix, mov, kind=None): - """ kind: 'affine', 'translation', 'rigid' - """ - kind = kind or 'affine' - self.shape = fix.shape - fix, mov = self.cast_image(fix), self.cast_image(mov) - # TODO: implement RigidTransform - tfilter = sitk.ElastixImageFilter() - tfilter.LogToConsoleOff() - tfilter.SetFixedImage(fix) - tfilter.SetMovingImage(mov) - tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind)) - tfilter.Execute() - transform = tfilter.GetTransformParameterMap()[0] - if kind == 'affine': - self.parameters = [float(t) for t in transform['TransformParameters']] - self.shape = [float(t) for t in transform['Size']] - self.origin = [float(t) for t in transform['CenterOfRotationPoint']] - elif kind == 'translation': - self.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']] - self.shape = [float(t) for t in transform['Size']] - self.origin = [(t - 1) / 2 for t in self.shape] - else: - raise NotImplementedError(f'{kind} tranforms not implemented (yet)') - self.dparameters = 6 * [np.nan] diff --git a/pyproject.toml b/pyproject.toml index 84605e0..72e3085 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ndbioimage" -version = "2023.10.1" +version = "2023.10.2" description = "Bio image reading, metadata and some affine registration." authors = ["W. Pomp "] license = "GPLv3" @@ -22,7 +22,7 @@ pint = "*" tqdm = "*" lxml = "*" pyyaml = "*" -parfor = ">=2023.8.2" +parfor = ">=2023.10.1" JPype1 = "*" SimpleITK-SimpleElastix = "*" pytest = { version = "*", optional = true }