diff --git a/Cargo.toml b/Cargo.toml index aa9db4a..e98959a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sitk-registration-sys" -version = "2025.3.1" +version = "2025.3.2" edition = "2024" license = "MIT OR Apache-2.0" description = "register and interpolate images" @@ -14,11 +14,11 @@ keywords = ["registration", "affine", "bspline", "transform"] categories = ["multimedia::images", "science"] [lib] -name = "sitk_regsitration_sys" +name = "sitk_registration_sys" crate-type = ["cdylib", "rlib"] [dependencies] -anyhow = "1.0.96" +anyhow = "1.0.97" libc = "0.2.170" ndarray = "0.16.1" num = "0.4.3" @@ -28,4 +28,7 @@ serde_yaml = "0.9.33" [build-dependencies] cmake = "0.1.54" -git2 = "0.20.0" \ No newline at end of file +git2 = "0.20.0" + +[dev-dependencies] +tempfile = "3.18.0" \ No newline at end of file diff --git a/cpp/sitk_adapter.cxx b/cpp/sitk_adapter.cxx index 52e9a85..b65d997 100644 --- a/cpp/sitk_adapter.cxx +++ b/cpp/sitk_adapter.cxx @@ -286,8 +286,8 @@ reg( try { string kind = (t_or_a == false) ? "translation" : "affine"; // std::filesystem::path output_path = std::filesystem::temp_directory_path() / gen_random(12); - std::filesystem::path output_path = std::filesystem::temp_directory_path(); // std::filesystem::create_directory(output_path); + std::filesystem::path output_path = std::filesystem::temp_directory_path(); sitk::ElastixImageFilter tfilter = sitk::ElastixImageFilter(); tfilter.LogToConsoleOff(); @@ -298,11 +298,7 @@ reg( tfilter.SetParameterMap(sitk::GetDefaultParameterMap(kind)); tfilter.SetParameter("WriteResultImage", "false"); tfilter.SetOutputDirectory(output_path); -// cout << "output_path: " << output_path << endl; - // tfilter.PrintParameterMap(); -// cout << "r6 " << std::flush; tfilter.Execute(); -// cout << "r7 " << std::flush; sitk::ElastixImageFilter::ParameterMapType parameter_map = tfilter.GetTransformParameterMap(0); for (sitk::ElastixImageFilter::ParameterMapType::iterator parameter = parameter_map.begin(); parameter != parameter_map.end(); ++parameter) { if (parameter->first == "TransformParameters") { @@ -323,7 +319,6 @@ reg( break; } } -// cout << "r8 " << std::flush; } catch (const std::exception &exc) { cerr << exc.what(); } @@ -334,14 +329,14 @@ extern "C" void register_u8( unsigned int width, unsigned int height, - uint8_t* fixed_arr, - uint8_t* moving_arr, + uint8_t** fixed_arr, + uint8_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkUInt8; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -349,14 +344,14 @@ extern "C" void register_i8( unsigned int width, unsigned int height, - int8_t* fixed_arr, - int8_t* moving_arr, + int8_t** fixed_arr, + int8_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkInt8; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -364,14 +359,14 @@ extern "C" void register_u16( unsigned int width, unsigned int height, - uint16_t* fixed_arr, - uint16_t* moving_arr, + uint16_t** fixed_arr, + uint16_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkUInt16; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -379,14 +374,14 @@ extern "C" void register_i16( unsigned int width, unsigned int height, - int16_t* fixed_arr, - int16_t* moving_arr, + int16_t** fixed_arr, + int16_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkInt16; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -394,14 +389,14 @@ extern "C" void register_u32( unsigned int width, unsigned int height, - uint32_t* fixed_arr, - uint32_t* moving_arr, + uint32_t** fixed_arr, + uint32_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkUInt32; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -409,14 +404,14 @@ extern "C" void register_i32( unsigned int width, unsigned int height, - int32_t* fixed_arr, - int32_t* moving_arr, + int32_t** fixed_arr, + int32_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkInt32; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -424,14 +419,14 @@ extern "C" void register_u64( unsigned int width, unsigned int height, - uint64_t* fixed_arr, - uint64_t* moving_arr, + uint64_t** fixed_arr, + uint64_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkUInt64; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -439,14 +434,14 @@ extern "C" void register_i64( unsigned int width, unsigned int height, - int64_t* fixed_arr, - int64_t* moving_arr, + int64_t** fixed_arr, + int64_t** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkInt64; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -454,14 +449,14 @@ extern "C" void register_f32( unsigned int width, unsigned int height, - float* fixed_arr, - float* moving_arr, + float** fixed_arr, + float** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkFloat32; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } @@ -469,13 +464,13 @@ extern "C" void register_f64( unsigned int width, unsigned int height, - double* fixed_arr, - double* moving_arr, + double** fixed_arr, + double** moving_arr, bool t_or_a, double** transform ) { sitk::PixelIDValueEnum id = sitk::PixelIDValueEnum::sitkFloat64; - sitk::Image fixed = make_image(width, height, fixed_arr, id); - sitk::Image moving = make_image(width, height, moving_arr, id); + sitk::Image fixed = make_image(width, height, *fixed_arr, id); + sitk::Image moving = make_image(width, height, *moving_arr, id); reg(fixed, moving, t_or_a, transform); } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index d813138..057a7e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,52 @@ mod sys; -use crate::sys::{PixelType, interp, register}; +use crate::sys::{interp, register}; use anyhow::{Result, anyhow}; -use ndarray::{Array2, ArrayView2, array, s}; +use ndarray::{Array2, ArrayView2, AsArray, Ix2, array, s}; use serde::{Deserialize, Serialize}; use serde_yaml::{from_reader, to_writer}; use std::fs::File; use std::ops::Mul; use std::path::PathBuf; +/// a trait marking number types that can be used in sitk: +/// (u/i)(8/16/32/64), (u/i)size, f(32/64) +pub trait PixelType: Clone { + const PT: u8; +} + +macro_rules! sitk_impl { + ($($T:ty: $sitk:expr $(,)?)*) => { + $( + impl PixelType for $T { + const PT: u8 = $sitk; + } + )* + }; +} + +sitk_impl! { + u8: 1, + i8: 2, + u16: 3, + i16: 4, + u32: 5, + i32: 6, + u64: 7, + i64: 8, + f32: 9, + f64: 10, +} + +#[cfg(target_pointer_width = "64")] +sitk_impl!(usize: 7); +#[cfg(target_pointer_width = "32")] +sitk_impl!(usize: 5); +#[cfg(target_pointer_width = "64")] +sitk_impl!(isize: 8); +#[cfg(target_pointer_width = "32")] +sitk_impl!(isize: 6); + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Transform { pub parameters: [f64; 6], @@ -47,6 +85,17 @@ impl Mul for Transform { } } +impl PartialEq for Transform { + fn eq(&self, other: &Self) -> bool { + self.parameters == other.parameters + && self.dparameters == other.dparameters + && self.origin == other.origin + && self.shape == other.shape + } +} + +impl Eq for Transform {} + impl Transform { /// parameters: flat 2x2 part of matrix, translation; origin: center of rotation pub fn new(parameters: [f64; 6], origin: [f64; 2], shape: [usize; 2]) -> Self { @@ -59,10 +108,11 @@ impl Transform { } /// find the affine transform which transforms moving into fixed - pub fn register_affine( - fixed: ArrayView2, - moving: ArrayView2, - ) -> Result { + pub fn register_affine<'a, A, T>(fixed: A, moving: A) -> Result + where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, + { let (parameters, origin, shape) = register(fixed, moving, true)?; Ok(Transform { parameters, @@ -73,10 +123,11 @@ impl Transform { } /// find the translation which transforms moving into fixed - pub fn register_translation( - fixed: ArrayView2, - moving: ArrayView2, - ) -> Result { + pub fn register_translation<'a, A, T>(fixed: A, moving: A) -> Result + where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, + { let (parameters, origin, shape) = register(fixed, moving, false)?; Ok(Transform { parameters, @@ -104,7 +155,11 @@ impl Transform { /// write a transform to a file pub fn to_file(&self, path: PathBuf) -> Result<()> { - let mut file = File::open(path)?; + let mut file = std::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(path)?; to_writer(&mut file, self)?; Ok(()) } @@ -115,23 +170,30 @@ impl Transform { } /// transform an image using nearest neighbor interpolation - pub fn transform_image_bspline(&self, image: ArrayView2) -> Result> { + pub fn transform_image_bspline<'a, A, T>(&self, image: A) -> Result> + where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, + { interp(self.parameters, self.origin, image, false) } /// transform an image using bspline interpolation - pub fn transform_image_nearest_neighbor( - &self, - image: ArrayView2, - ) -> Result> { + pub fn transform_image_nearest_neighbor<'a, A, T>(&self, image: A) -> Result> + where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, + { interp(self.parameters, self.origin, image, true) } /// get coordinates resulting from transforming input coordinates - pub fn transform_coordinates(&self, coordinates: ArrayView2) -> Result> + pub fn transform_coordinates<'a, A, T>(&self, coordinates: A) -> Result> where - T: Clone + Into, + T: 'a + Clone + Into, + A: AsArray<'a, T, Ix2>, { + let coordinates = coordinates.into(); let s = coordinates.shape(); if s[1] != 2 { return Err(anyhow!("coordinates must have two columns")); @@ -233,6 +295,7 @@ mod tests { use anyhow::Result; use ndarray::Array2; use num::Complex; + use tempfile::NamedTempFile; /// An example of generating julia fractals. fn julia_image(shift_x: f32, shift_y: f32) -> Result> { @@ -263,25 +326,35 @@ mod tests { Ok(im) } + #[test] + fn test_serialization() -> Result<()> { + let file = NamedTempFile::new()?; + let t = Transform::new([1.2, 0.3, -0.4, 0.9, 10.2, -9.5], [59.5, 49.5], [120, 100]); + t.to_file(file.path().to_path_buf())?; + let s = Transform::from_file(file.path().to_path_buf())?; + assert_eq!(s, t); + Ok(()) + } + macro_rules! interp_tests_bspline { ($($name:ident: $t:ty $(,)?)*) => { - $( - #[test] - fn $name() -> Result<()> { - let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t); - let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t); - let shape = j.shape(); - let origin = [ - ((shape[1] - 1) as f64) / 2f64, - ((shape[0] - 1) as f64) / 2f64, - ]; - let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]); - let n = transform.transform_image_bspline(j.view())?; - let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum(); - assert!(d <= (shape[0] * shape[1]) as f64); - Ok(()) - } - )* + $( + #[test] + fn $name() -> Result<()> { + let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t); + let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t); + let shape = j.shape(); + let origin = [ + ((shape[1] - 1) as f64) / 2f64, + ((shape[0] - 1) as f64) / 2f64, + ]; + let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]); + let n = transform.transform_image_bspline(j.view())?; + let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum(); + assert!(d <= (shape[0] * shape[1]) as f64); + Ok(()) + } + )* } } @@ -300,23 +373,28 @@ mod tests { macro_rules! interp_tests_nearest_neighbor { ($($name:ident: $t:ty $(,)?)*) => { - $( - #[test] - fn $name() -> Result<()> { - let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t); - let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t); - let shape = j.shape(); - let origin = [ - ((shape[1] - 1) as f64) / 2f64, - ((shape[0] - 1) as f64) / 2f64, - ]; - let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]); - let n = transform.transform_image_nearest_neighbor(j.view())?; - let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum(); - assert!(d <= (shape[0] * shape[1]) as f64); - Ok(()) - } - )* + $( + #[test] + fn $name() -> Result<()> { + let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t); + let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t); + let shape = j.shape(); + let origin = [ + ((shape[1] - 1) as f64) / 2f64, + ((shape[0] - 1) as f64) / 2f64, + ]; + let j0 = j.clone(); + let k0 = k.clone(); + let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]); + // make sure j & k weren't mutated + assert!(j.iter().zip(j0.iter()).map(|(a, b)| a == b).all(|x| x)); + assert!(k.iter().zip(k0.iter()).map(|(a, b)| a == b).all(|x| x)); + let n = transform.transform_image_nearest_neighbor(j.view())?; + let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum(); + assert!(d <= (shape[0] * shape[1]) as f64); + Ok(()) + } + )* } } @@ -340,7 +418,12 @@ mod tests { fn $name() -> Result<()> { let j = julia_image(0f32, 0f32)?.mapv(|x| x as $t); let k = julia_image(10f32, 20f32)?.mapv(|x| x as $t); + let j0 = j.clone(); + let k0 = k.clone(); let t = Transform::register_translation(j.view(), k.view())?; + // make sure j & k weren't mutated + assert!(j.iter().zip(j0.iter()).map(|(a, b)| a == b).all(|x| x)); + assert!(k.iter().zip(k0.iter()).map(|(a, b)| a == b).all(|x| x)); let mut m = Array2::eye(3); m[[0, 2]] = -10f64; m[[1, 2]] = -20f64; diff --git a/src/sys.rs b/src/sys.rs index 7802a5a..ef34bdc 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1,6 +1,7 @@ +use crate::PixelType; use anyhow::Result; use libc::{c_double, c_uint}; -use ndarray::{Array2, ArrayView2}; +use ndarray::{Array2, AsArray, Ix2}; use one_at_a_time_please::one_at_a_time; use std::ptr; @@ -62,48 +63,17 @@ unsafe extern "C" { } } -pub trait PixelType: Clone { - const PT: u8; -} - -macro_rules! sitk_impl { - ($($T:ty: $sitk:expr $(,)?)*) => { - $( - impl PixelType for $T { - const PT: u8 = $sitk; - } - )* - }; -} - -sitk_impl! { - u8: 1, - i8: 2, - u16: 3, - i16: 4, - u32: 5, - i32: 6, - u64: 7, - i64: 8, - f32: 9, - f64: 10, -} - -#[cfg(target_pointer_width = "64")] -sitk_impl!(usize: 7); -#[cfg(target_pointer_width = "32")] -sitk_impl!(usize: 5); -#[cfg(target_pointer_width = "64")] -sitk_impl!(isize: 8); -#[cfg(target_pointer_width = "32")] -sitk_impl!(isize: 6); - -pub(crate) fn interp( +pub(crate) fn interp<'a, A, T>( parameters: [f64; 6], origin: [f64; 2], - image: ArrayView2, + image: A, bspline_or_nn: bool, -) -> Result> { +) -> Result> +where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, +{ + let image = image.into(); let shape: Vec = image.shape().to_vec(); let width = shape[1] as c_uint; let height = shape[0] as c_uint; @@ -220,16 +190,22 @@ pub(crate) fn interp( } #[one_at_a_time] -pub(crate) fn register( - fixed: ArrayView2, - moving: ArrayView2, +pub(crate) fn register<'a, A, T>( + fixed: A, + moving: A, translation_or_affine: bool, -) -> Result<([f64; 6], [f64; 2], [usize; 2])> { +) -> Result<([f64; 6], [f64; 2], [usize; 2])> +where + T: 'a + PixelType, + A: AsArray<'a, T, Ix2>, +{ + let fixed = fixed.into(); + let moving = moving.into(); let shape: Vec = fixed.shape().to_vec(); let width = shape[1] as c_uint; let height = shape[0] as c_uint; - let fixed: Vec<_> = fixed.into_iter().cloned().collect(); - let moving: Vec<_> = moving.into_iter().cloned().collect(); + let fixed: Vec<_> = fixed.into_iter().collect(); + let moving: Vec<_> = moving.into_iter().collect(); let fixed_ptr = fixed.as_ptr(); let moving_ptr = moving.as_ptr(); let mut transform: Vec = vec![0.0; 6];