573 lines
19 KiB
Python
573 lines
19 KiB
Python
import warnings
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import yaml
|
|
from parfor import Chunks, pmap
|
|
from skimage import filters
|
|
from tiffwrite import IJTiffFile
|
|
from tqdm.auto import tqdm
|
|
|
|
try:
|
|
# best if SimpleElastix is installed: https://simpleelastix.readthedocs.io/GettingStarted.html
|
|
import SimpleITK as sitk # noqa
|
|
except ImportError:
|
|
sitk = None
|
|
|
|
try:
|
|
from pandas import DataFrame, Series, concat
|
|
except ImportError:
|
|
DataFrame, Series, concat = None, None, None
|
|
|
|
|
|
if hasattr(yaml, "full_load"):
|
|
yamlload = yaml.full_load
|
|
else:
|
|
yamlload = yaml.load
|
|
|
|
|
|
class Transforms(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.default = Transform()
|
|
|
|
@classmethod
|
|
def from_file(cls, file, C=True, T=True):
|
|
with open(Path(file).with_suffix(".yml")) as f:
|
|
return cls.from_dict(yamlload(f), C, T)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d, C=True, T=True):
|
|
new = cls()
|
|
for key, value in d.items():
|
|
if isinstance(key, str) and C:
|
|
new[key.replace(r"\:", ":").replace("\\\\", "\\")] = (
|
|
Transform.from_dict(value)
|
|
)
|
|
elif T:
|
|
new[key] = Transform.from_dict(value)
|
|
return new
|
|
|
|
@classmethod
|
|
def from_shifts(cls, shifts):
|
|
new = cls()
|
|
for key, shift in shifts.items():
|
|
new[key] = Transform.from_shift(shift)
|
|
return new
|
|
|
|
def __mul__(self, other):
|
|
new = Transforms()
|
|
if isinstance(other, Transforms):
|
|
for key0, value0 in self.items():
|
|
for key1, value1 in other.items():
|
|
new[key0 + key1] = value0 * value1
|
|
return new
|
|
elif other is None:
|
|
return self
|
|
else:
|
|
for key in self.keys():
|
|
new[key] = self[key] * other
|
|
return new
|
|
|
|
def asdict(self):
|
|
return {
|
|
key.replace("\\", "\\\\").replace(":", r"\:")
|
|
if isinstance(key, str)
|
|
else key: value.asdict()
|
|
for key, value in self.items()
|
|
}
|
|
|
|
def __getitem__(self, item):
|
|
return (
|
|
np.prod([self[i] for i in item[::-1]])
|
|
if isinstance(item, tuple)
|
|
else super().__getitem__(item)
|
|
)
|
|
|
|
def __missing__(self, key):
|
|
return self.default
|
|
|
|
def __getstate__(self):
|
|
return self.__dict__
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
|
|
def __hash__(self):
|
|
return hash(frozenset((*self.__dict__.items(), *self.items())))
|
|
|
|
def save(self, file):
|
|
with open(Path(file).with_suffix(".yml"), "w") as f:
|
|
yaml.safe_dump(self.asdict(), f, default_flow_style=None)
|
|
|
|
def copy(self):
|
|
return deepcopy(self)
|
|
|
|
def adapt(self, origin, shape, channel_names):
|
|
def key_map(a, b):
|
|
def fun(b, key_a):
|
|
for key_b in b:
|
|
if key_b in key_a or key_a in key_b:
|
|
return key_a, key_b
|
|
|
|
return {n[0]: n[1] for key_a in a if (n := fun(b, key_a))}
|
|
|
|
for value in self.values():
|
|
value.adapt(origin, shape)
|
|
self.default.adapt(origin, shape)
|
|
transform_channels = {key for key in self.keys() if isinstance(key, str)}
|
|
if set(channel_names) - transform_channels:
|
|
mapping = key_map(channel_names, transform_channels)
|
|
warnings.warn(
|
|
f"The image file and the transform do not have the same channels,"
|
|
f" creating a mapping: {mapping}"
|
|
)
|
|
for key_im, key_t in mapping.items():
|
|
self[key_im] = self[key_t]
|
|
|
|
@property
|
|
def inverse(self):
|
|
# TODO: check for C@T
|
|
inverse = self.copy()
|
|
for key, value in self.items():
|
|
inverse[key] = value.inverse
|
|
return inverse
|
|
|
|
def coords_pandas(self, array, channel_names, columns=None):
|
|
if isinstance(array, DataFrame):
|
|
return concat(
|
|
[
|
|
self.coords_pandas(row, channel_names, columns)
|
|
for _, row in array.iterrows()
|
|
],
|
|
axis=1,
|
|
).T
|
|
elif isinstance(array, Series):
|
|
key = []
|
|
if "C" in array:
|
|
key.append(channel_names[int(array["C"])])
|
|
if "T" in array:
|
|
key.append(int(array["T"]))
|
|
return self[tuple(key)].coords(array, columns)
|
|
else:
|
|
raise TypeError("Not a pandas DataFrame or Series.")
|
|
|
|
def with_beads(self, cyllens, bead_files):
|
|
assert len(bead_files) > 0, (
|
|
"At least one file is needed to calculate the registration."
|
|
)
|
|
transforms = [
|
|
self.calculate_channel_transforms(file, cyllens) for file in bead_files
|
|
]
|
|
for key in {key for transform in transforms for key in transform.keys()}:
|
|
new_transforms = [
|
|
transform[key] for transform in transforms if key in transform
|
|
]
|
|
if len(new_transforms) == 1:
|
|
self[key] = new_transforms[0]
|
|
else:
|
|
self[key] = Transform()
|
|
self[key].parameters = np.mean(
|
|
[t.parameters for t in new_transforms], 0
|
|
)
|
|
self[key].dparameters = (
|
|
np.std([t.parameters for t in new_transforms], 0)
|
|
/ np.sqrt(len(new_transforms))
|
|
).tolist()
|
|
return self
|
|
|
|
@staticmethod
|
|
def get_bead_files(path):
|
|
from . import Imread
|
|
|
|
files = []
|
|
for file in path.iterdir():
|
|
if file.name.lower().startswith("beads"):
|
|
try:
|
|
with Imread(file):
|
|
files.append(file)
|
|
except Exception:
|
|
pass
|
|
files = sorted(files)
|
|
if not files:
|
|
raise Exception("No bead file found!")
|
|
checked_files = []
|
|
for file in files:
|
|
try:
|
|
if file.is_dir():
|
|
file /= "Pos0"
|
|
with Imread(file): # check for errors opening the file
|
|
checked_files.append(file)
|
|
except (Exception,):
|
|
continue
|
|
if not checked_files:
|
|
raise Exception("No bead file found!")
|
|
return checked_files
|
|
|
|
@staticmethod
|
|
def calculate_channel_transforms(bead_file, cyllens):
|
|
"""When no channel is not transformed by a cylindrical lens, assume that the image is scaled by a factor 1.162
|
|
in the horizontal direction"""
|
|
from . import Imread
|
|
|
|
with Imread(bead_file, axes="zcyx") as im: # noqa
|
|
max_ims = im.max("z")
|
|
goodch = [c for c, max_im in enumerate(max_ims) if not im.is_noise(max_im)]
|
|
if not goodch:
|
|
goodch = list(range(len(max_ims)))
|
|
untransformed = [
|
|
c
|
|
for c in range(im.shape["c"])
|
|
if cyllens[im.detector[c]].lower() == "none"
|
|
]
|
|
|
|
good_and_untrans = sorted(set(goodch) & set(untransformed))
|
|
if good_and_untrans:
|
|
masterch = good_and_untrans[0]
|
|
else:
|
|
masterch = goodch[0]
|
|
transform = Transform()
|
|
if not good_and_untrans:
|
|
matrix = transform.matrix
|
|
matrix[0, 0] = 0.86
|
|
transform.matrix = matrix
|
|
transforms = Transforms()
|
|
for c in tqdm(goodch, desc="Calculating channel transforms"): # noqa
|
|
if c == masterch:
|
|
transforms[im.channel_names[c]] = transform
|
|
else:
|
|
transforms[im.channel_names[c]] = (
|
|
Transform.register(max_ims[masterch], max_ims[c]) * transform
|
|
)
|
|
return transforms
|
|
|
|
@staticmethod
|
|
def save_channel_transform_tiff(bead_files, tiffile):
|
|
from . import Imread
|
|
|
|
n_channels = 0
|
|
for file in bead_files:
|
|
with Imread(file) as im:
|
|
n_channels = max(n_channels, im.shape["c"])
|
|
with IJTiffFile(tiffile) as tif:
|
|
for t, file in enumerate(bead_files):
|
|
with Imread(file) as im:
|
|
with Imread(file).with_transform() as jm:
|
|
for c in range(im.shape["c"]):
|
|
tif.save(
|
|
np.hstack(
|
|
(im(c=c, t=0).max("z"), jm(c=c, t=0).max("z"))
|
|
),
|
|
c,
|
|
0,
|
|
t,
|
|
)
|
|
|
|
def with_drift(self, im):
|
|
"""Calculate shifts relative to the first frame
|
|
divide the sequence into groups,
|
|
compare each frame to the frame in the middle of the group and compare these middle frames to each other
|
|
"""
|
|
im = im.transpose("tzycx")
|
|
t_groups = [
|
|
list(chunk)
|
|
for chunk in Chunks(
|
|
range(im.shape["t"]), size=round(np.sqrt(im.shape["t"]))
|
|
)
|
|
]
|
|
t_keys = [int(np.round(np.mean(t_group))) for t_group in t_groups]
|
|
t_pairs = [
|
|
(int(np.round(np.mean(t_group))), frame)
|
|
for t_group in t_groups
|
|
for frame in t_group
|
|
]
|
|
t_pairs.extend(zip(t_keys, t_keys[1:]))
|
|
fmaxz_keys = {
|
|
t_key: filters.gaussian(im[t_key].max("z"), 5) for t_key in t_keys
|
|
}
|
|
|
|
def fun(t_key_t, im, fmaxz_keys):
|
|
t_key, t = t_key_t
|
|
if t_key == t:
|
|
return 0, 0
|
|
else:
|
|
fmaxz = filters.gaussian(im[t].max("z"), 5)
|
|
return Transform.register(
|
|
fmaxz_keys[t_key], fmaxz, "translation"
|
|
).parameters[4:]
|
|
|
|
shifts = np.array(
|
|
pmap(fun, t_pairs, (im, fmaxz_keys), desc="Calculating image shifts.")
|
|
)
|
|
shift_keys_cum = np.zeros(2)
|
|
for shift_keys, t_group in zip(
|
|
np.vstack((-shifts[0], shifts[im.shape["t"] :])), t_groups
|
|
):
|
|
shift_keys_cum += shift_keys
|
|
shifts[t_group] += shift_keys_cum
|
|
|
|
for i, shift in enumerate(shifts[: im.shape["t"]]):
|
|
self[i] = Transform.from_shift(shift)
|
|
return self
|
|
|
|
|
|
class Transform:
|
|
def __init__(self):
|
|
if sitk is None:
|
|
self.transform = None
|
|
else:
|
|
self.transform = sitk.ReadTransform(
|
|
str(Path(__file__).parent / "transform.txt")
|
|
)
|
|
self.dparameters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
|
self.shape = [512.0, 512.0]
|
|
self.origin = [255.5, 255.5]
|
|
self._last, self._inverse = None, None
|
|
|
|
def __reduce__(self):
|
|
return self.from_dict, (self.asdict(),)
|
|
|
|
def __repr__(self):
|
|
return self.asdict().__repr__()
|
|
|
|
def __str__(self):
|
|
return self.asdict().__str__()
|
|
|
|
@classmethod
|
|
def register(cls, fix, mov, kind=None):
|
|
"""kind: 'affine', 'translation', 'rigid'"""
|
|
if sitk is None:
|
|
raise ImportError(
|
|
"SimpleElastix is not installed: "
|
|
"https://simpleelastix.readthedocs.io/GettingStarted.html"
|
|
)
|
|
new = cls()
|
|
kind = kind or "affine"
|
|
new.shape = fix.shape
|
|
fix, mov = new.cast_image(fix), new.cast_image(mov)
|
|
# TODO: implement RigidTransform
|
|
tfilter = sitk.ElastixImageFilter()
|
|
tfilter.LogToConsoleOff()
|
|
tfilter.SetFixedImage(fix)
|
|
tfilter.SetMovingImage(mov)
|
|
tfilter.SetParameterMap(sitk.GetDefaultParameterMap(kind))
|
|
tfilter.Execute()
|
|
transform = tfilter.GetTransformParameterMap()[0]
|
|
if kind == "affine":
|
|
new.parameters = [float(t) for t in transform["TransformParameters"]]
|
|
new.shape = [float(t) for t in transform["Size"]]
|
|
new.origin = [float(t) for t in transform["CenterOfRotationPoint"]]
|
|
elif kind == "translation":
|
|
new.parameters = [1.0, 0.0, 0.0, 1.0] + [
|
|
float(t) for t in transform["TransformParameters"]
|
|
]
|
|
new.shape = [float(t) for t in transform["Size"]]
|
|
new.origin = [(t - 1) / 2 for t in new.shape]
|
|
else:
|
|
raise NotImplementedError(f"{kind} tranforms not implemented (yet)")
|
|
new.dparameters = 6 * [np.nan]
|
|
return new
|
|
|
|
@classmethod
|
|
def from_shift(cls, shift):
|
|
return cls.from_array(np.array(((1, 0, shift[0]), (0, 1, shift[1]), (0, 0, 1))))
|
|
|
|
@classmethod
|
|
def from_array(cls, array):
|
|
new = cls()
|
|
new.matrix = array
|
|
return new
|
|
|
|
@classmethod
|
|
def from_file(cls, file):
|
|
with open(Path(file).with_suffix(".yml")) as f:
|
|
return cls.from_dict(yamlload(f))
|
|
|
|
@classmethod
|
|
def from_dict(cls, d):
|
|
new = cls()
|
|
new.origin = (
|
|
None
|
|
if d["CenterOfRotationPoint"] is None
|
|
else [float(i) for i in d["CenterOfRotationPoint"]]
|
|
)
|
|
new.parameters = (
|
|
(1.0, 0.0, 0.0, 1.0, 0.0, 0.0)
|
|
if d["TransformParameters"] is None
|
|
else [float(i) for i in d["TransformParameters"]]
|
|
)
|
|
new.dparameters = (
|
|
[
|
|
(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) if i is None else float(i)
|
|
for i in d["dTransformParameters"]
|
|
]
|
|
if "dTransformParameters" in d
|
|
else 6 * [np.nan] and d["dTransformParameters"] is not None
|
|
)
|
|
new.shape = (
|
|
None
|
|
if d["Size"] is None
|
|
else [None if i is None else float(i) for i in d["Size"]]
|
|
)
|
|
return new
|
|
|
|
def __mul__(self, other): # TODO: take care of dmatrix
|
|
result = self.copy()
|
|
if isinstance(other, Transform):
|
|
result.matrix = self.matrix @ other.matrix
|
|
result.dmatrix = self.dmatrix @ other.matrix + self.matrix @ other.dmatrix
|
|
else:
|
|
result.matrix = self.matrix @ other
|
|
result.dmatrix = self.dmatrix @ other
|
|
return result
|
|
|
|
def is_unity(self):
|
|
return self.parameters == [1, 0, 0, 1, 0, 0]
|
|
|
|
def copy(self):
|
|
return deepcopy(self)
|
|
|
|
@staticmethod
|
|
def cast_image(im):
|
|
if not isinstance(im, sitk.Image):
|
|
im = sitk.GetImageFromArray(np.asarray(im))
|
|
return im
|
|
|
|
@staticmethod
|
|
def cast_array(im):
|
|
if isinstance(im, sitk.Image):
|
|
im = sitk.GetArrayFromImage(im)
|
|
return im
|
|
|
|
@property
|
|
def matrix(self):
|
|
return np.array(
|
|
(
|
|
(*self.parameters[:2], self.parameters[4]),
|
|
(*self.parameters[2:4], self.parameters[5]),
|
|
(0, 0, 1),
|
|
)
|
|
)
|
|
|
|
@matrix.setter
|
|
def matrix(self, value):
|
|
value = np.asarray(value)
|
|
self.parameters = [*value[0, :2], *value[1, :2], *value[:2, 2]]
|
|
|
|
@property
|
|
def dmatrix(self):
|
|
return np.array(
|
|
(
|
|
(*self.dparameters[:2], self.dparameters[4]),
|
|
(*self.dparameters[2:4], self.dparameters[5]),
|
|
(0, 0, 0),
|
|
)
|
|
)
|
|
|
|
@dmatrix.setter
|
|
def dmatrix(self, value):
|
|
value = np.asarray(value)
|
|
self.dparameters = [*value[0, :2], *value[1, :2], *value[:2, 2]]
|
|
|
|
@property
|
|
def parameters(self):
|
|
if self.transform is not None:
|
|
return list(self.transform.GetParameters())
|
|
else:
|
|
return [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]
|
|
|
|
@parameters.setter
|
|
def parameters(self, value):
|
|
if self.transform is not None:
|
|
value = np.asarray(value)
|
|
self.transform.SetParameters(value.tolist())
|
|
|
|
@property
|
|
def origin(self):
|
|
if self.transform is not None:
|
|
return self.transform.GetFixedParameters()
|
|
|
|
@origin.setter
|
|
def origin(self, value):
|
|
if self.transform is not None:
|
|
value = np.asarray(value)
|
|
self.transform.SetFixedParameters(value.tolist())
|
|
|
|
@property
|
|
def inverse(self):
|
|
if self.is_unity():
|
|
return self
|
|
if self._last is None or self._last != self.asdict():
|
|
self._last = self.asdict()
|
|
self._inverse = Transform.from_dict(self.asdict())
|
|
self._inverse.transform = self._inverse.transform.GetInverse()
|
|
self._inverse._last = self._inverse.asdict()
|
|
self._inverse._inverse = self
|
|
return self._inverse
|
|
|
|
def adapt(self, origin, shape):
|
|
self.origin -= np.array(origin) + (self.shape - np.array(shape)[:2]) / 2
|
|
self.shape = shape[:2]
|
|
|
|
def asdict(self):
|
|
return {
|
|
"CenterOfRotationPoint": self.origin,
|
|
"Size": self.shape,
|
|
"TransformParameters": self.parameters,
|
|
"dTransformParameters": np.nan_to_num(self.dparameters, nan=1e99).tolist(),
|
|
}
|
|
|
|
def frame(self, im, default=0):
|
|
if self.is_unity():
|
|
return im
|
|
else:
|
|
if sitk is None:
|
|
raise ImportError(
|
|
"SimpleElastix is not installed: "
|
|
"https://simpleelastix.readthedocs.io/GettingStarted.html"
|
|
)
|
|
dtype = im.dtype
|
|
im = im.astype("float")
|
|
intp = (
|
|
sitk.sitkBSpline
|
|
if np.issubdtype(dtype, np.floating)
|
|
else sitk.sitkNearestNeighbor
|
|
)
|
|
return self.cast_array(
|
|
sitk.Resample(self.cast_image(im), self.transform, intp, default)
|
|
).astype(dtype)
|
|
|
|
def coords(self, array, columns=None):
|
|
"""Transform coordinates in 2 column numpy array,
|
|
or in pandas DataFrame or Series objects in columns ['x', 'y']
|
|
"""
|
|
if self.is_unity():
|
|
return array.copy()
|
|
elif DataFrame is not None and isinstance(array, (DataFrame, Series)):
|
|
columns = columns or ["x", "y"]
|
|
array = array.copy()
|
|
if isinstance(array, DataFrame):
|
|
array[columns] = self.coords(np.atleast_2d(array[columns].to_numpy()))
|
|
elif isinstance(array, Series):
|
|
array[columns] = self.coords(np.atleast_2d(array[columns].to_numpy()))[
|
|
0
|
|
]
|
|
return array
|
|
else: # somehow we need to use the inverse here to get the same effect as when using self.frame
|
|
return np.array(
|
|
[
|
|
self.inverse.transform.TransformPoint(i.tolist())
|
|
for i in np.asarray(array)
|
|
]
|
|
)
|
|
|
|
def save(self, file):
|
|
"""save the parameters of the transform calculated
|
|
with affine_registration to a yaml file
|
|
"""
|
|
if not file[-3:] == "yml":
|
|
file += ".yml"
|
|
with open(file, "w") as f:
|
|
yaml.safe_dump(self.asdict(), f, default_flow_style=None)
|