- Improve transforms, implement drift correction transforms, require from_ methods for instantiation.

- Replace transform, drift and beadfile arguments for Imread by with_transform method.
- Bring Imread.transpose in line with numpy.transpose.
- Fix seqread.lazy_property.
This commit is contained in:
Wim Pomp
2023-11-02 15:05:40 +01:00
parent 2e56f45f3e
commit d13b702481
10 changed files with 447 additions and 507 deletions

View File

@@ -1,33 +1,32 @@
import multiprocessing
import re import re
import warnings import warnings
from abc import ABC, ABCMeta, abstractmethod
import pandas
import yaml
import numpy as np
import multiprocessing
import ome_types
from ome_types import ureg, model
from pint import set_application_registry
from datetime import datetime
from tqdm.auto import tqdm
from itertools import product
from collections import OrderedDict
from abc import ABCMeta, abstractmethod
from functools import cached_property, wraps
from parfor import parfor
from tiffwrite import IJTiffFile
from numbers import Number
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from collections import OrderedDict
from datetime import datetime
from functools import cached_property, wraps
from importlib.metadata import version from importlib.metadata import version
from traceback import print_exc from itertools import product
from numbers import Number
from operator import truediv from operator import truediv
from .transforms import Transform, Transforms from pathlib import Path
from traceback import print_exc
import numpy as np
import ome_types
import yaml
from ome_types import model, ureg
from pint import set_application_registry
from tiffwrite import IJTiffFile
from tqdm.auto import tqdm
from .jvm import JVM from .jvm import JVM
from .transforms import Transform, Transforms
try: try:
__version__ = version(Path(__file__).parent.name) __version__ = version(Path(__file__).parent.name)
except (Exception,): except Exception: # noqa
__version__ = 'unknown' __version__ = 'unknown'
try: try:
@@ -35,7 +34,7 @@ try:
head = g.read().split(':')[1].strip() head = g.read().split(':')[1].strip()
with open(Path(__file__).parent.parent / '.git' / head) as h: with open(Path(__file__).parent.parent / '.git' / head) as h:
__git_commit_hash__ = h.read().rstrip('\n') __git_commit_hash__ = h.read().rstrip('\n')
except (Exception,): except Exception: # noqa
__git_commit_hash__ = 'unknown' __git_commit_hash__ = 'unknown'
ureg.default_format = '~P' ureg.default_format = '~P'
@@ -57,235 +56,6 @@ class TransformTiff(IJTiffFile):
return super().compress_frame(np.asarray(self.image(*frame)).astype(self.dtype)) return super().compress_frame(np.asarray(self.image(*frame)).astype(self.dtype))
class ImTransforms(Transforms):
""" Transforms class with methods to calculate channel transforms from bead files etc. """
def __init__(self, path, cyllens, file=None, transforms=None):
super().__init__()
self.cyllens = tuple(cyllens)
if transforms is None:
# TODO: check this
if re.search(r'^Pos\d+', path.name):
self.path = path.parent.parent
else:
self.path = path.parent
if file is not None:
if isinstance(file, str) and file.lower().endswith('.yml'):
self.ymlpath = file
self.beadfile = None
else:
self.ymlpath = self.path / 'transform.yml'
self.beadfile = file
else:
self.ymlpath = self.path / 'transform.yml'
self.beadfile = None
self.tifpath = self.ymlpath.with_suffix('.tif')
try:
self.load(self.ymlpath)
except (Exception,):
print('No transform file found, trying to generate one.')
if not self.files:
raise FileNotFoundError('No bead files found to calculate the transform from.')
self.calculate_transforms()
self.save(self.ymlpath)
self.save_transform_tiff()
print(f'Saving transform in {self.ymlpath}.')
print(f'Please check the transform in {self.tifpath}.')
else: # load from dict transforms
self.path = path
self.beadfile = file
for i, (key, value) in enumerate(transforms.items()):
self[key] = Transform(value)
def coords_pandas(self, array, cnamelist, colums=None):
if isinstance(array, pandas.DataFrame):
return pandas.concat([self[(cnamelist[int(row['C'])],)].coords(row, colums)
for _, row in array.iterrows()], axis=1).T
elif isinstance(array, pandas.Series):
return self[(cnamelist[int(array['C'])],)].coords(array, colums)
else:
raise TypeError('Not a pandas DataFrame or Series.')
@cached_property
def files(self):
try:
if self.beadfile is None:
files = self.get_bead_files()
else:
files = self.beadfile
if isinstance(files, str):
files = (Path(files),)
elif isinstance(files, Path):
files = (files,)
return tuple(files)
except (Exception,):
return ()
def get_bead_files(self):
files = sorted([f for f in self.path.iterdir() if f.name.lower().startswith('beads')
and not f.suffix.lower() == '.pdf' and not f.suffix.lower() == 'pkl'])
if not files:
raise Exception('No bead file found!')
checked_files = []
for file in files:
try:
if file.is_dir():
file /= 'Pos0'
with Imread(file): # check for errors opening the file
checked_files.append(file)
except (Exception,):
continue
if not checked_files:
raise Exception('No bead file found!')
return checked_files
def calculate_transform(self, file):
""" When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162
in the horizontal direction """
with Imread(file, axes='zcxy') as im:
max_ims = im.max('z')
goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)]
if not goodch:
goodch = list(range(len(max_ims)))
untransformed = [c for c in range(im.shape['c']) if self.cyllens[im.detector[c]].lower() == 'none']
good_and_untrans = sorted(set(goodch) & set(untransformed))
if good_and_untrans:
masterch = good_and_untrans[0]
else:
masterch = goodch[0]
transform = Transform()
if not good_and_untrans:
matrix = transform.matrix
matrix[0, 0] = 0.86
transform.matrix = matrix
transforms = Transforms()
for c in tqdm(goodch):
if c == masterch:
transforms[(im.cnamelist[c],)] = transform
else:
transforms[(im.cnamelist[c],)] = Transform(max_ims[masterch], max_ims[c]) * transform
return transforms
def calculate_transforms(self):
transforms = [self.calculate_transform(file) for file in self.files]
for key in set([key for transform in transforms for key in transform.keys()]):
new_transforms = [transform[key] for transform in transforms if key in transform]
if len(new_transforms) == 1:
self[key] = new_transforms[0]
else:
self[key] = Transform()
self[key].parameters = np.mean([t.parameters for t in new_transforms], 0)
self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) /
np.sqrt(len(new_transforms))).tolist()
def save_transform_tiff(self):
n_channels = 0
for file in self.files:
with Imread(file) as im:
n_channels = max(n_channels, im.shape['c'])
with IJTiffFile(self.tifpath, (n_channels, 1, len(self.files))) as tif:
for t, file in enumerate(self.files):
with Imread(file) as im:
with Imread(file, transform=True) as jm:
for c in range(im.shape['c']):
tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t)
class ImShiftTransforms(Transforms):
""" Class to handle drift in xy. The image this is applied to must have a channel transform already, which is then
replaced by this class. """
def __init__(self, im, shifts=None):
""" im: Calculate shifts from channel-transformed images
im, t x 2 array Sets shifts from array, one row per frame
im, dict {frame: shift} Sets shifts from dict, each key is a frame number from where a new shift is applied
im, file Loads shifts from a saved file """
super().__init__()
with (Imread(im, transform=True, drift=False) if isinstance(im, str)
else im.new(transform=True, drift=False)) as im:
self.impath = im.path
self.path = self.impath.parent / self.impath.stem + '_shifts.txt'
self.tracks, self.detectors, self.files = im.track, im.detector, im.beadfile
if shifts is not None:
if isinstance(shifts, np.ndarray):
self.shifts = shifts
self.shifts2transforms(im)
elif isinstance(shifts, dict):
self.shifts = np.zeros((im.shape['t'], 2))
for k in sorted(shifts.keys()):
self.shifts[k:] = shifts[k]
self.shifts2transforms(im)
elif isinstance(shifts, str):
self.load(im, shifts)
elif self.path.exists():
self.load(im, self.path)
else:
self.calulate_shifts(im)
self.save()
def __call__(self, channel, time, tracks=None, detectors=None):
tracks = tracks or self.tracks
detectors = detectors or self.detectors
track, detector = tracks[channel], detectors[channel]
if (track, detector, time) in self:
return self[track, detector, time]
elif (0, detector, time) in self:
return self[0, detector, time]
else:
return Transform()
def load(self, im, file):
self.shifts = np.loadtxt(file)
self.shifts2transforms(im)
def save(self, file=None):
self.path = file or self.path
np.savetxt(self.path, self.shifts)
def coords(self, array, colums=None):
if isinstance(array, pandas.DataFrame):
return pandas.concat([self(int(row['C']), int(row['T'])).coords(row, colums)
for _, row in array.iterrows()], axis=1).T
elif isinstance(array, pandas.Series):
return self(int(array['C']), int(array['T'])).coords(array, colums)
else:
raise TypeError('Not a pandas DataFrame or Series.')
def calulate_shifts0(self, im):
""" Calculate shifts relative to the first frame """
im0 = im[:, 0, 0].squeeze().transpose(2, 0, 1)
@parfor(range(1, im.shape['t']), (im, im0), desc='Calculating image shifts.')
def fun(t, im, im0):
return Transform(im0, im[:, 0, t].squeeze().transpose(2, 0, 1), 'translation')
transforms = [Transform()] + fun
self.shifts = np.array([t.parameters[4:] for t in transforms])
self.set_transforms(transforms, im.transform)
def calulate_shifts(self, im):
""" Calculate shifts relative to the previous frame """
@parfor(range(1, im.shape['t']), (im,), desc='Calculating image shifts.')
def fun(t, im):
return Transform(im[:, 0, t - 1].squeeze().transpose(2, 0, 1), im[:, 0, t].squeeze().transpose(2, 0, 1),
'translation')
transforms = [Transform()] + fun
self.shifts = np.cumsum([t.parameters[4:] for t in transforms])
self.set_transforms(transforms, im.transform)
def shifts2transforms(self, im):
self.set_transforms([Transform(np.array(((1, 0, s[0]), (0, 1, s[1]), (0, 0, 1))))
for s in self.shifts], im.transform)
def set_transforms(self, shift_transforms, channel_transforms):
for key, value in channel_transforms.items():
for t, T in enumerate(shift_transforms):
self[key[0], key[1], t] = T * channel_transforms[key]
class DequeDict(OrderedDict): class DequeDict(OrderedDict):
def __init__(self, maxlen=None, *args, **kwargs): def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen self.maxlen = maxlen
@@ -314,24 +84,24 @@ def find(obj, **kwargs):
def try_default(fun, default, *args, **kwargs): def try_default(fun, default, *args, **kwargs):
try: try:
return fun(*args, **kwargs) return fun(*args, **kwargs)
except (Exception,): except Exception: # noqa
return default return default
def get_ome(path): def get_ome(path):
from .readers.bfread import jars from .readers.bfread import jars
try: try:
jvm = JVM(jars) jvm = JVM(jars) # noqa
ome_meta = jvm.metadata_tools.createOMEXMLMetadata() ome_meta = jvm.metadata_tools.createOMEXMLMetadata()
reader = jvm.image_reader() reader = jvm.image_reader()
reader.setMetadataStore(ome_meta) reader.setMetadataStore(ome_meta)
reader.setId(str(path)) reader.setId(str(path))
ome = ome_types.from_xml(str(ome_meta.dumpXML()), parser='lxml') ome = ome_types.from_xml(str(ome_meta.dumpXML()), parser='lxml')
except (Exception,): except Exception: # noqa
print_exc() print_exc()
ome = model.OME() ome = model.OME()
finally: finally:
jvm.kill_vm() jvm.kill_vm() # noqa
return ome return ome
@@ -356,7 +126,59 @@ class Shape(tuple):
return tuple(self[i] for i in 'xyczt') return tuple(self[i] for i in 'xyczt')
class Imread(np.lib.mixins.NDArrayOperatorsMixin): class Imread(np.lib.mixins.NDArrayOperatorsMixin, ABC):
""" class to read image files, while taking good care of important metadata,
currently optimized for .czi files, but can open anything that bioformats can handle
path: path to the image file
optional:
series: in case multiple experiments are saved in one file, like in .lif files
dtype: datatype to be used when returning frames
meta: define metadata, used for pickle-ing
NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate
modify images on the fly with a decorator function:
define a function which takes an instance of this object, one image frame,
and the coordinates c, z, t as arguments, and one image frame as return
>> imread.frame_decorator = fun
then use imread as usually
Examples:
>> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi')
>> im
<< shows summary
>> im.shape
<< (256, 256, 2, 1, 600)
>> plt.imshow(im(1, 0, 100))
<< plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array
with 1 frame
>> data = im[:,:,0,0,:25]
<< retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array
>> plt.imshow(im.max(0, None, 0))
<< plots max-z projection at c=0, t=0
>> len(im)
<< total number of frames
>> im.pxsize
<< 0.09708737864077668 image-plane pixel size in um
>> im.laserwavelengths
<< [642, 488]
>> im.laserpowers
<< [0.02, 0.0005] in %
See __init__ and other functions for more ideas.
Subclassing:
Subclass AbstractReader to add more file types. A subclass should always have at least the following
methods:
staticmethod _can_open(path): returns True when the subclass can open the image in path
__metadata__(self): pulls some metadata from the file and do other format specific things,
it needs to define a few properties, like shape, etc.
__frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t
optional close(self): close the file in a proper way
optional field priority: subclasses with lower priority will be tried first, default = 99
Any other method can be overridden as needed
wp@tl2019-2023 """
def __new__(cls, path=None, *args, **kwargs): def __new__(cls, path=None, *args, **kwargs):
if cls is not Imread: if cls is not Imread:
return super().__new__(cls) return super().__new__(cls)
@@ -376,18 +198,13 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return super().__new__(subclass) return super().__new__(subclass)
raise ReaderNotFoundError(f'No reader found for {path}.') raise ReaderNotFoundError(f'No reader found for {path}.')
def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None, def __init__(self, base=None, slice=None, shape=(0, 0, 0, 0, 0), dtype=None, frame_decorator=None):
transform=False, drift=False, beadfile=None, frame_decorator=None): self.base = base or self
self.base = base
self.slice = slice self.slice = slice
self._shape = Shape(shape) self._shape = Shape(shape)
self.dtype = dtype self.dtype = dtype
self.frame_decorator = frame_decorator self.frame_decorator = frame_decorator
self.transform = Transforms()
self.transform = transform
self.drift = drift
self.beadfile = beadfile
self.flags = dict(C_CONTIGUOUS=False, F_CONTIGUOUS=False, OWNDATA=False, WRITEABLE=False, self.flags = dict(C_CONTIGUOUS=False, F_CONTIGUOUS=False, OWNDATA=False, WRITEABLE=False,
ALIGNED=False, WRITEBACKIFCOPY=False, UPDATEIFCOPY=False) ALIGNED=False, WRITEBACKIFCOPY=False, UPDATEIFCOPY=False)
@@ -426,7 +243,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def __getitem__(self, n): def __getitem__(self, n):
""" slice like a numpy array but return an Imread instance """ """ slice like a numpy array but return an Imread instance """
if self.isclosed: if self.isclosed:
raise IOError("file is closed") raise OSError("file is closed")
if isinstance(n, (slice, Number)): # None = : if isinstance(n, (slice, Number)): # None = :
n = (n,) n = (n,)
elif isinstance(n, type(Ellipsis)): elif isinstance(n, type(Ellipsis)):
@@ -644,9 +461,6 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
cfun(*tmps)).astype(p.sub('', dtype.name)) cfun(*tmps)).astype(p.sub('', dtype.name))
return out return out
def __framet__(self, c, z, t):
return self.transform_frame(self.__frame__(c, z, t), c, t)
@property @property
def axes(self): def axes(self):
return self.shape.axes return self.shape.axes
@@ -677,7 +491,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return return
try: try:
return self.get_config(pname) return self.get_config(pname)
except (Exception,): except Exception: # noqa
return return
return return
@@ -704,11 +518,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def summary(self): def summary(self):
""" gives a helpful summary of the recorded experiment """ """ gives a helpful summary of the recorded experiment """
s = [f"path/filename: {self.path}", s = [f"path/filename: {self.path}",
f"series/pos: {self.series}"] f"series/pos: {self.series}",
if isinstance(self, View): f"reader: {self.base.__class__.__module__.split('.')[-1]}"]
s.append(f"reader: {self.base.__class__.__module__.split('.')[-1]} view")
else:
s.append(f"reader: {self.__class__.__module__.split('.')[-1]} base")
s.extend((f"dtype: {self.dtype}", s.extend((f"dtype: {self.dtype}",
f"shape ({self.axes}):".ljust(15) + f"{' x '.join(str(i) for i in self.shape)}")) f"shape ({self.axes}):".ljust(15) + f"{' x '.join(str(i) for i in self.shape)}"))
if self.pxsize_um: if self.pxsize_um:
@@ -728,15 +539,15 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
s.append('laser powers: ' + ' | '.join([' & '.join(len(p) * ('{:.3g}',)).format(*[100 * i for i in p]) s.append('laser powers: ' + ' | '.join([' & '.join(len(p) * ('{:.3g}',)).format(*[100 * i for i in p])
for p in self.laserpowers]) + ' %') for p in self.laserpowers]) + ' %')
if self.objective: if self.objective:
s.append('objective: {}'.format(self.objective.model)) s.append(f'objective: {self.objective.model}')
if self.magnification: if self.magnification:
s.append('magnification: {}x'.format(self.magnification)) s.append(f'magnification: {self.magnification}x')
if self.tubelens: if self.tubelens:
s.append('tubelens: {}'.format(self.tubelens.model)) s.append(f'tubelens: {self.tubelens.model}')
if self.filter: if self.filter:
s.append('filterset: {}'.format(self.filter)) s.append(f'filterset: {self.filter}')
if self.powermode: if self.powermode:
s.append('powermode: {}'.format(self.powermode)) s.append(f'powermode: {self.powermode}')
if self.collimator: if self.collimator:
s.append('collimator: ' + (' {}' * len(self.collimator)).format(*self.collimator)) s.append('collimator: ' + (' {}' * len(self.collimator)).format(*self.collimator))
if self.tirfangle: if self.tirfangle:
@@ -881,9 +692,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return new return new
@wraps(np.transpose) @wraps(np.transpose)
def transpose(self, axes=None): def transpose(self, *axes):
new = self.copy() new = self.copy()
if axes is None: if not axes:
new.axes = new.axes[::-1] new.axes = new.axes[::-1]
else: else:
new.axes = ''.join(ax if isinstance(ax, str) else new.axes[ax] for ax in axes) new.axes = ''.join(ax if isinstance(ax, str) else new.axes[ax] for ax in axes)
@@ -918,8 +729,8 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def block(self, x=None, y=None, c=None, z=None, t=None): def block(self, x=None, y=None, c=None, z=None, t=None):
""" returns 5D block of frames """ """ returns 5D block of frames """
x, y, c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) x, y, c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1)
for i, e in zip('xyczt', (x, y, c, z, t))] for i, e in zip('xyczt', (x, y, c, z, t)))
d = np.empty((len(x), len(y), len(c), len(z), len(t)), self.dtype) d = np.empty((len(x), len(y), len(c), len(z), len(t)), self.dtype)
for (ci, cj), (zi, zj), (ti, tj) in product(enumerate(c), enumerate(z), enumerate(t)): for (ci, cj), (zi, zj), (ti, tj) in product(enumerate(c), enumerate(z), enumerate(t)):
d[:, :, ci, zi, ti] = self.frame(cj, zj, tj)[x][:, y] d[:, :, ci, zi, ti] = self.frame(cj, zj, tj)[x][:, y]
@@ -930,7 +741,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
def data(self, c=0, z=0, t=0): def data(self, c=0, z=0, t=0):
""" returns 3D stack of frames """ """ returns 3D stack of frames """
c, z, t = [np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t))] c, z, t = (np.arange(self.shape[i]) if e is None else np.array(e, ndmin=1) for i, e in zip('czt', (c, z, t)))
return np.dstack([self.frame(ci, zi, ti) for ci, zi, ti in product(c, z, t)]) return np.dstack([self.frame(ci, zi, ti) for ci, zi, ti in product(c, z, t)])
def frame(self, c=0, z=0, t=0): def frame(self, c=0, z=0, t=0):
@@ -946,7 +757,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
self.cache.move_to_end(key) self.cache.move_to_end(key)
f = self.cache[key] f = self.cache[key]
else: else:
f = self.__framet__(c, z, t) f = self.transform[self.channel_names[c], t].frame(self.__frame__(c, z, t))
if self.frame_decorator is not None: if self.frame_decorator is not None:
f = self.frame_decorator(self, f, c, z, t) f = self.frame_decorator(self, f, c, z, t)
self.cache[key] = f self.cache[key] = f
@@ -959,9 +770,9 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
if not isinstance(channel_name, str): if not isinstance(channel_name, str):
return channel_name return channel_name
else: else:
c = [i for i, c in enumerate(self.cnamelist) if c.lower().startswith(channel_name.lower())] c = [i for i, c in enumerate(self.channel_names) if c.lower().startswith(channel_name.lower())]
assert len(c) > 0, 'Channel {} not found in {}'.format(c, self.cnamelist) assert len(c) > 0, f'Channel {c} not found in {self.channel_names}'
assert len(c) < 2, 'Channel {} not unique in {}'.format(c, self.cnamelist) assert len(c) < 2, f'Channel {c} not unique in {self.channel_names}'
return c[0] return c[0]
@staticmethod @staticmethod
@@ -978,7 +789,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
|[-+]?\\.(?:inf|Inf|INF) |[-+]?\\.(?:inf|Inf|INF)
|\.(?:nan|NaN|NAN))$''', re.X), |\.(?:nan|NaN|NAN))$''', re.X),
list(r'-+0123456789.')) list(r'-+0123456789.'))
with open(file, 'r') as f: with open(file) as f:
return yaml.load(f, loader) return yaml.load(f, loader)
def get_czt(self, c, z, t): def get_czt(self, c, z, t):
@@ -995,7 +806,7 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
stop = n.stop stop = n.stop
czt.append(list(range(n.start % self.shape[i], stop, n.step))) czt.append(list(range(n.start % self.shape[i], stop, n.step)))
elif isinstance(n, Number): elif isinstance(n, Number):
czt.append([n % self.shape[i]]) czt.append([n % self.shape[i]]) # noqa
else: else:
czt.append([k % self.shape[i] for k in n]) czt.append([k % self.shape[i] for k in n])
return [self.get_channel(c) for c in czt[0]], *czt[1:] return [self.get_channel(c) for c in czt[0]], *czt[1:]
@@ -1055,28 +866,42 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
shape = [len(i) for i in n] shape = [len(i) for i in n]
with TransformTiff(self, fname.with_suffix('.tif'), shape, pixel_type, with TransformTiff(self, fname.with_suffix('.tif'), shape, pixel_type,
pxsize=self.pxsize_um, deltaz=self.deltaz_um, **kwargs) as tif: pxsize=self.pxsize_um, deltaz=self.deltaz_um, **kwargs) as tif:
for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)), for i, m in tqdm(zip(product(*[range(s) for s in shape]), product(*n)), # noqa
total=np.prod(shape), desc='Saving tiff', disable=not bar): total=np.prod(shape), desc='Saving tiff', disable=not bar):
tif.save(m, *i) tif.save(m, *i)
def set_transform(self): def with_transform(self, channels=True, drift=False, file=None, bead_files=()):
# handle transforms """ returns a view where channels and/or frames are registered with an affine transformation
if self.transform is False or self.transform is None: channels: True/False register channels using bead_files
self.transform = None drift: True/False register frames to correct drift
else: file: load registration from file with name file, default: transform.yml in self.path.parent
if isinstance(self.transform, Transforms): bead_files: files used to register channels, default: files in self.path.parent,
self.transform = self.transform with names starting with 'beads'
else: """
if isinstance(self.transform, str): view = self.view()
self.transform = ImTransforms(self.path, self.cyllens, self.transform) if file is None:
else: file = Path(view.path.parent) / 'transform.yml'
self.transform = ImTransforms(self.path, self.cyllens, self.beadfile) if not bead_files:
if self.drift is True: bead_files = Transforms.get_bead_files(view.path.parent)
self.transform = ImShiftTransforms(self)
elif not (self.drift is False or self.drift is None): if channels:
self.transform = ImShiftTransforms(self, self.drift) try:
self.transform.adapt(self.frameoffset, self.shape.xyczt) view.transform = Transforms.from_file(file, T=drift)
self.beadfile = self.transform.files # for key in view.channel_names:
# if
except Exception: # noqa
view.transform = Transforms().with_beads(view.cyllens, bead_files)
if drift:
view.transform = view.transform.with_drift(view)
view.transform.save(file.with_suffix('.yml'))
view.transform.save_channel_transform_tiff(bead_files, file.with_suffix('.tif'))
elif drift:
try:
view.transform = Transforms.from_file(file, C=False)
except Exception: # noqa
view.transform = Transforms().with_drift(self)
view.transform.adapt(view.frameoffset, view.shape.xyczt, view.channel_names)
return view
@staticmethod @staticmethod
def split_path_series(path): def split_path_series(path):
@@ -1086,21 +911,14 @@ class Imread(np.lib.mixins.NDArrayOperatorsMixin):
return path.parent, int(path.name.lstrip('Pos')) return path.parent, int(path.name.lstrip('Pos'))
return path, 0 return path, 0
def transform_frame(self, frame, c, t=0):
if self.transform is None:
return frame
else:
return self.transform[(self.cnamelist[c],)].frame(frame)
def view(self, *args, **kwargs): def view(self, *args, **kwargs):
return View(self, *args, **kwargs) return View(self, *args, **kwargs)
class View(Imread): class View(Imread, ABC):
def __init__(self, base, dtype=None, transform=None, drift=None, beadfile=None): def __init__(self, base, dtype=None):
super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, transform or base.transform, super().__init__(base.base, base.slice, base.shape, dtype or base.dtype, base.frame_decorator)
drift or base.drift, beadfile or base.beadfile, base.frame_decorator) self.transform = base.transform
self.set_transform()
def __getattr__(self, item): def __getattr__(self, item):
if not hasattr(self.base, item): if not hasattr(self.base, item):
@@ -1109,60 +927,6 @@ class View(Imread):
class AbstractReader(Imread, metaclass=ABCMeta): class AbstractReader(Imread, metaclass=ABCMeta):
""" class to read image files, while taking good care of important metadata,
currently optimized for .czi files, but can open anything that bioformats can handle
path: path to the image file
optional:
series: in case multiple experiments are saved in one file, like in .lif files
transform: automatically correct warping between channels, need transforms.py among others
drift: automatically correct for drift, only works if transform is not None or False
beadfile: image file(s) with beads which can be used for correcting warp
dtype: datatype to be used when returning frames
meta: define metadata, used for pickle-ing
NOTE: run imread.kill_vm() at the end of your script/program, otherwise python might not terminate
modify images on the fly with a decorator function:
define a function which takes an instance of this object, one image frame,
and the coordinates c, z, t as arguments, and one image frame as return
>> imread.frame_decorator = fun
then use imread as usually
Examples:
>> im = imread('/DATA/lenstra_lab/w.pomp/data/20190913/01_YTL639_JF646_DefiniteFocus.czi')
>> im
<< shows summary
>> im.shape
<< (256, 256, 2, 1, 600)
>> plt.imshow(im(1, 0, 100))
<< plots frame at position c=1, z=0, t=100 (python type indexing), note: round brackets; always 2d array
with 1 frame
>> data = im[:,:,0,0,:25]
<< retrieves 5d numpy array containing first 25 frames at c=0, z=0, note: square brackets; always 5d array
>> plt.imshow(im.max(0, None, 0))
<< plots max-z projection at c=0, t=0
>> len(im)
<< total number of frames
>> im.pxsize
<< 0.09708737864077668 image-plane pixel size in um
>> im.laserwavelengths
<< [642, 488]
>> im.laserpowers
<< [0.02, 0.0005] in %
See __init__ and other functions for more ideas.
Subclassing:
Subclass this class to add more file types. A subclass should always have at least the following methods:
staticmethod _can_open(path): returns True when the subclass can open the image in path
__metadata__(self): pulls some metadata from the file and do other format specific things, it needs to
define a few properties, like shape, etc.
__frame__(self, c, z, t): this should return a single frame at channel c, slice z and time t
optional close(self): close the file in a proper way
optional field priority: subclasses with lower priority will be tried first, default = 99
Any other method can be overridden as needed
wp@tl2019-2023 """
priority = 99 priority = 99
do_not_pickle = 'cache' do_not_pickle = 'cache'
ureg = ureg ureg = ureg
@@ -1187,10 +951,10 @@ class AbstractReader(Imread, metaclass=ABCMeta):
def close(self): # Optionally override this, close file handles etc. def close(self): # Optionally override this, close file handles etc.
return return
def __init__(self, path, transform=False, drift=False, beadfile=None, dtype=None, axes=None): def __init__(self, path, dtype=None, axes=None):
if isinstance(path, Imread): if isinstance(path, Imread):
return return
super().__init__(self, transform=transform, drift=drift, beadfile=beadfile) super().__init__()
self.isclosed = False self.isclosed = False
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
@@ -1248,7 +1012,9 @@ class AbstractReader(Imread, metaclass=ABCMeta):
self.pxsize *= self.binning[0] self.pxsize *= self.binning[0]
except (AttributeError, IndexError, ValueError): except (AttributeError, IndexError, ValueError):
self.binning = None self.binning = None
self.cnamelist = [channel.name for channel in image.pixels.channels] self.channel_names = [channel.name for channel in image.pixels.channels]
self.channel_names += [chr(97 + i) for i in range(len(self.channel_names), self.shape['c'])]
self.cnamelist = self.channel_names
try: try:
optovars = [objective for objective in instrument.objectives if 'tubelens' in objective.id.lower()] optovars = [objective for objective in instrument.objectives if 'tubelens' in objective.id.lower()]
except AttributeError: except AttributeError:
@@ -1297,19 +1063,18 @@ class AbstractReader(Imread, metaclass=ABCMeta):
self.feedback = m['FeedbackChannels'] self.feedback = m['FeedbackChannels']
else: else:
self.feedback = m['FeedbackChannel'] self.feedback = m['FeedbackChannel']
except (Exception,): except Exception: # noqa
self.cyllens = ['None', 'None'] self.cyllens = ['None', 'None']
self.duolink = 'None' self.duolink = 'None'
self.feedback = [] self.feedback = []
try: try:
self.cyllenschannels = np.where([self.cyllens[self.detector[c]].lower() != 'none' self.cyllenschannels = np.where([self.cyllens[self.detector[c]].lower() != 'none'
for c in range(self.shape['c'])])[0].tolist() for c in range(self.shape['c'])])[0].tolist()
except (Exception,): except Exception: # noqa
pass pass
self.set_transform()
try: try:
s = int(re.findall(r'_(\d{3})_', self.duolink)[0]) * ureg.nm s = int(re.findall(r'_(\d{3})_', self.duolink)[0]) * ureg.nm
except (Exception,): except Exception: # noqa
s = 561 * ureg.nm s = 561 * ureg.nm
try: try:
sigma = [] sigma = []
@@ -1320,7 +1085,7 @@ class AbstractReader(Imread, metaclass=ABCMeta):
sigma[sigma == 0] = 600 * ureg.nm sigma[sigma == 0] = 600 * ureg.nm
sigma /= 2 * self.NA * self.pxsize sigma /= 2 * self.NA * self.pxsize
self.sigma = sigma.magnitude.tolist() self.sigma = sigma.magnitude.tolist()
except (Exception,): except Exception: # noqa
self.sigma = [2] * self.shape['c'] self.sigma = [2] * self.shape['c']
if not self.NA: if not self.NA:
self.immersionN = 1 self.immersionN = 1
@@ -1341,6 +1106,7 @@ class AbstractReader(Imread, metaclass=ABCMeta):
except Exception: except Exception:
pass pass
def main(): def main():
parser = ArgumentParser(description='Display info and save as tif') parser = ArgumentParser(description='Display info and save as tif')
parser.add_argument('file', help='image_file') parser.add_argument('file', help='image_file')
@@ -1353,13 +1119,15 @@ def main():
parser.add_argument('-f', '--force', help='force overwrite', action='store_true') parser.add_argument('-f', '--force', help='force overwrite', action='store_true')
args = parser.parse_args() args = parser.parse_args()
with Imread(args.file, transform=args.register) as im: with Imread(args.file) as im:
if args.register:
im = im.with_transform()
print(im.summary) print(im.summary)
if args.out: if args.out:
out = Path(args.out).absolute() out = Path(args.out).absolute()
out.parent.mkdir(parents=True, exist_ok=True) out.parent.mkdir(parents=True, exist_ok=True)
if out.exists() and not args.force: if out.exists() and not args.force:
print('File {} exists already, add the -f flag if you want to overwrite it.'.format(args.out)) print(f'File {args.out} exists already, add the -f flag if you want to overwrite it.')
else: else:
im.save_as_tiff(out, args.channel, args.zslice, args.time, args.split) im.save_as_tiff(out, args.channel, args.zslice, args.time, args.split)

View File

@@ -19,27 +19,27 @@ try:
def __init__(self, jars=None): def __init__(self, jars=None):
if not self.vm_started and not self.vm_killed: if not self.vm_started and not self.vm_killed:
try: try:
jarpath = Path(__file__).parent / 'jars' jar_path = Path(__file__).parent / 'jars'
if jars is None: if jars is None:
jars = {} jars = {}
for jar, src in jars.items(): for jar, src in jars.items():
if not (jarpath / jar).exists(): if not (jar_path / jar).exists():
JVM.download(src, jarpath / jar) JVM.download(src, jar_path / jar)
classpath = [str(jarpath / jar) for jar in jars.keys()] classpath = [str(jar_path / jar) for jar in jars.keys()]
import jpype import jpype
jpype.startJVM(classpath=classpath) jpype.startJVM(classpath=classpath)
except Exception: except Exception: # noqa
self.vm_started = False self.vm_started = False
else: else:
self.vm_started = True self.vm_started = True
try: try:
import jpype.imports import jpype.imports
from loci.common import DebugTools from loci.common import DebugTools # noqa
from loci.formats import ImageReader from loci.formats import ImageReader # noqa
from loci.formats import ChannelSeparator from loci.formats import ChannelSeparator # noqa
from loci.formats import FormatTools from loci.formats import FormatTools # noqa
from loci.formats import MetadataTools from loci.formats import MetadataTools # noqa
DebugTools.setRootLevel("ERROR") DebugTools.setRootLevel("ERROR")
@@ -47,7 +47,7 @@ try:
self.channel_separator = ChannelSeparator self.channel_separator = ChannelSeparator
self.format_tools = FormatTools self.format_tools = FormatTools
self.metadata_tools = MetadataTools self.metadata_tools = MetadataTools
except Exception: except Exception: # noqa
pass pass
if self.vm_killed: if self.vm_killed:
@@ -64,7 +64,7 @@ try:
self = cls._instance self = cls._instance
if self is not None and self.vm_started and not self.vm_killed: if self is not None and self.vm_started and not self.vm_killed:
import jpype import jpype
jpype.shutdownJVM() jpype.shutdownJVM() # noqa
self.vm_started = False self.vm_started = False
self.vm_killed = True self.vm_killed = True

View File

@@ -153,7 +153,7 @@ class JVMReader:
self.queue_out.put(image[..., c]) self.queue_out.put(image[..., c])
else: else:
self.queue_out.put(image) self.queue_out.put(image)
except queues.Empty: except queues.Empty: # noqa
continue continue
except (Exception,): except (Exception,):
print_exc() print_exc()
@@ -171,7 +171,7 @@ def can_open(path):
except (Exception,): except (Exception,):
return False return False
finally: finally:
jvm.kill_vm() jvm.kill_vm() # noqa
class Reader(AbstractReader, ABC): class Reader(AbstractReader, ABC):

View File

@@ -31,7 +31,7 @@ class Reader(AbstractReader, ABC):
filedict[c, z, t].append(directory_entry) filedict[c, z, t].append(directory_entry)
else: else:
filedict[c, z, t] = [directory_entry] filedict[c, z, t] = [directory_entry]
self.filedict = filedict self.filedict = filedict # noqa
def close(self): def close(self):
self.reader.close() self.reader.close()
@@ -116,8 +116,8 @@ class Reader(AbstractReader, ABC):
y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]]) y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]])
size_x = x_max - x_min size_x = x_max - x_min
size_y = y_max - y_min size_y = y_max - y_min
size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)] size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT'] for directory_entry in 'CZT')
image = information.find("Image") image = information.find("Image")
pixel_type = text(image.find("PixelType"), "Gray16") pixel_type = text(image.find("PixelType"), "Gray16")
@@ -277,23 +277,23 @@ class Reader(AbstractReader, ABC):
text(light_source.find("LightSourceType").find("Laser").find("Wavelength"))))) text(light_source.find("LightSourceType").find("Laser").find("Wavelength")))))
multi_track_setup = acquisition_block.find("MultiTrackSetup") multi_track_setup = acquisition_block.find("MultiTrackSetup")
for idx, tube_lens in enumerate(set(text(track_setup.find("TubeLensPosition")) for idx, tube_lens in enumerate({text(track_setup.find("TubeLensPosition"))
for track_setup in multi_track_setup)): for track_setup in multi_track_setup}):
ome.instruments[0].objectives.append( ome.instruments[0].objectives.append(
model.Objective(id=f"Objective:Tubelens:{idx}", model=tube_lens, model.Objective(id=f"Objective:Tubelens:{idx}", model=tube_lens,
nominal_magnification=float( nominal_magnification=float(
re.findall(r'\d+[,.]\d*', tube_lens)[0].replace(',', '.')) re.findall(r'\d+[,.]\d*', tube_lens)[0].replace(',', '.'))
)) ))
for idx, filter_ in enumerate(set(text(beam_splitter.find("Filter")) for idx, filter_ in enumerate({text(beam_splitter.find("Filter"))
for track_setup in multi_track_setup for track_setup in multi_track_setup
for beam_splitter in track_setup.find("BeamSplitters"))): for beam_splitter in track_setup.find("BeamSplitters")}):
ome.instruments[0].filter_sets.append( ome.instruments[0].filter_sets.append(
model.FilterSet(id=f"FilterSet:{idx}", model=filter_) model.FilterSet(id=f"FilterSet:{idx}", model=filter_)
) )
for idx, collimator in enumerate(set(text(track_setup.find("FWFOVPosition")) for idx, collimator in enumerate({text(track_setup.find("FWFOVPosition"))
for track_setup in multi_track_setup)): for track_setup in multi_track_setup}):
ome.instruments[0].filters.append(model.Filter(id=f"Filter:Collimator:{idx}", model=collimator)) ome.instruments[0].filters.append(model.Filter(id=f"Filter:Collimator:{idx}", model=collimator))
x_min = min([f.start[f.axes.index('X')] for f in self.filedict[0, 0, 0]]) x_min = min([f.start[f.axes.index('X')] for f in self.filedict[0, 0, 0]])
@@ -302,8 +302,8 @@ class Reader(AbstractReader, ABC):
y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]]) y_max = max([f.start[f.axes.index('Y')] + f.shape[f.axes.index('Y')] for f in self.filedict[0, 0, 0]])
size_x = x_max - x_min size_x = x_max - x_min
size_y = y_max - y_min size_y = y_max - y_min
size_c, size_z, size_t = [self.reader.shape[self.reader.axes.index(directory_entry)] size_c, size_z, size_t = (self.reader.shape[self.reader.axes.index(directory_entry)]
for directory_entry in 'CZT'] for directory_entry in 'CZT')
image = information.find("Image") image = information.find("Image")
pixel_type = text(image.find("PixelType"), "Gray16") pixel_type = text(image.find("PixelType"), "Gray16")

View File

@@ -32,10 +32,10 @@ class Reader(AbstractReader, ABC):
self.reader = TiffFile(self.path) self.reader = TiffFile(self.path)
assert self.reader.pages[0].compression == 1, "Can only read uncompressed tiff files." assert self.reader.pages[0].compression == 1, "Can only read uncompressed tiff files."
assert self.reader.pages[0].samplesperpixel == 1, "Can only read 1 sample per pixel." assert self.reader.pages[0].samplesperpixel == 1, "Can only read 1 sample per pixel."
self.offset = self.reader.pages[0].dataoffsets[0] self.offset = self.reader.pages[0].dataoffsets[0] # noqa
self.count = self.reader.pages[0].databytecounts[0] self.count = self.reader.pages[0].databytecounts[0] # noqa
self.bytes_per_sample = self.reader.pages[0].bitspersample // 8 self.bytes_per_sample = self.reader.pages[0].bitspersample // 8 # noqa
self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1] self.fmt = self.reader.byteorder + self.count // self.bytes_per_sample * 'BHILQ'[self.bytes_per_sample - 1] # noqa
def close(self): def close(self):
self.reader.close() self.reader.close()

View File

@@ -15,7 +15,7 @@ class Reader(AbstractReader, ABC):
@cached_property @cached_property
def ome(self): def ome(self):
def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1): def shape(size_x=1, size_y=1, size_c=1, size_z=1, size_t=1): # noqa
return size_x, size_y, size_c, size_z, size_t return size_x, size_y, size_c, size_z, size_t
size_x, size_y, size_c, size_z, size_t = shape(*self.array.shape) size_x, size_y, size_c, size_z, size_t = shape(*self.array.shape)
try: try:
@@ -42,7 +42,7 @@ class Reader(AbstractReader, ABC):
if isinstance(self.path, np.ndarray): if isinstance(self.path, np.ndarray):
self.array = np.array(self.path) self.array = np.array(self.path)
while self.array.ndim < 5: while self.array.ndim < 5:
self.array = np.expand_dims(self.array, -1) self.array = np.expand_dims(self.array, -1) # noqa
self.path = 'numpy array' self.path = 'numpy array'
def __frame__(self, c, z, t): def __frame__(self, c, z, t):

View File

@@ -4,7 +4,7 @@ import re
from pathlib import Path from pathlib import Path
from functools import cached_property from functools import cached_property
from ome_types import model from ome_types import model
from ome_types.units import _quantity_property from ome_types.units import _quantity_property # noqa
from itertools import product from itertools import product
from datetime import datetime from datetime import datetime
from abc import ABC from abc import ABC
@@ -15,7 +15,10 @@ def lazy_property(function, field, *arg_fields):
def lazy(self): def lazy(self):
if self.__dict__.get(field) is None: if self.__dict__.get(field) is None:
self.__dict__[field] = function(*[getattr(self, arg_field) for arg_field in arg_fields]) self.__dict__[field] = function(*[getattr(self, arg_field) for arg_field in arg_fields])
try:
self.model_fields_set.add(field) self.model_fields_set.add(field)
except Exception: # noqa
pass
return self.__dict__[field] return self.__dict__[field]
return property(lazy) return property(lazy)
@@ -24,6 +27,7 @@ class Plane(model.Plane):
""" Lazily retrieve delta_t from metadata """ """ Lazily retrieve delta_t from metadata """
def __init__(self, t0, file, **kwargs): def __init__(self, t0, file, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# setting fields here because they would be removed by ome_types/pydantic after class definition
setattr(self.__class__, 'delta_t', lazy_property(self.get_delta_t, 'delta_t', 't0', 'file')) setattr(self.__class__, 'delta_t', lazy_property(self.get_delta_t, 'delta_t', 't0', 'file'))
setattr(self.__class__, 'delta_t_quantity', _quantity_property('delta_t')) setattr(self.__class__, 'delta_t_quantity', _quantity_property('delta_t'))
self.__dict__['t0'] = t0 self.__dict__['t0'] = t0
@@ -79,7 +83,7 @@ class Reader(AbstractReader, ABC):
else: else:
pixel_type = "uint16" # assume pixel_type = "uint16" # assume
size_c, size_z, size_t = [max(i) + 1 for i in zip(*self.filedict.keys())] size_c, size_z, size_t = (max(i) + 1 for i in zip(*self.filedict.keys()))
t0 = datetime.strptime(metadata["Info"]["Time"], "%Y-%m-%d %H:%M:%S %z") t0 = datetime.strptime(metadata["Info"]["Time"], "%Y-%m-%d %H:%M:%S %z")
ome.images.append( ome.images.append(
model.Image( model.Image(
@@ -123,7 +127,7 @@ class Reader(AbstractReader, ABC):
pattern_c = re.compile(r"img_\d{3,}_(.*)_\d{3,}$") pattern_c = re.compile(r"img_\d{3,}_(.*)_\d{3,}$")
pattern_z = re.compile(r"(\d{3,})$") pattern_z = re.compile(r"(\d{3,})$")
pattern_t = re.compile(r"img_(\d{3,})") pattern_t = re.compile(r"img_(\d{3,})")
self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]), self.filedict = {(cnamelist.index(pattern_c.findall(file.stem)[0]), # noqa
int(pattern_z.findall(file.stem)[0]), int(pattern_z.findall(file.stem)[0]),
int(pattern_t.findall(file.stem)[0])): file for file in filelist} int(pattern_t.findall(file.stem)[0])): file for file in filelist}

View File

@@ -17,7 +17,7 @@ class Reader(AbstractReader, ABC):
def _can_open(path): def _can_open(path):
if isinstance(path, Path) and path.suffix in ('.tif', '.tiff'): if isinstance(path, Path) and path.suffix in ('.tif', '.tiff'):
with tifffile.TiffFile(path) as tif: with tifffile.TiffFile(path) as tif:
return tif.is_imagej and tif.pages[-1]._nextifd() == 0 return tif.is_imagej and tif.pages[-1]._nextifd() == 0 # noqa
else: else:
return False return False
@@ -27,12 +27,12 @@ class Reader(AbstractReader, ABC):
for key, value in self.reader.imagej_metadata.items()} for key, value in self.reader.imagej_metadata.items()}
page = self.reader.pages[0] page = self.reader.pages[0]
self.p_ndim = page.ndim self.p_ndim = page.ndim # noqa
size_x = page.imagelength size_x = page.imagelength
size_y = page.imagewidth size_y = page.imagewidth
if self.p_ndim == 3: if self.p_ndim == 3:
size_c = page.samplesperpixel size_c = page.samplesperpixel
self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0] self.p_transpose = [i for i in [page.axes.find(j) for j in 'SYX'] if i >= 0] # noqa
size_t = metadata.get('frames', 1) # // C size_t = metadata.get('frames', 1) # // C
else: else:
size_c = metadata.get('channels', 1) size_c = metadata.get('channels', 1)

View File

@@ -1,19 +1,24 @@
import yaml import warnings
import re
import numpy as np
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import numpy as np
import yaml
from parfor import pmap, Chunks
from skimage import filters
from tiffwrite import IJTiffFile
from tqdm.auto import tqdm
try: try:
# best if SimpleElastix is installed: https://simpleelastix.readthedocs.io/GettingStarted.html # best if SimpleElastix is installed: https://simpleelastix.readthedocs.io/GettingStarted.html
import SimpleITK as sitk import SimpleITK as sitk # noqa
except ImportError: except ImportError:
sitk = None sitk = None
try: try:
from pandas import DataFrame, Series from pandas import DataFrame, Series, concat
except ImportError: except ImportError:
DataFrame, Series = None, None DataFrame, Series, concat = None, None, None
if hasattr(yaml, 'full_load'): if hasattr(yaml, 'full_load'):
@@ -24,10 +29,30 @@ else:
class Transforms(dict): class Transforms(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args[1:], **kwargs) super().__init__(*args, **kwargs)
self.default = Transform() self.default = Transform()
if len(args):
self.load(args[0]) @classmethod
def from_file(cls, file, C=True, T=True):
with open(Path(file).with_suffix('.yml')) as f:
return cls.from_dict(yamlload(f), C, T)
@classmethod
def from_dict(cls, d, C=True, T=True):
new = cls()
for key, value in d.items():
if isinstance(key, str) and C:
new[key.replace(r'\:', ':').replace('\\\\', '\\')] = Transform.from_dict(value)
elif T:
new[key] = Transform.from_dict(value)
return new
@classmethod
def from_shifts(cls, shifts):
new = cls()
for key, shift in shifts.items():
new[key] = Transform.from_shift(shift)
return new
def __mul__(self, other): def __mul__(self, other):
new = Transforms() new = Transforms()
@@ -44,18 +69,11 @@ class Transforms(dict):
return new return new
def asdict(self): def asdict(self):
return {':'.join(str(i).replace('\\', '\\\\').replace(':', r'\:') for i in key): value.asdict() return {key.replace('\\', '\\\\').replace(':', r'\:') if isinstance(key, str) else key: value.asdict()
for key, value in self.items()} for key, value in self.items()}
def load(self, file): def __getitem__(self, item):
if isinstance(file, dict): return np.prod([self[i] for i in item[::-1]]) if isinstance(item, tuple) else super().__getitem__(item)
d = file
else:
with open(file.with_suffix(".yml"), 'r') as f:
d = yamlload(f)
pattern = re.compile(r'[^\\]:')
for key, value in d.items():
self[tuple(i.replace(r'\:', ':').replace('\\\\', '\\') for i in pattern.split(key))] = Transform(value)
def __missing__(self, key): def __missing__(self, key):
return self.default return self.default
@@ -70,47 +88,242 @@ class Transforms(dict):
return hash(frozenset((*self.__dict__.items(), *self.items()))) return hash(frozenset((*self.__dict__.items(), *self.items())))
def save(self, file): def save(self, file):
with open(file.with_suffix(".yml"), 'w') as f: with open(Path(file).with_suffix(".yml"), 'w') as f:
yaml.safe_dump(self.asdict(), f, default_flow_style=None) yaml.safe_dump(self.asdict(), f, default_flow_style=None)
def copy(self): def copy(self):
return deepcopy(self) return deepcopy(self)
def adapt(self, origin, shape): def adapt(self, origin, shape, channel_names):
def key_map(a, b):
def fun(b, key_a):
for key_b in b:
if key_b in key_a or key_a in key_b:
return key_a, key_b
return {n[0]: n[1] for key_a in a if (n := fun(b, key_a))}
for value in self.values(): for value in self.values():
value.adapt(origin, shape) value.adapt(origin, shape)
self.default.adapt(origin, shape) self.default.adapt(origin, shape)
transform_channels = {key for key in self.keys() if isinstance(key, str)}
if set(channel_names) - transform_channels:
mapping = key_map(channel_names, transform_channels)
warnings.warn(f'The image file and the transform do not have the same channels,'
f' creating a mapping: {mapping}')
for key_im, key_t in mapping.items():
self[key_im] = self[key_t]
@property @property
def inverse(self): def inverse(self):
# TODO: check for C@T
inverse = self.copy() inverse = self.copy()
for key, value in self.items(): for key, value in self.items():
inverse[key] = value.inverse inverse[key] = value.inverse
return inverse return inverse
@property def coords_pandas(self, array, channel_names, columns=None):
def ndim(self): if isinstance(array, DataFrame):
return len(list(self.keys())[0]) return concat([self.coords_pandas(row, channel_names, columns) for _, row in array.iterrows()], axis=1).T
elif isinstance(array, Series):
key = []
if 'C' in array:
key.append(channel_names[int(array['C'])])
if 'T' in array:
key.append(int(array['T']))
return self[tuple(key)].coords(array, columns)
else:
raise TypeError('Not a pandas DataFrame or Series.')
def with_beads(self, cyllens, bead_files):
assert len(bead_files) > 0, "At least one file is needed to calculate the registration."
transforms = [self.calculate_channel_transforms(file, cyllens) for file in bead_files]
for key in {key for transform in transforms for key in transform.keys()}:
new_transforms = [transform[key] for transform in transforms if key in transform]
if len(new_transforms) == 1:
self[key] = new_transforms[0]
else:
self[key] = Transform()
self[key].parameters = np.mean([t.parameters for t in new_transforms], 0)
self[key].dparameters = (np.std([t.parameters for t in new_transforms], 0) /
np.sqrt(len(new_transforms))).tolist()
return self
@staticmethod
def get_bead_files(path):
from . import Imread
files = []
for file in path.iterdir():
if file.name.lower().startswith('beads'):
try:
with Imread(file):
files.append(file)
except Exception:
pass
files = sorted(files)
if not files:
raise Exception('No bead file found!')
checked_files = []
for file in files:
try:
if file.is_dir():
file /= 'Pos0'
with Imread(file): # check for errors opening the file
checked_files.append(file)
except (Exception,):
continue
if not checked_files:
raise Exception('No bead file found!')
return checked_files
@staticmethod
def calculate_channel_transforms(bead_file, cyllens):
""" When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162
in the horizontal direction """
from . import Imread
with Imread(bead_file, axes='zcxy') as im: # noqa
max_ims = im.max('z')
goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)]
if not goodch:
goodch = list(range(len(max_ims)))
untransformed = [c for c in range(im.shape['c']) if cyllens[im.detector[c]].lower() == 'none']
good_and_untrans = sorted(set(goodch) & set(untransformed))
if good_and_untrans:
masterch = good_and_untrans[0]
else:
masterch = goodch[0]
transform = Transform()
if not good_and_untrans:
matrix = transform.matrix
matrix[0, 0] = 0.86
transform.matrix = matrix
transforms = Transforms()
for c in tqdm(goodch, desc='Calculating channel transforms'): # noqa
if c == masterch:
transforms[im.channel_names[c]] = transform
else:
transforms[im.channel_names[c]] = Transform.register(max_ims[masterch], max_ims[c]) * transform
return transforms
@staticmethod
def save_channel_transform_tiff(bead_files, tiffile):
from . import Imread
n_channels = 0
for file in bead_files:
with Imread(file) as im:
n_channels = max(n_channels, im.shape['c'])
with IJTiffFile(tiffile, (n_channels, 1, len(bead_files))) as tif:
for t, file in enumerate(bead_files):
with Imread(file) as im:
with Imread(file).with_transform() as jm:
for c in range(im.shape['c']):
tif.save(np.hstack((im(c=c, t=0).max('z'), jm(c=c, t=0).max('z'))), c, 0, t)
def with_drift(self, im):
""" Calculate shifts relative to the first frame
divide the sequence into groups,
compare each frame to the frame in the middle of the group and compare these middle frames to each other
"""
im = im.transpose('tzycx')
t_groups = [list(chunk) for chunk in Chunks(range(im.shape['t']), size=round(np.sqrt(im.shape['t'])))]
t_keys = [int(np.round(np.mean(t_group))) for t_group in t_groups]
t_pairs = [(int(np.round(np.mean(t_group))), frame) for t_group in t_groups for frame in t_group]
t_pairs.extend(zip(t_keys, t_keys[1:]))
fmaxz_keys = {t_key: filters.gaussian(im[t_key].max('z'), 5) for t_key in t_keys}
def fun(t_key_t, im, fmaxz_keys):
t_key, t = t_key_t
if t_key == t:
return 0, 0
else:
fmaxz = filters.gaussian(im[t].max('z'), 5)
return Transform.register(fmaxz_keys[t_key], fmaxz, 'translation').parameters[4:]
shifts = np.array(pmap(fun, t_pairs, (im, fmaxz_keys), desc='Calculating image shifts.'))
shift_keys_cum = np.zeros(2)
for shift_keys, t_group in zip(np.vstack((-shifts[0], shifts[im.shape['t']:])), t_groups):
shift_keys_cum += shift_keys
shifts[t_group] += shift_keys_cum
for i, shift in enumerate(shifts[:im.shape['t']]):
self[i] = Transform.from_shift(shift)
return self
class Transform: class Transform:
def __init__(self, *args): def __init__(self):
if sitk is None: if sitk is None:
raise ImportError('SimpleElastix is not installed: ' raise ImportError('SimpleElastix is not installed: '
'https://simpleelastix.readthedocs.io/GettingStarted.html') 'https://simpleelastix.readthedocs.io/GettingStarted.html')
self.transform = sitk.ReadTransform(str(Path(__file__).parent / 'transform.txt')) self.transform = sitk.ReadTransform(str(Path(__file__).parent / 'transform.txt'))
self.dparameters = 0, 0, 0, 0, 0, 0 self.dparameters = [0., 0., 0., 0., 0., 0.]
self.shape = 512, 512 self.shape = [512., 512.]
self.origin = 255.5, 255.5 self.origin = [255.5, 255.5]
if len(args) == 1: # load from file or dict
if isinstance(args[0], np.ndarray):
self.matrix = args[0]
else:
self.load(*args)
elif len(args) > 1: # make new transform using fixed and moving image
self.register(*args)
self._last, self._inverse = None, None self._last, self._inverse = None, None
def __reduce__(self):
return self.from_dict, (self.asdict(),)
def __repr__(self):
return self.asdict().__repr__()
def __str__(self):
return self.asdict().__str__()
@classmethod
def register(cls, fix, mov, kind=None):
""" kind: 'affine', 'translation', 'rigid' """
new = cls()
kind = kind or 'affine'
new.shape = fix.shape
fix, mov = new.cast_image(fix), new.cast_image(mov)
# TODO: implement RigidTransform
tfilter = sitk.ElastixImageFilter()
tfilter.LogToConsoleOff()
tfilter.SetFixedImage(fix)
tfilter.SetMovingImage(mov)
tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind))
tfilter.Execute()
transform = tfilter.GetTransformParameterMap()[0]
if kind == 'affine':
new.parameters = [float(t) for t in transform['TransformParameters']]
new.shape = [float(t) for t in transform['Size']]
new.origin = [float(t) for t in transform['CenterOfRotationPoint']]
elif kind == 'translation':
new.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']]
new.shape = [float(t) for t in transform['Size']]
new.origin = [(t - 1) / 2 for t in new.shape]
else:
raise NotImplementedError(f'{kind} tranforms not implemented (yet)')
new.dparameters = 6 * [np.nan]
return new
@classmethod
def from_shift(cls, shift):
return cls.from_array(np.array(((1, 0, shift[0]), (0, 1, shift[1]), (0, 0, 1))))
@classmethod
def from_array(cls, array):
new = cls()
new.matrix = array
return new
@classmethod
def from_file(cls, file):
with open(Path(file).with_suffix('.yml')) as f:
return cls.from_dict(yamlload(f))
@classmethod
def from_dict(cls, d):
new = cls()
new.origin = [float(i) for i in d['CenterOfRotationPoint']]
new.parameters = [float(i) for i in d['TransformParameters']]
new.dparameters = [float(i) for i in d['dTransformParameters']] if 'dTransformParameters' in d else 6 * [np.nan]
new.shape = [float(i) for i in d['Size']]
return new
def __mul__(self, other): # TODO: take care of dmatrix def __mul__(self, other): # TODO: take care of dmatrix
result = self.copy() result = self.copy()
if isinstance(other, Transform): if isinstance(other, Transform):
@@ -121,9 +334,6 @@ class Transform:
result.dmatrix = self.dmatrix @ other result.dmatrix = self.dmatrix @ other
return result return result
def __reduce__(self):
return self.__class__, (self.asdict(),)
def is_unity(self): def is_unity(self):
return self.parameters == [1, 0, 0, 1, 0, 0] return self.parameters == [1, 0, 0, 1, 0, 0]
@@ -166,7 +376,7 @@ class Transform:
@property @property
def parameters(self): def parameters(self):
return self.transform.GetParameters() return list(self.transform.GetParameters())
@parameters.setter @parameters.setter
def parameters(self, value): def parameters(self, value):
@@ -186,7 +396,7 @@ class Transform:
def inverse(self): def inverse(self):
if self._last is None or self._last != self.asdict(): if self._last is None or self._last != self.asdict():
self._last = self.asdict() self._last = self.asdict()
self._inverse = Transform(self.asdict()) self._inverse = Transform.from_dict(self.asdict())
self._inverse.transform = self._inverse.transform.GetInverse() self._inverse.transform = self._inverse.transform.GetInverse()
self._inverse._last = self._inverse.asdict() self._inverse._last = self._inverse.asdict()
self._inverse._inverse = self self._inverse._inverse = self
@@ -234,45 +444,3 @@ class Transform:
file += '.yml' file += '.yml'
with open(file, 'w') as f: with open(file, 'w') as f:
yaml.safe_dump(self.asdict(), f, default_flow_style=None) yaml.safe_dump(self.asdict(), f, default_flow_style=None)
def load(self, file):
""" load the parameters of a transform from a yaml file or a dict
"""
if isinstance(file, dict):
d = file
else:
if not file[-3:] == 'yml':
file += '.yml'
with open(file, 'r') as f:
d = yamlload(f)
self.origin = [float(i) for i in d['CenterOfRotationPoint']]
self.parameters = [float(i) for i in d['TransformParameters']]
self.dparameters = [float(i) for i in d['dTransformParameters']] \
if 'dTransformParameters' in d else 6 * [np.nan]
self.shape = [float(i) for i in d['Size']]
def register(self, fix, mov, kind=None):
""" kind: 'affine', 'translation', 'rigid'
"""
kind = kind or 'affine'
self.shape = fix.shape
fix, mov = self.cast_image(fix), self.cast_image(mov)
# TODO: implement RigidTransform
tfilter = sitk.ElastixImageFilter()
tfilter.LogToConsoleOff()
tfilter.SetFixedImage(fix)
tfilter.SetMovingImage(mov)
tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind))
tfilter.Execute()
transform = tfilter.GetTransformParameterMap()[0]
if kind == 'affine':
self.parameters = [float(t) for t in transform['TransformParameters']]
self.shape = [float(t) for t in transform['Size']]
self.origin = [float(t) for t in transform['CenterOfRotationPoint']]
elif kind == 'translation':
self.parameters = [1.0, 0.0, 0.0, 1.0] + [float(t) for t in transform['TransformParameters']]
self.shape = [float(t) for t in transform['Size']]
self.origin = [(t - 1) / 2 for t in self.shape]
else:
raise NotImplementedError(f'{kind} tranforms not implemented (yet)')
self.dparameters = 6 * [np.nan]

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "ndbioimage" name = "ndbioimage"
version = "2023.10.1" version = "2023.10.2"
description = "Bio image reading, metadata and some affine registration." description = "Bio image reading, metadata and some affine registration."
authors = ["W. Pomp <w.pomp@nki.nl>"] authors = ["W. Pomp <w.pomp@nki.nl>"]
license = "GPLv3" license = "GPLv3"
@@ -22,7 +22,7 @@ pint = "*"
tqdm = "*" tqdm = "*"
lxml = "*" lxml = "*"
pyyaml = "*" pyyaml = "*"
parfor = ">=2023.8.2" parfor = ">=2023.10.1"
JPype1 = "*" JPype1 = "*"
SimpleITK-SimpleElastix = "*" SimpleITK-SimpleElastix = "*"
pytest = { version = "*", optional = true } pytest = { version = "*", optional = true }