- Improve transforms, implement drift correction transforms, require from_ methods for instantiation.

- Replace transform, drift and beadfile arguments for Imread by with_transform method.
- Bring Imread.transpose in line with numpy.transpose.
- Fix seqread.lazy_property.
This commit is contained in:
Wim Pomp
2023-11-02 15:05:40 +01:00
parent 2e56f45f3e
commit d13b702481
10 changed files with 447 additions and 507 deletions

View File

@@ -1,33 +1,32 @@
import multiprocessing
import re
import warnings
import pandas
import yaml
import numpy as np
import multiprocessing
import ome_types
from ome_types import ureg, model
from pint import set_application_registry
from datetime import datetime
from tqdm.auto import tqdm
from itertools import product
from collections import OrderedDict
from abc import ABCMeta, abstractmethod
from functools import cached_property, wraps
from parfor import parfor
from tiffwrite import IJTiffFile
from numbers import Number
from abc import ABC, ABCMeta, abstractmethod
from argparse import ArgumentParser
from pathlib import Path
from collections import OrderedDict
from datetime import datetime
from functools import cached_property, wraps
from importlib.metadata import version
from traceback import print_exc
from itertools import product
from numbers import Number
from operator import truediv
from .transforms import Transform, Transforms
from pathlib import Path
from traceback import print_exc
import numpy as np
import ome_types
import yaml
from ome_types import model, ureg
from pint import set_application_registry
from tiffwrite import IJTiffFile
from tqdm.auto import tqdm
from .jvm import JVM
from .transforms import Transform, Transforms
try:
__version__ = version(Path(__file__).parent.name)
except (Exception,):
except Exception: # noqa
__version__ = 'unknown'
try:
@@ -35,7 +34,7 @@ try:
head = g.read().split(':')[1].strip()
with open(Path(__file__).parent.parent / '.git' / head) as h:
__git_commit_hash__ = h.read().rstrip('\n')
except (Exception,):
except Exception: # noqa
__git_commit_hash__ = 'unknown'
ureg.default_format = '~P'
@@ -57,235 +56,6 @@ class TransformTiff(IJTiffFile):
return super().compress_frame(np.asarray(self.image(*frame)).astype(self.dtype))
class ImTransforms(Transforms):
""" Transforms class with methods to calculate channel transforms from bead files etc. """
def __init__(self, path, cyllens, file=None, transforms=None):
super().__init__()
self.cyllens = tuple(cyllens)
if transforms is None:
# TODO: check this
if re.search(r'^Pos\d+', path.name):
self.path = path.parent.parent
else:
self.path = path.parent
if file is not None:
if isinstance(file, str) and file.lower().endswith('.yml'):
self.ymlpath = file
self.beadfile = None
else:
self.ymlpath = self.path / 'transform.yml'
self.beadfile = file
else:
self.ymlpath = self.path / 'transform.yml'
self.beadfile = None
self.tifpath = self.ymlpath.with_suffix('.tif')
try:
self.load(self.ymlpath)
except (Exception,):
print('No transform file found, trying to generate one.')
if not self.files:
raise FileNotFoundError('No bead files found to calculate the transform from.')
self.calculate_transforms()
self.save(self.ymlpath)
self.save_transform_tiff()
print(f'Saving transform in {self.ymlpath}.')
print(f'Please check the transform in {self.tifpath}.')
else: # load from dict transforms
self.path = path
self.beadfile = file
for i, (key, value) in enumerate(transforms.items()):
self[key] = Transform(value)
def coords_pandas(self, array, cnamelist, colums=None):
if isinstance(array, pandas.DataFrame):
return pandas.concat([self[(cnamelist[int(row['C'])],)].coords(row, colums)
for _, row in array.iterrows()], axis=1).T
elif isinstance(array, pandas.Series):
return self[(cnamelist[int(array['C'])],)].coords(array, colums)
else:
raise TypeError('Not a pandas DataFrame or Series.')
@cached_property
def files(self):
try:
if self.beadfile is None:
files = self.get_bead_files()
else:
files = self.beadfile
if isinstance(files, str):
files = (Path(files),)
elif isinstance(files, Path):
files = (files,)
return tuple(files)
except (Exception,):
return ()
def get_bead_files(self):
files = sorted([f for f in self.path.iterdir() if f.name.lower().startswith('beads')
and not f.suffix.lower() == '.pdf' and not f.suffix.lower() == 'pkl'])
if not files:
raise Exception('No bead file found!')
checked_files = []
for file in files:
try:
if file.is_dir():
file /= 'Pos0'
with Imread(file): # check for errors opening the file
checked_files.append(file)
except (Exception,):
continue
if not checked_files:
raise Exception('No bead file found!')
return checked_files
def calculate_transform(self, file):
""" When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162
in the horizontal direction """
with Imread(file, axes='zcxy') as im:
max_ims = im.max('z')
goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)]
if not goodch:
goodch = list(range(len(max_ims)))
untransformed = [c for c in range(im.shape['c']) if self.cyllens[im.detector[c]].lower() == 'none']
good_and_untrans = sorted(set(goodch) & set(untransformed))
if good_and_untrans:
masterch = good_and_untrans[0]
else:
masterch = goodch[0]
transform = Transform()
if not good_and_untrans:
matrix = transform.matrix
matrix[0, 0] = 0.86
transform.matrix = matrix
transforms = Transforms()
for c in tqdm(goodch):
if c == masterch:
transforms[(im.cnamelist[c],)] = transform
else:
transforms[(im.cnamelist[c],)] = Transform(max_ims[masterch], max_ims[c]) * transform
return transforms
def calculate_transforms(self):
transforms = [self.calculate_transform(file) for file in self.files]
for key in set([key for transform in transforms for key in transform.keys()]):
new_transforms = [transform[key] for transform in transforms if key in transform]
if len(new_transforms) == 1:
self[key] = new_transforms[0]
else:
self[key] = Transform()
self[key].parameters = np.mean([t.parameters for t in new_transforms], 0)
self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) /
np.sqrt(len(new_transforms))).tolist()
def save_transform_tiff(self):
n_channels = 0
for file in self.files:
with Imread(file) as im:
n_channels = max(n_channels, im.shape['c'])
with IJTiffFile(self.tifpath, (n_channels, 1, len(self.files))) as tif:
for t, file in enumerate(self.files):
with Imread(file) as im:
with Imread(file, transform=True) as jm:
for c in range(im.shape['c']):
tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t)
class ImShiftTransforms(Transforms):
""" Class to handle drift in xy. The image this is applied to must have a channel transform already, which is then
replaced by this class. """
def __init__(self, im, shifts=None):
""" im: Calculate shifts from channel-transformed images
im, t x 2 array Sets shifts from array, one row per frame
im, dict {frame: shift} Sets shifts from dict, each key is a frame number from where a new shift is applied
im, file Loads shifts from a saved file """
super().__init__()
with (Imread(im, transform=True, drift=False) if isinstance(im, str)
else im.new(transform=True, drift=False)) as im:
self.impath = im.path
self.path = self.impath.parent / self.impath.stem + '_shifts.txt'
self.tracks, self.detectors, self.files = im.track, im.detector, im.beadfile
if shifts is not None:
if isinstance(shifts, np.ndarray):
self.shifts = shifts
self.shifts2transforms(im)
elif isinstance(shifts, dict):
self.shifts = np.zeros((im.shape['t'], 2))
for k in sorted(shifts.keys()):
self.shifts[k:] = shifts[k]
self.shifts2transforms(im)
elif isinstance(shifts, str):
self.load(im, shifts)
elif self.path.exists():
self.load(im, self.path)
else:
self.calulate_shifts(im)
self.save()
def __call__(self, channel, time, tracks=None, detectors=None):
tracks = tracks or self.tracks
detectors = detectors or self.detectors
track, detector = tracks[channel], detectors[channel]
if (track, detector, time) in self:
return self[track, detector, time]
elif (0, detector, time) in self:
return self[0, detector, time]
else:
return Transform()
def load(self, im, file):
self.shifts = np.loadtxt(file)
self.shifts2transforms(im)
def save(self, file=None):
self.path = file or self.path
np.savetxt(self.path, self.shifts)
def coords(self, array, colums=None):
if isinstance(array, pandas.DataFrame):
return pandas.concat([self(int(row['C']), int(row['T'])).coords(row, colums)
for _, row in array.iterrows()], axis=1).T
elif isinstance(array, pandas.Series):
return self(int(array['C']), int(array['T'])).coords(array, colums)
else:
raise TypeError('Not a pandas DataFrame or Series.')
def calulate_shifts0(self, im):
""" Calculate shifts relative to the first frame """
im0 = im[:, 0, 0].squeeze().transpose(2, 0, 1)
@parfor(range(1, im.shape['t']), (im, im0), desc='Calculating image shifts.')
def fun(t, im, im0):
return Transform(im0, im[:, 0, t].squeeze().transpose(2, 0, 1), 'translation')
transforms = [Transform()] + fun
self.shifts = np.array([t.parameters[4:] for t in transforms])
self.set_transforms(transforms, im.transform)
def calulate_shifts(self, im):
""" Calculate shifts relative to the previous frame """
@parfor(range(1, im.shape['t']), (im,), desc='Calculating image shifts.')
def fun(t, im):
return Transform(im[:, 0, t - 1].squeeze().transpose(2, 0, 1), im[:, 0, t].squeeze().transpose(2, 0, 1),
'translation')
transforms = [Transform()] + fun
self.shifts = np.cumsum([t.parameters[4:] for t in transforms])
self.set_transforms(transforms, im.transform)
def shifts2transforms(self, im):
self.set_transforms([Transform(np.array(((1, 0, s[0]), (0, 1, s[1]), (0, 0, 1))))
for s in self.shifts], im.transform)
def set_transforms(self, shift_transforms, channel_transforms):
for key, value in channel_transforms.items():
for t, T in enumerate(shift_transforms):
self[key[0], key[1], t] = T * channel_transforms[key]
class DequeDict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
@@ -314,24 +84,24 @@ def find(obj, **kwargs):
def try_default(fun, default, *args, **kwargs):
try:
return fun(*args, **kwargs)
except (Exception,):
except Exception: # noqa
return default
def get_ome(path):
from .readers.bfread import jars
try:
jvm = JVM(jars)
jvm = JVM(jars) # noqa
ome_meta = jvm.metadata_tools.createOMEXMLMetadata()
reader = jvm.image_reader()
reader.setMetadataStore(ome_meta)
reader.setId(str(path))
ome = ome_types.from_xml(str(ome_meta.dumpXML()), parser='lxml')
except (Exception,):
except Exception: # noqa
print_exc()
ome = model.OME()
finally:
jvm.kill_vm()
jvm.kill_vm() # noqa
return ome
@@ -356,7 +126,59 @@ class Shape(tuple):
return tuple(self[i] for i in 'xyczt')
class Imread(np.lib.mixins.NDArrayOperatorsMixin):
class Imread(np.lib.mixins.NDArrayOperatorsMixin, ABC):
""" class to read image files, while taking good care of important metadata,
currently optimized for .czi files, but can open anything that bioformats can handle
path: path to the image file
optional:
series: in case multiple experiments are saved in one file, like in .lif files
dtype: datatype to be used when returning frames
meta: define metadata, used for pickle-ing
NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate
modify images on the fly with a decorator function:
define a function which takes an instance of this object, one image frame,
and the coordinates c, z, t as arguments, and one image frame as return
>> imread.frame_decorator = fun
then use imread as usually
Examples:
>> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi')
>> im
<< shows summary
>> im.shape
<< (256, 256, 2, 1, 600)
>> plt.imshow(im(1, 0, 100))
<< plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array
with 1 frame
>> data = im[:,:,0,0,:25]
<< retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array
>> plt.imshow(im.max(0, None, 0))
<< plots max-z projection at c=0, t=0
>> len(im)
<< total number of frames
>> im.pxsize
<< 0.09708737864077668 image-plane pixel size in um
>> im.laserwavelengths
<< [642, 488]
>> im.laserpowers
<< [0.02, 0.0005] in %
See __init__ and other functions for more ideas.
Subclassing:
Subclass AbstractReader to add more file types. A subclass should always have at least the following
methods:
staticmethod _can_open(path): returns True when the subclass can open the image in path
__metadata__(self): pulls some metadata from the file and do other format specific things,
it needs to define a few properties, like shape, etc.
__frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t
optional close(self): close the file in a proper way
optional field priority: subclasses with lower priority will be tried first, default = 99
Any other method can be overridden as needed
wp@tl2019-2023 """
def __new__(cls, path=None, *args, **kwargs):
if cls is not Imread:
return super().__new__(cls)
@@ -376,18 +198,13 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return super().__new__(subclass)
raise ReaderNotFoundError(f'No reader found for {path}.')
def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None,
transform=False, drift=False, beadfile=None, frame_decorator=None):
self.base = base
def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None, frame_decorator=None):
self.base = base or self
self.slice = slice
self._shape = Shape(shape)
self.dtype = dtype
self.frame_decorator = frame_decorator
self.transform = transform
self.drift = drift
self.beadfile = beadfile
self.transform = Transforms()
self.flags = dict(C_CONTIGUOUS=False, F_CONTIGUOUS=False, OWNDATA=False, WRITEABLE=False,
ALIGNED=False, WRITEBACKIFCOPY=False, UPDATEIFCOPY=False)
@@ -426,7 +243,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def __getitem__(self, n):
""" slice like a numpy array but return an Imread instance """
if self.isclosed:
raise IOError("file is closed")
raise OSError("file is closed")
if isinstance(n, (slice, Number)): # None = :
n = (n,)
elif isinstance(n, type(Ellipsis)):
@@ -644,9 +461,6 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
cfun(*tmps)).astype(p.sub('', dtype.name))
return out
def __framet__(self, c, z, t):
return self.transform_frame(self.__frame__(c, z, t), c, t)
@property
def axes(self):
return self.shape.axes
@@ -677,7 +491,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return
try:
return self.get_config(pname)
except (Exception,):
except Exception: # noqa
return
return
@@ -704,11 +518,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def summary(self):
""" gives a helpful summary of the recorded experiment """
s = [f"path/filename: {self.path}",
f"series/pos: {self.series}"]
if isinstance(self, View):
s.append(f"reader: {self.base.__class__.__module__.split('.')[-1]} view")
else:
s.append(f"reader: {self.__class__.__module__.split('.')[-1]} base")
f"series/pos: {self.series}",
f"reader: {self.base.__class__.__module__.split('.')[-1]}"]
s.extend((f"dtype: {self.dtype}",
f"shape ({self.axes}):".ljust(15) + f"{' x '.join(str(i) for i in self.shape)}"))
if self.pxsize_um:
@@ -728,15 +539,15 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
s.append('laser powers: ' + ' | '.join([' & '.join(len(p) * ('{:.3g}',)).format(*[100 * i for i in p])
for p in self.laserpowers]) + ' %')
if self.objective:
s.append('objective: {}'.format(self.objective.model))
s.append(f'objective: {self.objective.model}')
if self.magnification:
s.append('magnification: {}x'.format(self.magnification))
s.append(f'magnification: {self.magnification}x')
if self.tubelens:
s.append('tubelens: {}'.format(self.tubelens.model))
s.append(f'tubelens: {self.tubelens.model}')
if self.filter:
s.append('filterset: {}'.format(self.filter))
s.append(f'filterset: {self.filter}')
if self.powermode:
s.append('powermode: {}'.format(self.powermode))
s.append(f'powermode: {self.powermode}')
if self.collimator:
s.append('collimator: ' + (' {}' * len(self.collimator)).format(*self.collimator))
if self.tirfangle:
@@ -881,9 +692,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return new
@wraps(np.transpose)
def transpose(self, axes=None):
def transpose(self, *axes):
new = self.copy()
if axes is None:
if not axes:
new.axes = new.axes[::-1]
else:
new.axes = ''.join(ax if isinstance(ax, str) else new.axes[ax] for ax in axes)
@@ -918,8 +729,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def block(self, x=None, y=None, c=None, z=None, t=None):
""" returns 5D block of frames """
x, y, c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1)
for i, e in zip('xyczt', (x, y, c, z, t))]
x, y, c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1)
for i, e in zip('xyczt', (x, y, c, z, t)))
d = np.empty((len(x), len(y), len(c), len(z), len(t)), self.dtype)
for (ci, cj), (zi, zj), (ti, tj) in product(enumerate(c), enumerate(z), enumerate(t)):
d[:, :, ci, zi, ti] = self.frame(cj, zj, tj)[x][:, y]
@@ -930,7 +741,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def data(self, c=0, z=0, t=0):
""" returns 3D stack of frames """
c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t))]
c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t)))
return np.dstack([self.frame(ci, zi, ti) for ci, zi, ti in product(c, z, t)])
def frame(self, c=0, z=0, t=0):
@@ -946,7 +757,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
self.cache.move_to_end(key)
f = self.cache[key]
else:
f = self.__framet__(c, z, t)
f = self.transform[self.channel_names[c], t].frame(self.__frame__(c, z, t))
if self.frame_decorator is not None:
f = self.frame_decorator(self, f, c, z, t)
self.cache[key] = f
@@ -959,9 +770,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
if not isinstance(channel_name, str):
return channel_name
else:
c = [i for i, c in enumerate(self.cnamelist) if c.lower().startswith(channel_name.lower())]
assert len(c) > 0, 'Channel {} not found in {}'.format(c, self.cnamelist)
assert len(c) < 2, 'Channel {} not unique in {}'.format(c, self.cnamelist)
c = [i for i, c in enumerate(self.channel_names) if c.lower().startswith(channel_name.lower())]
assert len(c) > 0, f'Channel {c} not found in {self.channel_names}'
assert len(c) < 2, f'Channel {c} not unique in {self.channel_names}'
return c[0]
@staticmethod
@@ -978,7 +789,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
|[-+]?\\.(?:inf|Inf|INF)
|\.(?:nan|NaN|NAN))$''', re.X),
list(r'-+0123456789.'))
with open(file, 'r') as f:
with open(file) as f:
return yaml.load(f, loader)
def get_czt(self, c, z, t):
@@ -995,7 +806,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
stop = n.stop
czt.append(list(range(n.start % self.shape[i], stop, n.step)))
elif isinstance(n, Number):
czt.append([n % self.shape[i]])
czt.append([n % self.shape[i]]) # noqa
else:
czt.append([k % self.shape[i] for k in n])
return [self.get_channel(c) for c in czt[0]], *czt[1:]
@@ -1055,28 +866,42 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
shape = [len(i) for i in n]
with TransformTiff(self, fname.with_suffix('.tif'), shape, pixel_type,
pxsize=self.pxsize_um, deltaz=self.deltaz_um, **kwargs) as tif:
for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)),
for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)), # noqa
total=np.prod(shape), desc='Saving tiff', disable=not bar):
tif.save(m, *i)
def set_transform(self):
# handle transforms
if self.transform is False or self.transform is None:
self.transform = None
else:
if isinstance(self.transform, Transforms):
self.transform = self.transform
else:
if isinstance(self.transform, str):
self.transform = ImTransforms(self.path, self.cyllens, self.transform)
else:
self.transform = ImTransforms(self.path, self.cyllens, self.beadfile)
if self.drift is True:
self.transform = ImShiftTransforms(self)
elif not (self.drift is False or self.drift is None):
self.transform = ImShiftTransforms(self, self.drift)
self.transform.adapt(self.frameoffset, self.shape.xyczt)
self.beadfile = self.transform.files
def with_transform(self, channels=True, drift=False, file=None, bead_files=()):
""" returns a view where channels and/or frames are registered with an affine transformation
channels: True/False register channels using bead_files
drift: True/False register frames to correct drift
file: load registration from file with name file, default: transform.yml in self.path.parent
bead_files: files used to register channels, default: files in self.path.parent,
with names starting with 'beads'
"""
view = self.view()
if file is None:
file = Path(view.path.parent) / 'transform.yml'
if not bead_files:
bead_files = Transforms.get_bead_files(view.path.parent)
if channels:
try:
view.transform = Transforms.from_file(file, T=drift)
# for key in view.channel_names:
# if
except Exception: # noqa
view.transform = Transforms().with_beads(view.cyllens, bead_files)
if drift:
view.transform = view.transform.with_drift(view)
view.transform.save(file.with_suffix('.yml'))
view.transform.save_channel_transform_tiff(bead_files, file.with_suffix('.tif'))
elif drift:
try:
view.transform = Transforms.from_file(file, C=False)
except Exception: # noqa
view.transform = Transforms().with_drift(self)
view.transform.adapt(view.frameoffset, view.shape.xyczt, view.channel_names)
return view
@staticmethod
def split_path_series(path):
@@ -1086,21 +911,14 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return path.parent, int(path.name.lstrip('Pos'))
return path, 0
def transform_frame(self, frame, c, t=0):
if self.transform is None:
return frame
else:
return self.transform[(self.cnamelist[c],)].frame(frame)
def view(self, *args, **kwargs):
return View(self, *args, **kwargs)
class View(Imread):
def __init__(self, base, dtype=None, transform=None, drift=None, beadfile=None):
super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, transform or base.transform,
drift or base.drift, beadfile or base.beadfile, base.frame_decorator)
self.set_transform()
class View(Imread, ABC):
def __init__(self, base, dtype=None):
super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, base.frame_decorator)
self.transform = base.transform
def __getattr__(self, item):
if not hasattr(self.base, item):
@@ -1109,60 +927,6 @@ class View(Imread):
class AbstractReader(Imread, metaclass=ABCMeta):
""" class to read image files, while taking good care of important metadata,
currently optimized for .czi files, but can open anything that bioformats can handle
path: path to the image file
optional:
series: in case multiple experiments are saved in one file, like in .lif files
transform: automatically correct warping between channels, need transforms.py among others
drift: automatically correct for drift, only works if transform is not None or False
beadfile: image file(s) with beads which can be used for correcting warp
dtype: datatype to be used when returning frames
meta: define metadata, used for pickle-ing
NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate
modify images on the fly with a decorator function:
define a function which takes an instance of this object, one image frame,
and the coordinates c, z, t as arguments, and one image frame as return
>> imread.frame_decorator = fun
then use imread as usually
Examples:
>> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi')
>> im
<< shows summary
>> im.shape
<< (256, 256, 2, 1, 600)
>> plt.imshow(im(1, 0, 100))
<< plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array
with 1 frame
>> data = im[:,:,0,0,:25]
<< retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array
>> plt.imshow(im.max(0, None, 0))
<< plots max-z projection at c=0, t=0
>> len(im)
<< total number of frames
>> im.pxsize
<< 0.09708737864077668 image-plane pixel size in um
>> im.laserwavelengths
<< [642, 488]
>> im.laserpowers
<< [0.02, 0.0005] in %
See __init__ and other functions for more ideas.
Subclassing:
Subclass this class to add more file types. A subclass should always have at least the following methods:
staticmethod _can_open(path): returns True when the subclass can open the image in path
__metadata__(self): pulls some metadata from the file and do other format specific things, it needs to
define a few properties, like shape, etc.
__frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t
optional close(self): close the file in a proper way
optional field priority: subclasses with lower priority will be tried first, default = 99
Any other method can be overridden as needed
wp@tl2019-2023 """
priority = 99
do_not_pickle = 'cache'
ureg = ureg
@@ -1187,10 +951,10 @@ class AbstractReader(Imread, metaclass=ABCMeta):
def close(self): # Optionally override this, close file handles etc.
return
def __init__(self, path, transform=False, drift=False, beadfile=None, dtype=None, axes=None):
def __init__(self, path, dtype=None, axes=None):
if isinstance(path, Imread):
return
super().__init__(self, transform=transform, drift=drift, beadfile=beadfile)
super().__init__()
self.isclosed = False
if isinstance(path, str):
path = Path(path)
@@ -1248,7 +1012,9 @@ class AbstractReader(Imread, metaclass=ABCMeta):
self.pxsize *= self.binning[0]
except (AttributeError, IndexError, ValueError):
self.binning = None
self.cnamelist = [channel.name for channel in image.pixels.channels]
self.channel_names = [channel.name for channel in image.pixels.channels]
self.channel_names += [chr(97 + i) for i in range(len(self.channel_names), self.shape['c'])]
self.cnamelist = self.channel_names
try:
optovars = [objective for objective in instrument.objectives if 'tubelens' in objective.id.lower()]
except AttributeError:
@@ -1297,19 +1063,18 @@ class AbstractReader(Imread, metaclass=ABCMeta):
self.feedback = m['FeedbackChannels']
else:
self.feedback = m['FeedbackChannel']
except (Exception,):
except Exception: # noqa
self.cyllens = ['None', 'None']
self.duolink = 'None'
self.feedback = []
try:
self.cyllenschannels = np.where([self.cyllens[self.detector[c]].lower() != 'none'
for c in range(self.shape['c'])])[0].tolist()
except (Exception,):
except Exception: # noqa
pass
self.set_transform()
try:
s = int(re.findall(r'_(\d{3})_', self.duolink)[0]) * ureg.nm
except (Exception,):
except Exception: # noqa
s = 561 * ureg.nm
try:
sigma = []
@@ -1320,7 +1085,7 @@ class AbstractReader(Imread, metaclass=ABCMeta):
sigma[sigma == 0] = 600 * ureg.nm
sigma /= 2 * self.NA * self.pxsize
self.sigma = sigma.magnitude.tolist()
except (Exception,):
except Exception: # noqa
self.sigma = [2] * self.shape['c']
if not self.NA:
self.immersionN = 1
@@ -1341,6 +1106,7 @@ class AbstractReader(Imread, metaclass=ABCMeta):
except Exception:
pass
def main():
parser = ArgumentParser(description='Display info and save as tif')
parser.add_argument('file', help='image_file')
@@ -1353,13 +1119,15 @@ def main():
parser.add_argument('-f', '--force', help='force overwrite', action='store_true')
args = parser.parse_args()
with Imread(args.file, transform=args.register) as im:
with Imread(args.file) as im:
if args.register:
im = im.with_transform()
print(im.summary)
if args.out:
out = Path(args.out).absolute()
out.parent.mkdir(parents=True, exist_ok=True)
if out.exists() and not args.force:
print('File {} exists already, add the -f flag if you want to overwrite it.'.format(args.out))
print(f'File {args.out} exists already, add the -f flag if you want to overwrite it.')
else:
im.save_as_tiff(out, args.channel, args.zslice, args.time, args.split)

View File

@@ -19,27 +19,27 @@ try:
def __init__(self, jars=None):
if not self.vm_started and not self.vm_killed:
try:
jarpath = Path(__file__).parent / 'jars'
jar_path = Path(__file__).parent / 'jars'
if jars is None:
jars = {}
for jar, src in jars.items():
if not (jarpath / jar).exists():
JVM.download(src, jarpath / jar)
classpath = [str(jarpath / jar) for jar in jars.keys()]
if not (jar_path / jar).exists():
JVM.download(src, jar_path / jar)
classpath = [str(jar_path / jar) for jar in jars.keys()]
import jpype
jpype.startJVM(classpath=classpath)
except Exception:
except Exception: # noqa
self.vm_started = False
else:
self.vm_started = True
try:
import jpype.imports
from loci.common import DebugTools
from loci.formats import ImageReader
from loci.formats import ChannelSeparator
from loci.formats import FormatTools
from loci.formats import MetadataTools
from loci.common import DebugTools # noqa
from loci.formats import ImageReader # noqa
from loci.formats import ChannelSeparator # noqa
from loci.formats import FormatTools # noqa
from loci.formats import MetadataTools # noqa
DebugTools.setRootLevel("ERROR")
@@ -47,7 +47,7 @@ try:
self.channel_separator = ChannelSeparator
self.format_tools = FormatTools
self.metadata_tools = MetadataTools
except Exception:
except Exception: # noqa
pass
if self.vm_killed:
@@ -64,7 +64,7 @@ try:
self = cls._instance
if self is not None and self.vm_started and not self.vm_killed:
import jpype
jpype.shutdownJVM()
jpype.shutdownJVM() # noqa
self.vm_started = False
self.vm_killed = True

View File

@@ -153,7 +153,7 @@ class JVMReader:
self.queue_out.put(image[..., c])
else:
self.queue_out.put(image)
except queues.Empty:
except queues.Empty: # noqa
continue
except (Exception,):
print_exc()
@@ -171,7 +171,7 @@ def can_open(path):
except (Exception,):
return False
finally:
jvm.kill_vm()
jvm.kill_vm() # noqa
class Reader(AbstractReader, ABC):

View File

@@ -31,7 +31,7 @@ class Reader(AbstractReader, ABC):
filedict[c, z, t].append(directory_entry)
else:
filedict[c, z, t] = [directory_entry]
self.filedict = filedict
self.filedict = filedict # noqa
def close(self):
self.reader.close()
@@ -116,8 +116,8 @@ class Reader(AbstractReader, ABC):
y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]])
size_x = x_max - x_min
size_y = y_max - y_min
size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT']
size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT')
image = information.find("Image")
pixel_type = text(image.find("PixelType"), "Gray16")
@@ -277,23 +277,23 @@ class Reader(AbstractReader, ABC):
text(light_source.find("LightSourceType").find("Laser").find("Wavelength")))))
multi_track_setup = acquisition_block.find("MultiTrackSetup")
for idx, tube_lens in enumerate(set(text(track_setup.find("TubeLensPosition"))
for track_setup in multi_track_setup)):
for idx, tube_lens in enumerate({text(track_setup.find("TubeLensPosition"))
for track_setup in multi_track_setup}):
ome.instruments[0].objectives.append(
model.Objective(id=f"Objective:Tubelens:{idx}", model=tube_lens,
nominal_magnification=float(
re.findall(r'\d+[,.]\d*', tube_lens)[0].replace(',', '.'))
))
for idx, filter_ in enumerate(set(text(beam_splitter.find("Filter"))
for track_setup in multi_track_setup
for beam_splitter in track_setup.find("BeamSplitters"))):
for idx, filter_ in enumerate({text(beam_splitter.find("Filter"))
for track_setup in multi_track_setup
for beam_splitter in track_setup.find("BeamSplitters")}):
ome.instruments[0].filter_sets.append(
model.FilterSet(id=f"FilterSet:{idx}", model=filter_)
)
for idx, collimator in enumerate(set(text(track_setup.find("FWFOVPosition"))
for track_setup in multi_track_setup)):
for idx, collimator in enumerate({text(track_setup.find("FWFOVPosition"))
for track_setup in multi_track_setup}):
ome.instruments[0].filters.append(model.Filter(id=f"Filter:Collimator:{idx}", model=collimator))
x_min = min([f.start[f.axes.index('X')] for f in self.filedict[0, 0, 0]])
@@ -302,8 +302,8 @@ class Reader(AbstractReader, ABC):
y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]])
size_x = x_max - x_min
size_y = y_max - y_min
size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT']
size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT')
image = information.find("Image")
pixel_type = text(image.find("PixelType"), "Gray16")

View File

@@ -32,10 +32,10 @@ class Reader(AbstractReader, ABC):
self.reader = TiffFile(self.path)
assert self.reader.pages[0].compression == 1, "Can only read uncompressed tiff files."
assert self.reader.pages[0].samplesperpixel == 1, "Can only read 1 sample per pixel."
self.offset = self.reader.pages[0].dataoffsets[0]
self.count = self.reader.pages[0].databytecounts[0]
self.bytes_per_sample = self.reader.pages[0].bitspersample // 8
self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1]
self.offset = self.reader.pages[0].dataoffsets[0] # noqa
self.count = self.reader.pages[0].databytecounts[0] # noqa
self.bytes_per_sample = self.reader.pages[0].bitspersample // 8 # noqa
self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1] # noqa
def close(self):
self.reader.close()

View File

@@ -15,7 +15,7 @@ class Reader(AbstractReader, ABC):
@cached_property
def ome(self):
def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1):
def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1): # noqa
return size_x, size_y, size_c, size_z, size_t
size_x, size_y, size_c, size_z, size_t = shape(*self.array.shape)
try:
@@ -42,7 +42,7 @@ class Reader(AbstractReader, ABC):
if isinstance(self.path, np.ndarray):
self.array = np.array(self.path)
while self.array.ndim < 5:
self.array = np.expand_dims(self.array, -1)
self.array = np.expand_dims(self.array, -1) # noqa
self.path = 'numpy array'
def __frame__(self, c, z, t):

View File

@@ -4,7 +4,7 @@ import re
from pathlib import Path
from functools import cached_property
from ome_types import model
from ome_types.units import _quantity_property
from ome_types.units import _quantity_property # noqa
from itertools import product
from datetime import datetime
from abc import ABC
@@ -15,7 +15,10 @@ def lazy_property(function, field, *arg_fields):
def lazy(self):
if self.__dict__.get(field) is None:
self.__dict__[field] = function(*[getattr(self, arg_field) for arg_field in arg_fields])
self.model_fields_set.add(field)
try:
self.model_fields_set.add(field)
except Exception: # noqa
pass
return self.__dict__[field]
return property(lazy)
@@ -24,6 +27,7 @@ class Plane(model.Plane):
""" Lazily retrieve delta_t from metadata """
def __init__(self, t0, file, **kwargs):
super().__init__(**kwargs)
# setting fields here because they would be removed by ome_types/pydantic after class definition
setattr(self.__class__, 'delta_t', lazy_property(self.get_delta_t, 'delta_t', 't0', 'file'))
setattr(self.__class__, 'delta_t_quantity', _quantity_property('delta_t'))
self.__dict__['t0'] = t0
@@ -79,7 +83,7 @@ class Reader(AbstractReader, ABC):
else:
pixel_type = "uint16" # assume
size_c, size_z, size_t = [max(i) + 1 for i in zip(*self.filedict.keys())]
size_c, size_z, size_t = (max(i) + 1 for i in zip(*self.filedict.keys()))
t0 = datetime.strptime(metadata["Info"]["Time"], "%Y-%m-%d %H:%M:%S %z")
ome.images.append(
model.Image(
@@ -123,7 +127,7 @@ class Reader(AbstractReader, ABC):
pattern_c = re.compile(r"img_\d{3,}_(.*)_\d{3,}$")
pattern_z = re.compile(r"(\d{3,})$")
pattern_t = re.compile(r"img_(\d{3,})")
self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]),
self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]), # noqa
int(pattern_z.findall(file.stem)[0]),
int(pattern_t.findall(file.stem)[0])): file for file in filelist}

View File

@@ -17,7 +17,7 @@ class Reader(AbstractReader, ABC):
def _can_open(path):
if isinstance(path, Path) and path.suffix in ('.tif', '.tiff'):
with tifffile.TiffFile(path) as tif:
return tif.is_imagej and tif.pages[-1]._nextifd() == 0
return tif.is_imagej and tif.pages[-1]._nextifd() == 0 # noqa
else:
return False
@@ -27,12 +27,12 @@ class Reader(AbstractReader, ABC):
for key, value in self.reader.imagej_metadata.items()}
page = self.reader.pages[0]
self.p_ndim = page.ndim
self.p_ndim = page.ndim # noqa
size_x = page.imagelength
size_y = page.imagewidth
if self.p_ndim == 3:
size_c = page.samplesperpixel
self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0]
self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0] # noqa
size_t = metadata.get('frames', 1) # // C
else:
size_c = metadata.get('channels', 1)

View File

@@ -1,19 +1,24 @@
import yaml
import re
import numpy as np
import warnings
from copy import deepcopy
from pathlib import Path
import numpy as np
import yaml
from parfor import pmap, Chunks
from skimage import filters
from tiffwrite import IJTiffFile
from tqdm.auto import tqdm
try:
# best if SimpleElastix is installed: https://simpleelastix.readthedocs.io/GettingStarted.html
import SimpleITK as sitk
import SimpleITK as sitk # noqa
except ImportError:
sitk = None
try:
from pandas import DataFrame, Series
from pandas import DataFrame, Series, concat
except ImportError:
DataFrame, Series = None, None
DataFrame, Series, concat = None, None, None
if hasattr(yaml, 'full_load'):
@@ -24,10 +29,30 @@ else:
class Transforms(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args[1:], **kwargs)
super().__init__(*args, **kwargs)
self.default = Transform()
if len(args):
self.load(args[0])
@classmethod
def from_file(cls, file, C=True, T=True):
with open(Path(file).with_suffix('.yml')) as f:
return cls.from_dict(yamlload(f), C, T)
@classmethod
def from_dict(cls, d, C=True, T=True):
new = cls()
for key, value in d.items():
if isinstance(key, str) and C:
new[key.replace(r'\:', ':').replace('\\\\', '\\')] = Transform.from_dict(value)
elif T:
new[key] = Transform.from_dict(value)
return new
@classmethod
def from_shifts(cls, shifts):
new = cls()
for key, shift in shifts.items():
new[key] = Transform.from_shift(shift)
return new
def __mul__(self, other):
new = Transforms()
@@ -44,18 +69,11 @@ class Transforms(dict):
return new
def asdict(self):
return {':'.join(str(i).replace('\\', '\\\\').replace(':', r'\:') for i in key): value.asdict()
return {key.replace('\\', '\\\\').replace(':', r'\:') if isinstance(key, str) else key: value.asdict()
for key, value in self.items()}
def load(self, file):
if isinstance(file, dict):
d = file
else:
with open(file.with_suffix(".yml"), 'r') as f:
d = yamlload(f)
pattern = re.compile(r'[^\\]:')
for key, value in d.items():
self[tuple(i.replace(r'\:', ':').replace('\\\\', '\\') for i in pattern.split(key))] = Transform(value)
def __getitem__(self, item):
return np.prod([self[i] for i in item[::-1]]) if isinstance(item, tuple) else super().__getitem__(item)
def __missing__(self, key):
return self.default
@@ -70,47 +88,242 @@ class Transforms(dict):
return hash(frozenset((*self.__dict__.items(), *self.items())))
def save(self, file):
with open(file.with_suffix(".yml"), 'w') as f:
with open(Path(file).with_suffix(".yml"), 'w') as f:
yaml.safe_dump(self.asdict(), f, default_flow_style=None)
def copy(self):
return deepcopy(self)
def adapt(self, origin, shape):
def adapt(self, origin, shape, channel_names):
def key_map(a, b):
def fun(b, key_a):
for key_b in b:
if key_b in key_a or key_a in key_b:
return key_a, key_b
return {n[0]: n[1] for key_a in a if (n := fun(b, key_a))}
for value in self.values():
value.adapt(origin, shape)
self.default.adapt(origin, shape)
transform_channels = {key for key in self.keys() if isinstance(key, str)}
if set(channel_names) - transform_channels:
mapping = key_map(channel_names, transform_channels)
warnings.warn(f'The image file and the transform do not have the same channels,'
f' creating a mapping: {mapping}')
for key_im, key_t in mapping.items():
self[key_im] = self[key_t]
@property
def inverse(self):
# TODO: check for C@T
inverse = self.copy()
for key, value in self.items():
inverse[key] = value.inverse
return inverse
@property
def ndim(self):
return len(list(self.keys())[0])
def coords_pandas(self, array, channel_names, columns=None):
if isinstance(array, DataFrame):
return concat([self.coords_pandas(row, channel_names, columns) for _, row in array.iterrows()], axis=1).T
elif isinstance(array, Series):
key = []
if 'C' in array:
key.append(channel_names[int(array['C'])])
if 'T' in array:
key.append(int(array['T']))
return self[tuple(key)].coords(array, columns)
else:
raise TypeError('Not a pandas DataFrame or Series.')
def with_beads(self, cyllens, bead_files):
assert len(bead_files) > 0, "At least one file is needed to calculate the registration."
transforms = [self.calculate_channel_transforms(file, cyllens) for file in bead_files]
for key in {key for transform in transforms for key in transform.keys()}:
new_transforms = [transform[key] for transform in transforms if key in transform]
if len(new_transforms) == 1:
self[key] = new_transforms[0]
else:
self[key] = Transform()
self[key].parameters = np.mean([t.parameters for t in new_transforms], 0)
self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) /
np.sqrt(len(new_transforms))).tolist()
return self
@staticmethod
def get_bead_files(path):
from . import Imread
files = []
for file in path.iterdir():
if file.name.lower().startswith('beads'):
try:
with Imread(file):
files.append(file)
except Exception:
pass
files = sorted(files)
if not files:
raise Exception('No bead file found!')
checked_files = []
for file in files:
try:
if file.is_dir():
file /= 'Pos0'
with Imread(file): # check for errors opening the file
checked_files.append(file)
except (Exception,):
continue
if not checked_files:
raise Exception('No bead file found!')
return checked_files
@staticmethod
def calculate_channel_transforms(bead_file, cyllens):
""" When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162
in the horizontal direction """
from . import Imread
with Imread(bead_file, axes='zcxy') as im: # noqa
max_ims = im.max('z')
goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)]
if not goodch:
goodch = list(range(len(max_ims)))
untransformed = [c for c in range(im.shape['c']) if cyllens[im.detector[c]].lower() == 'none']
good_and_untrans = sorted(set(goodch) & set(untransformed))
if good_and_untrans:
masterch = good_and_untrans[0]
else:
masterch = goodch[0]
transform = Transform()
if not good_and_untrans:
matrix = transform.matrix
matrix[0, 0] = 0.86
transform.matrix = matrix
transforms = Transforms()
for c in tqdm(goodch, desc='Calculating channel transforms'): # noqa
if c == masterch:
transforms[im.channel_names[c]] = transform
else:
transforms[im.channel_names[c]] = Transform.register(max_ims[masterch], max_ims[c]) * transform
return transforms
@staticmethod
def save_channel_transform_tiff(bead_files, tiffile):
from . import Imread
n_channels = 0
for file in bead_files:
with Imread(file) as im:
n_channels = max(n_channels, im.shape['c'])
with IJTiffFile(tiffile, (n_channels, 1, len(bead_files))) as tif:
for t, file in enumerate(bead_files):
with Imread(file) as im:
with Imread(file).with_transform() as jm:
for c in range(im.shape['c']):
tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t)
def with_drift(self, im):
""" Calculate shifts relative to the first frame
divide the sequence into groups,
compare each frame to the frame in the middle of the group and compare these middle frames to each other
"""
im = im.transpose('tzycx')
t_groups = [list(chunk) for chunk in Chunks(range(im.shape['t']), size=round(np.sqrt(im.shape['t'])))]
t_keys = [int(np.round(np.mean(t_group))) for t_group in t_groups]
t_pairs = [(int(np.round(np.mean(t_group))), frame) for t_group in t_groups for frame in t_group]
t_pairs.extend(zip(t_keys, t_keys[1:]))
fmaxz_keys = {t_key: filters.gaussian(im[t_key].max('z'), 5) for t_key in t_keys}
def fun(t_key_t, im, fmaxz_keys):
t_key, t = t_key_t
if t_key == t:
return 0, 0
else:
fmaxz = filters.gaussian(im[t].max('z'), 5)
return Transform.register(fmaxz_keys[t_key], fmaxz, 'translation').parameters[4:]
shifts = np.array(pmap(fun, t_pairs, (im, fmaxz_keys), desc='Calculating image shifts.'))
shift_keys_cum = np.zeros(2)
for shift_keys, t_group in zip(np.vstack((-shifts[0], shifts[im.shape['t']:])), t_groups):
shift_keys_cum += shift_keys
shifts[t_group] += shift_keys_cum
for i, shift in enumerate(shifts[:im.shape['t']]):
self[i] = Transform.from_shift(shift)
return self
class Transform:
def __init__(self, *args):
def __init__(self):
if sitk is None:
raise ImportError('SimpleElastix is not installed: '
'https://simpleelastix.readthedocs.io/GettingStarted.html')
self.transform = sitk.ReadTransform(str(Path(__file__).parent / 'transform.txt'))
self.dparameters = 0, 0, 0, 0, 0, 0
self.shape = 512, 512
self.origin = 255.5, 255.5
if len(args) == 1: # load from file or dict
if isinstance(args[0], np.ndarray):
self.matrix = args[0]
else:
self.load(*args)
elif len(args) > 1: # make new transform using fixed and moving image
self.register(*args)
self.dparameters = [0., 0., 0., 0., 0., 0.]
self.shape = [512., 512.]
self.origin = [255.5, 255.5]
self._last, self._inverse = None, None
def __reduce__(self):
return self.from_dict, (self.asdict(),)
def __repr__(self):
return self.asdict().__repr__()
def __str__(self):
return self.asdict().__str__()
@classmethod
def register(cls, fix, mov, kind=None):
""" kind: 'affine', 'translation', 'rigid' """
new = cls()
kind = kind or 'affine'
new.shape = fix.shape
fix, mov = new.cast_image(fix), new.cast_image(mov)
# TODO: implement RigidTransform
tfilter = sitk.ElastixImageFilter()
tfilter.LogToConsoleOff()
tfilter.SetFixedImage(fix)
tfilter.SetMovingImage(mov)
tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind))
tfilter.Execute()
transform = tfilter.GetTransformParameterMap()[0]
if kind == 'affine':
new.parameters = [float(t) for t in transform['TransformParameters']]
new.shape = [float(t) for t in transform['Size']]
new.origin = [float(t) for t in transform['CenterOfRotationPoint']]
elif kind == 'translation':
new.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']]
new.shape = [float(t) for t in transform['Size']]
new.origin = [(t - 1) / 2 for t in new.shape]
else:
raise NotImplementedError(f'{kind} tranforms not implemented (yet)')
new.dparameters = 6 * [np.nan]
return new
@classmethod
def from_shift(cls, shift):
return cls.from_array(np.array(((1, 0, shift[0]), (0, 1, shift[1]), (0, 0, 1))))
@classmethod
def from_array(cls, array):
new = cls()
new.matrix = array
return new
@classmethod
def from_file(cls, file):
with open(Path(file).with_suffix('.yml')) as f:
return cls.from_dict(yamlload(f))
@classmethod
def from_dict(cls, d):
new = cls()
new.origin = [float(i) for i in d['CenterOfRotationPoint']]
new.parameters = [float(i) for i in d['TransformParameters']]
new.dparameters = [float(i) for i in d['dTransformParameters']] if 'dTransformParameters' in d else 6 * [np.nan]
new.shape = [float(i) for i in d['Size']]
return new
def __mul__(self, other): # TODO: take care of dmatrix
result = self.copy()
if isinstance(other, Transform):
@@ -121,9 +334,6 @@ class Transform:
result.dmatrix = self.dmatrix @ other
return result
def __reduce__(self):
return self.__class__, (self.asdict(),)
def is_unity(self):
return self.parameters == [1, 0, 0, 1, 0, 0]
@@ -166,7 +376,7 @@ class Transform:
@property
def parameters(self):
return self.transform.GetParameters()
return list(self.transform.GetParameters())
@parameters.setter
def parameters(self, value):
@@ -186,7 +396,7 @@ class Transform:
def inverse(self):
if self._last is None or self._last != self.asdict():
self._last = self.asdict()
self._inverse = Transform(self.asdict())
self._inverse = Transform.from_dict(self.asdict())
self._inverse.transform = self._inverse.transform.GetInverse()
self._inverse._last = self._inverse.asdict()
self._inverse._inverse = self
@@ -234,45 +444,3 @@ class Transform:
file += '.yml'
with open(file, 'w') as f:
yaml.safe_dump(self.asdict(), f, default_flow_style=None)
def load(self, file):
""" load the parameters of a transform from a yaml file or a dict
"""
if isinstance(file, dict):
d = file
else:
if not file[-3:] == 'yml':
file += '.yml'
with open(file, 'r') as f:
d = yamlload(f)
self.origin = [float(i) for i in d['CenterOfRotationPoint']]
self.parameters = [float(i) for i in d['TransformParameters']]
self.dparameters = [float(i) for i in d['dTransformParameters']] \
if 'dTransformParameters' in d else 6 * [np.nan]
self.shape = [float(i) for i in d['Size']]
def register(self, fix, mov, kind=None):
""" kind: 'affine', 'translation', 'rigid'
"""
kind = kind or 'affine'
self.shape = fix.shape
fix, mov = self.cast_image(fix), self.cast_image(mov)
# TODO: implement RigidTransform
tfilter = sitk.ElastixImageFilter()
tfilter.LogToConsoleOff()
tfilter.SetFixedImage(fix)
tfilter.SetMovingImage(mov)
tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind))
tfilter.Execute()
transform = tfilter.GetTransformParameterMap()[0]
if kind == 'affine':
self.parameters = [float(t) for t in transform['TransformParameters']]
self.shape = [float(t) for t in transform['Size']]
self.origin = [float(t) for t in transform['CenterOfRotationPoint']]
elif kind == 'translation':
self.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']]
self.shape = [float(t) for t in transform['Size']]
self.origin = [(t - 1) / 2 for t in self.shape]
else:
raise NotImplementedError(f'{kind} tranforms not implemented (yet)')
self.dparameters = 6 * [np.nan]

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "ndbioimage"
version = "2023.10.1"
version = "2023.10.2"
description = "Bio image reading, metadata and some affine registration."
authors = ["W. Pomp <w.pomp@nki.nl>"]
license = "GPLv3"
@@ -22,7 +22,7 @@ pint = "*"
tqdm = "*"
lxml = "*"
pyyaml = "*"
parfor = ">=2023.8.2"
parfor = ">=2023.10.1"
JPype1 = "*"
SimpleITK-SimpleElastix = "*"
pytest = { version = "*", optional = true }