- add more tests

- do not allow multiple registrations to run at once because it causes a memory error in the sitk code
This commit is contained in:
Wim Pomp
2025-03-03 09:40:46 +01:00
parent 0aeb3c8c7e
commit 934f038ea1
8 changed files with 378 additions and 167 deletions

1
.gitignore vendored
View File

@@ -81,4 +81,5 @@ docs/_build/
/cpp/Makefile /cpp/Makefile
/sitk /sitk
*.nii *.nii
*.tif
TransformParameters* TransformParameters*

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "sitk-registration-sys" name = "sitk-registration-sys"
version = "2025.3.0" version = "2025.3.1"
edition = "2024" edition = "2024"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
description = "register and interpolate images" description = "register and interpolate images"
@@ -22,9 +22,9 @@ anyhow = "1.0.96"
libc = "0.2.170" libc = "0.2.170"
ndarray = "0.16.1" ndarray = "0.16.1"
num = "0.4.3" num = "0.4.3"
one_at_a_time_please = "1.0.1"
[dev-dependencies] serde = { version = "1.0.218", features = ["derive"] }
tiffwrite = "2025.2.0" serde_yaml = "0.9.33"
[build-dependencies] [build-dependencies]
cmake = "0.1.54" cmake = "0.1.54"

View File

@@ -4,7 +4,7 @@ This crate does two things:
- find an affine transform or translation that transforms one image into the other - find an affine transform or translation that transforms one image into the other
- use bpline or nearest neighbor interpolation to apply a transformation to an image - use bpline or nearest neighbor interpolation to apply a transformation to an image
To do this [SimpleITK](https://github.com/SimpleITK/SimpleITK.git), which is written in To do this, [SimpleITK](https://github.com/SimpleITK/SimpleITK.git), which is written in
C++, is used. An adapter library is created to expose the required functionality in SimpleITK C++, is used. An adapter library is created to expose the required functionality in SimpleITK
in a shared library. Because of this, compilation of this crate requires quite some time, as in a shared library. Because of this, compilation of this crate requires quite some time, as
wel as cmake. wel as cmake.

View File

@@ -41,7 +41,6 @@ fn main() {
.define("SimpleITK_USE_ELASTIX", "ON") .define("SimpleITK_USE_ELASTIX", "ON")
.build(); .build();
} }
// println!("cargo::rustc-env=CMAKE_INSTALL_PREFIX=/home/wim/code/rust/sitk-sys/cpp");
println!( println!(
"cargo::rustc-env=CMAKE_INSTALL_PREFIX={}", "cargo::rustc-env=CMAKE_INSTALL_PREFIX={}",
out_dir.display() out_dir.display()
@@ -57,6 +56,6 @@ fn main() {
println!("cargo::rustc-link-search={}", path.join("build").display()); println!("cargo::rustc-link-search={}", path.join("build").display());
println!("cargo::rustc-link-lib=dylib=sitk_adapter"); println!("cargo::rustc-link-lib=dylib=sitk_adapter");
println!("cargo::rerun-if-changed=build.rs"); println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-changed=cpp/*.cxx"); println!("cargo::rerun-if-changed=cpp");
} }
} }

View File

@@ -1,10 +1,8 @@
cmake_minimum_required(VERSION 3.16.3) cmake_minimum_required(VERSION 3.16.3)
project(sitk_adapter) project(sitk_adapter)
set(ENV{Elastix_DIR} "../sitk/build/Elastix-build" ) set(ENV{Elastix_DIR} "../sitk/build/Elastix-build" )
set(ENV{ITK_DIR} "~/code/c/SimpleITK/build/ITK-build" ) set(ENV{ITK_DIR} "../sitk/build/ITK-build" )
set(ENV{SimpleITK_DIR} "~/code/c/SimpleITK/build/SimpleITK-build" ) set(ENV{SimpleITK_DIR} "~../sitk/build/SimpleITK-build" )
find_package(SimpleITK) find_package(SimpleITK)
add_library(sitk_adapter SHARED sitk_adapter.cxx) add_library(sitk_adapter SHARED sitk_adapter.cxx)
target_link_libraries (sitk_adapter ${SimpleITK_LIBRARIES}) target_link_libraries (sitk_adapter ${SimpleITK_LIBRARIES})

View File

@@ -1,12 +1,29 @@
#include <SimpleITK.h> #include <SimpleITK.h>
#include <sitkImageOperators.h> #include <sitkImageOperators.h>
#include <cstring> #include <cstring>
#include <filesystem>
namespace sitk = itk::simple; namespace sitk = itk::simple;
using namespace std; using namespace std;
std::string gen_random(const int len) {
static const char alphanum[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
std::string tmp_s;
tmp_s.reserve(len);
for (int i = 0; i < len; ++i) {
tmp_s += alphanum[rand() % (sizeof(alphanum) - 1)];
}
return tmp_s;
}
template <typename T> template <typename T>
sitk::Image make_image( sitk::Image make_image(
unsigned int width, unsigned int width,
@@ -114,7 +131,7 @@ interp_u16(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt16); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt16);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
uint16_t* c = im.GetBufferAsUInt16(); uint16_t* c = im.GetBufferAsUInt16();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 2);
} }
extern "C" void extern "C" void
@@ -129,7 +146,7 @@ interp_i16(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt16); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt16);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
int16_t* c = im.GetBufferAsInt16(); int16_t* c = im.GetBufferAsInt16();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 2);
} }
extern "C" void extern "C" void
@@ -144,7 +161,7 @@ interp_u32(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt32); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt32);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
uint32_t* c = im.GetBufferAsUInt32(); uint32_t* c = im.GetBufferAsUInt32();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 4);
} }
extern "C" void extern "C" void
@@ -159,7 +176,7 @@ interp_i32(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt32); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt32);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
int32_t* c = im.GetBufferAsInt32(); int32_t* c = im.GetBufferAsInt32();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 4);
} }
extern "C" void extern "C" void
@@ -174,7 +191,7 @@ interp_u64(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt64); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkUInt64);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
uint64_t* c = im.GetBufferAsUInt64(); uint64_t* c = im.GetBufferAsUInt64();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 8);
} }
extern "C" void extern "C" void
@@ -189,7 +206,7 @@ interp_i64(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt64); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkInt64);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
int64_t* c = im.GetBufferAsInt64(); int64_t* c = im.GetBufferAsInt64();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 8);
} }
extern "C" void extern "C" void
@@ -204,7 +221,7 @@ interp_f32(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkFloat32); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkFloat32);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
float* c = im.GetBufferAsFloat(); float* c = im.GetBufferAsFloat();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 4);
} }
extern "C" void extern "C" void
@@ -219,48 +236,97 @@ interp_f64(
sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkFloat64); sitk::Image im = make_image(width, height, *image, sitk::PixelIDValueEnum::sitkFloat64);
im = interp(transform, origin, im, bspline_or_nn); im = interp(transform, origin, im, bspline_or_nn);
double* c = im.GetBufferAsDouble(); double* c = im.GetBufferAsDouble();
memcpy(*image, c, width * height); memcpy(*image, c, width * height * 8);
}
void
reg2(
sitk::Image fixed,
sitk::Image moving,
bool t_or_a,
double** transform
) {
try {
string kind = (t_or_a == false) ? "translation" : "affine";
sitk::ImageRegistrationMethod R;
R.SetMetricAsMattesMutualInformation();
const double maxStep = 4.0;
const double minStep = 0.01;
const unsigned int numberOfIterations = 200;
const double relaxationFactor = 0.5;
// R.SetOptimizerAsLBFGSB(maxStep, minStep, numberOfIterations, relaxationFactor);
R.SetOptimizerAsRegularStepGradientDescent(maxStep, minStep, numberOfIterations, relaxationFactor);
// R.SetOptimizerAsLBFGS2();
vector<double> matrix = {1.0, 0.0, 0.0, 1.0};
vector<double> translation = {0., 0.};
vector<double> origin = {399.5, 299.5};
R.SetInitialTransform(sitk::AffineTransform(matrix, translation, origin));
R.SetInterpolator(sitk::sitkBSpline);
sitk::Transform outTx = R.Execute(fixed, moving);
vector<double> t = outTx.GetParameters();
for (int i = 0; i < t.size(); i++) {
cout << t[i] << " ";
(*transform)[i] = t[i];
}
} catch (const std::exception &exc) {
cerr << exc.what();
}
} }
void void
reg( reg(
sitk::Image fixed, sitk::Image fixed,
sitk::Image moving, sitk::Image moving,
bool t_or_a, bool t_or_a,
double** transform double** transform
) { ) {
try { try {
string kind = (t_or_a == false) ? "translation" : "affine"; string kind = (t_or_a == false) ? "translation" : "affine";
sitk::ElastixImageFilter tfilter = sitk::ElastixImageFilter(); // std::filesystem::path output_path = std::filesystem::temp_directory_path() / gen_random(12);
tfilter.LogToConsoleOff(); std::filesystem::path output_path = std::filesystem::temp_directory_path();
tfilter.SetFixedImage(fixed); // std::filesystem::create_directory(output_path);
tfilter.SetMovingImage(moving);
tfilter.SetParameterMap(sitk::GetDefaultParameterMap(kind));
tfilter.Execute();
sitk::ElastixImageFilter::ParameterMapType parameter_map = tfilter.GetTransformParameterMap(0); sitk::ElastixImageFilter tfilter = sitk::ElastixImageFilter();
for (sitk::ElastixImageFilter::ParameterMapType::iterator parameter = parameter_map.begin(); parameter != parameter_map.end(); ++parameter) { tfilter.LogToConsoleOff();
if (parameter->first == "TransformParameters") { tfilter.LogToFileOff();
vector<string> tp = parameter->second; tfilter.SetLogToFile(false);
if (t_or_a == true) { tfilter.SetFixedImage(fixed);
for (int j = 0; j < tp.size(); j++) { tfilter.SetMovingImage(moving);
(*transform)[j] = stod(tp[j]); tfilter.SetParameterMap(sitk::GetDefaultParameterMap(kind));
} tfilter.SetParameter("WriteResultImage", "false");
} else { tfilter.SetOutputDirectory(output_path);
(*transform)[0] = 1.0; // cout << "output_path: " << output_path << endl;
(*transform)[1] = 0.0; // tfilter.PrintParameterMap();
(*transform)[2] = 0.0; // cout << "r6 " << std::flush;
(*transform)[3] = 1.0; tfilter.Execute();
for (int j = 0; j < tp.size(); j++) { // cout << "r7 " << std::flush;
(*transform)[j + 4] = stod(tp[j]); sitk::ElastixImageFilter::ParameterMapType parameter_map = tfilter.GetTransformParameterMap(0);
for (sitk::ElastixImageFilter::ParameterMapType::iterator parameter = parameter_map.begin(); parameter != parameter_map.end(); ++parameter) {
if (parameter->first == "TransformParameters") {
vector<string> tp = parameter->second;
if (t_or_a == true) {
for (int j = 0; j < tp.size(); j++) {
(*transform)[j] = stod(tp[j]);
}
} else {
(*transform)[0] = 1.0;
(*transform)[1] = 0.0;
(*transform)[2] = 0.0;
(*transform)[3] = 1.0;
for (int j = 0; j < tp.size(); j++) {
(*transform)[j + 4] = stod(tp[j]);
}
}
break;
} }
} }
} // cout << "r8 " << std::flush;
} catch (const std::exception &exc) {
cerr << exc.what();
} }
} catch (const std::exception &exc) {
cerr << exc.what();
}
} }

View File

@@ -3,10 +3,13 @@ mod sys;
use crate::sys::{PixelType, interp, register}; use crate::sys::{PixelType, interp, register};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use ndarray::{Array2, ArrayView2, array, s}; use ndarray::{Array2, ArrayView2, array, s};
use serde::{Deserialize, Serialize};
use serde_yaml::{from_reader, to_writer};
use std::fs::File;
use std::ops::Mul; use std::ops::Mul;
use std::path::PathBuf; use std::path::PathBuf;
#[derive(Clone, Debug)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Transform { pub struct Transform {
pub parameters: [f64; 6], pub parameters: [f64; 6],
pub dparameters: [f64; 6], pub dparameters: [f64; 6],
@@ -94,13 +97,16 @@ impl Transform {
} }
/// read a transform from a file /// read a transform from a file
pub fn from_file(file: PathBuf) -> Result<Self> { pub fn from_file(path: PathBuf) -> Result<Self> {
todo!() let file = File::open(path)?;
Ok(from_reader(file)?)
} }
/// write a transform to a file /// write a transform to a file
pub fn to_file(&self, file: PathBuf) -> Result<()> { pub fn to_file(&self, path: PathBuf) -> Result<()> {
todo!() let mut file = File::open(path)?;
to_writer(&mut file, self)?;
Ok(())
} }
/// true if transform does nothing /// true if transform does nothing
@@ -190,18 +196,17 @@ impl Transform {
} }
let m = self.matrix(); let m = self.matrix();
let d0 = det(m.slice(s![1.., 1..])); let d = det(m.slice(s![..2, ..2]));
if d0 == 0f64 { if d == 0f64 {
return Err(anyhow!("transform matrix is not invertible")); return Err(anyhow!("transform matrix is not invertible"));
} }
let d2 = det(m.slice(s![..2, ..2]));
let parameters = [ let parameters = [
d0 / d2, det(m.slice(s![1.., 1..])) / d,
-det(m.slice(s![..;2, 1..])) / d2, -det(m.slice(s![..;2, 1..])) / d,
-det(m.slice(s![1.., ..;2])) / d2, -det(m.slice(s![1.., ..;2])) / d,
det(m.slice(s![..;2, ..;2])) / d2, det(m.slice(s![..;2, ..;2])) / d,
det(m.slice(s![..2, 1..])) / d2, det(m.slice(s![..2, 1..])) / d,
-det(m.slice(s![..2, ..;2])) / d2, -det(m.slice(s![..2, ..;2])) / d,
]; ];
Ok(Transform { Ok(Transform {
@@ -228,10 +233,9 @@ mod tests {
use anyhow::Result; use anyhow::Result;
use ndarray::Array2; use ndarray::Array2;
use num::Complex; use num::Complex;
use tiffwrite::IJTiffFile;
/// An example of generating julia fractals. /// An example of generating julia fractals.
fn julia_image() -> Result<Array2<u8>> { fn julia_image(shift_x: f32, shift_y: f32) -> Result<Array2<u8>> {
let imgx = 800; let imgx = 800;
let imgy = 600; let imgy = 600;
@@ -241,11 +245,11 @@ mod tests {
let mut im = Array2::<u8>::zeros((imgy, imgx)); let mut im = Array2::<u8>::zeros((imgy, imgx));
for x in 0..imgx { for x in 0..imgx {
for y in 0..imgy { for y in 0..imgy {
let cx = y as f32 * scalex - 1.5; let cy = (y as f32 + shift_y) * scalex - 1.5;
let cy = x as f32 * scaley - 1.5; let cx = (x as f32 + shift_x) * scaley - 1.5;
let c = Complex::new(-0.4, 0.6); let c = Complex::new(-0.4, 0.6);
let mut z = Complex::new(cx, cy); let mut z = Complex::new(cy, cx);
let mut i = 0; let mut i = 0;
while i < 255 && z.norm() <= 2.0 { while i < 255 && z.norm() <= 2.0 {
@@ -259,24 +263,140 @@ mod tests {
Ok(im) Ok(im)
} }
#[test] macro_rules! interp_tests_bspline {
fn test_interp() -> Result<()> { ($($name:ident: $t:ty $(,)?)*) => {
let j = julia_image()?; $(
let mut tif = IJTiffFile::new("interp_test.tif")?; #[test]
tif.save(&j, 0, 0, 0)?; fn $name() -> Result<()> {
let shape = j.shape(); let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t);
let origin = [ let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t);
((shape[1] - 1) as f64) / 2f64, let shape = j.shape();
((shape[0] - 1) as f64) / 2f64, let origin = [
]; ((shape[1] - 1) as f64) / 2f64,
let transform = Transform::new([1.2, 0., 0., 1., 10., 0.], origin, [shape[0], shape[1]]); ((shape[0] - 1) as f64) / 2f64,
let k = transform.transform_image_bspline(j.view())?; ];
tif.save(&k, 1, 0, 0)?; let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]);
let n = transform.transform_image_bspline(j.view())?;
let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum();
assert!(d <= (shape[0] * shape[1]) as f64);
Ok(())
}
)*
}
}
let t = Transform::register_affine(k.view(), j.view())?; interp_tests_bspline! {
println!("t: {:#?}", t); interpbs_u8: u8,
println!("m: {:#?}", t.matrix()); interpbs_i8: i8,
println!("i: {:#?}", t.inverse()?.matrix()); interpbs_u16: u16,
Ok(()) interpbs_i16: i16,
interpbs_u32: u32,
interpbs_i32: i32,
interpbs_u64: u64,
interpbs_i64: i64,
interpbs_f32: f32,
interpbs_f64: f64,
}
macro_rules! interp_tests_nearest_neighbor {
($($name:ident: $t:ty $(,)?)*) => {
$(
#[test]
fn $name() -> Result<()> {
let j = julia_image(-120f32, 10f32)?.mapv(|x| x as $t);
let k = julia_image(0f32, 0f32)?.mapv(|x| x as $t);
let shape = j.shape();
let origin = [
((shape[1] - 1) as f64) / 2f64,
((shape[0] - 1) as f64) / 2f64,
];
let transform = Transform::new([1., 0., 0., 1., 120., -10.], origin, [shape[0], shape[1]]);
let n = transform.transform_image_nearest_neighbor(j.view())?;
let d = (k.mapv(|x| x as f64) - n.mapv(|x| x as f64)).powi(2).sum();
assert!(d <= (shape[0] * shape[1]) as f64);
Ok(())
}
)*
}
}
interp_tests_nearest_neighbor! {
interpnn_u8: u8,
interpnn_i8: i8,
interpnn_u16: u16,
interpnn_i16: i16,
interpnn_u32: u32,
interpnn_i32: i32,
interpnn_u64: u64,
interpnn_i64: i64,
interpnn_f32: f32,
interpnn_f64: f64,
}
macro_rules! registration_tests_translation {
($($name:ident: $t:ty $(,)?)*) => {
$(
#[test]
fn $name() -> Result<()> {
let j = julia_image(0f32, 0f32)?.mapv(|x| x as $t);
let k = julia_image(10f32, 20f32)?.mapv(|x| x as $t);
let t = Transform::register_translation(j.view(), k.view())?;
let mut m = Array2::eye(3);
m[[0, 2]] = -10f64;
m[[1, 2]] = -20f64;
let d = (t.matrix() - m).powi(2).sum();
assert!(d < 0.01);
Ok(())
}
)*
}
}
registration_tests_translation! {
registration_translation_u8: u8,
registration_translation_i8: i8,
registration_translation_u16: u16,
registration_translation_i16: i16,
registration_translation_u32: u32,
registration_translation_i32: i32,
registration_translation_u64: u64,
registration_translation_i64: i64,
registration_translation_f32: f32,
registration_translation_f64: f64,
}
macro_rules! registration_tests_affine {
($($name:ident: $t:ty $(,)?)*) => {
$(
#[test]
fn $name() -> Result<()> {
let j = julia_image(0f32, 0f32)?.mapv(|x| x as $t);
let shape = j.shape();
let origin = [
((shape[1] - 1) as f64) / 2f64,
((shape[0] - 1) as f64) / 2f64,
];
let s = Transform::new([1.2, 0., 0., 1., 5., 7.], origin, [shape[0], shape[1]]);
let k = s.transform_image_bspline(j.view())?;
let t = Transform::register_affine(j.view(), k.view())?.inverse()?;
let d = (t.matrix() - s.matrix()).powi(2).sum();
assert!(d < 0.01);
Ok(())
}
)*
}
}
registration_tests_affine! {
registration_tests_affine_u8: u8,
registration_tests_affine_i8: i8,
registration_tests_affine_u16: u16,
registration_tests_affine_i16: i16,
registration_tests_affine_u32: u32,
registration_tests_affine_i32: i32,
registration_tests_affine_u64: u64,
registration_tests_affine_i64: i64,
registration_tests_affine_f32: f32,
registration_tests_affine_f64: f64,
} }
} }

View File

@@ -1,55 +1,65 @@
use anyhow::Result; use anyhow::Result;
use libc::{c_double, c_uint}; use libc::{c_double, c_uint};
use ndarray::{Array2, ArrayView2}; use ndarray::{Array2, ArrayView2};
use one_at_a_time_please::one_at_a_time;
use std::ptr; use std::ptr;
macro_rules! register_fn { macro_rules! register_fn {
($T:ty, $t:ident) => { ($($name:ident: $T:ty $(,)?)*) => {
fn $t( $(
width: c_uint, fn $name(
height: c_uint, width: c_uint,
fixed_arr: *const $T, height: c_uint,
moving_arr: *const $T, fixed_arr: *const $T,
translation_or_affine: bool, moving_arr: *const $T,
transform: &mut *mut c_double, translation_or_affine: bool,
); transform: &mut *mut c_double,
);
)*
}; };
} }
macro_rules! interp_fn { macro_rules! interp_fn {
($T:ty, $t:ident) => { ($($name:ident: $T:ty $(,)?)*) => {
fn $t( $(
width: c_uint, fn $name(
height: c_uint, width: c_uint,
transform: *const c_double, height: c_uint,
origin: *const c_double, transform: *const c_double,
image: &mut *mut $T, origin: *const c_double,
bspline_or_nn: bool, image: &mut *mut $T,
); bspline_or_nn: bool,
);
)*
}; };
} }
unsafe extern "C" { unsafe extern "C" {
register_fn!(u8, register_u8); register_fn! {
register_fn!(i8, register_i8); register_u8: u8,
register_fn!(u16, register_u16); register_i8: i8,
register_fn!(i16, register_i16); register_u16: u16,
register_fn!(u32, register_u32); register_i16: i16,
register_fn!(i32, register_i32); register_u32: u32,
register_fn!(u64, register_u64); register_i32: i32,
register_fn!(i64, register_i64); register_u64: u64,
register_fn!(f32, register_f32); register_i64: i64,
register_fn!(f64, register_f64); register_f32: f32,
interp_fn!(u8, interp_u8); register_f64: f64,
interp_fn!(i8, interp_i8); }
interp_fn!(u16, interp_u16);
interp_fn!(i16, interp_i16); interp_fn! {
interp_fn!(u32, interp_u32); interp_u8: u8,
interp_fn!(i32, interp_i32); interp_i8: i8,
interp_fn!(u64, interp_u64); interp_u16: u16,
interp_fn!(i64, interp_i64); interp_i16: i16,
interp_fn!(f32, interp_f32); interp_u32: u32,
interp_fn!(f64, interp_f64); interp_i32: i32,
interp_u64: u64,
interp_i64: i64,
interp_f32: f32,
interp_f64: f64,
}
} }
pub trait PixelType: Clone { pub trait PixelType: Clone {
@@ -57,31 +67,36 @@ pub trait PixelType: Clone {
} }
macro_rules! sitk_impl { macro_rules! sitk_impl {
($T:ty, $sitk:expr) => { ($($T:ty: $sitk:expr $(,)?)*) => {
impl PixelType for $T { $(
const PT: u8 = $sitk; impl PixelType for $T {
} const PT: u8 = $sitk;
}
)*
}; };
} }
sitk_impl!(u8, 1); sitk_impl! {
sitk_impl!(i8, 2); u8: 1,
sitk_impl!(u16, 3); i8: 2,
sitk_impl!(i16, 4); u16: 3,
sitk_impl!(u32, 5); i16: 4,
sitk_impl!(i32, 6); u32: 5,
sitk_impl!(u64, 7); i32: 6,
sitk_impl!(i64, 8); u64: 7,
i64: 8,
f32: 9,
f64: 10,
}
#[cfg(target_pointer_width = "64")] #[cfg(target_pointer_width = "64")]
sitk_impl!(usize, 7); sitk_impl!(usize: 7);
#[cfg(target_pointer_width = "32")] #[cfg(target_pointer_width = "32")]
sitk_impl!(usize, 5); sitk_impl!(usize: 5);
#[cfg(target_pointer_width = "64")] #[cfg(target_pointer_width = "64")]
sitk_impl!(isize, 8); sitk_impl!(isize: 8);
#[cfg(target_pointer_width = "32")] #[cfg(target_pointer_width = "32")]
sitk_impl!(isize, 6); sitk_impl!(isize: 6);
sitk_impl!(f32, 9);
sitk_impl!(f64, 10);
pub(crate) fn interp<T: PixelType>( pub(crate) fn interp<T: PixelType>(
parameters: [f64; 6], parameters: [f64; 6],
@@ -204,6 +219,7 @@ pub(crate) fn interp<T: PixelType>(
)?) )?)
} }
#[one_at_a_time]
pub(crate) fn register<T: PixelType>( pub(crate) fn register<T: PixelType>(
fixed: ArrayView2<T>, fixed: ArrayView2<T>,
moving: ArrayView2<T>, moving: ArrayView2<T>,
@@ -212,19 +228,24 @@ pub(crate) fn register<T: PixelType>(
let shape: Vec<usize> = fixed.shape().to_vec(); let shape: Vec<usize> = fixed.shape().to_vec();
let width = shape[1] as c_uint; let width = shape[1] as c_uint;
let height = shape[0] as c_uint; let height = shape[0] as c_uint;
let fixed: Vec<_> = fixed.into_iter().collect(); let fixed: Vec<_> = fixed.into_iter().cloned().collect();
let moving: Vec<_> = moving.into_iter().collect(); let moving: Vec<_> = moving.into_iter().cloned().collect();
let fixed_ptr = fixed.as_ptr();
let moving_ptr = moving.as_ptr();
let mut transform: Vec<c_double> = vec![0.0; 6]; let mut transform: Vec<c_double> = vec![0.0; 6];
let mut transform_ptr: *mut c_double = ptr::from_mut(unsafe { &mut *transform.as_mut_ptr() }); let mut transform_ptr: *mut c_double = ptr::from_mut(unsafe { &mut *transform.as_mut_ptr() });
// let ma0 = &mut moving as *mut Vec<T> as usize;
// println!("ma0: {:#x}", ma0);
match T::PT { match T::PT {
1 => { 1 => {
unsafe { unsafe {
register_u8( register_u8(
width, width,
height, height,
fixed.as_ptr() as *const u8, fixed_ptr as *const u8,
moving.as_ptr() as *const u8, moving_ptr as *const u8,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -235,8 +256,8 @@ pub(crate) fn register<T: PixelType>(
register_i8( register_i8(
width, width,
height, height,
fixed.as_ptr() as *const i8, fixed_ptr as *const i8,
moving.as_ptr() as *const i8, moving_ptr as *const i8,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -247,8 +268,8 @@ pub(crate) fn register<T: PixelType>(
register_u16( register_u16(
width, width,
height, height,
fixed.as_ptr() as *const u16, fixed_ptr as *const u16,
moving.as_ptr() as *const u16, moving_ptr as *const u16,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -259,8 +280,8 @@ pub(crate) fn register<T: PixelType>(
register_i16( register_i16(
width, width,
height, height,
fixed.as_ptr() as *const i16, fixed_ptr as *const i16,
moving.as_ptr() as *const i16, moving_ptr as *const i16,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -271,8 +292,8 @@ pub(crate) fn register<T: PixelType>(
register_u32( register_u32(
width, width,
height, height,
fixed.as_ptr() as *const u32, fixed_ptr as *const u32,
moving.as_ptr() as *const u32, moving_ptr as *const u32,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -283,8 +304,8 @@ pub(crate) fn register<T: PixelType>(
register_i32( register_i32(
width, width,
height, height,
fixed.as_ptr() as *const i32, fixed_ptr as *const i32,
moving.as_ptr() as *const i32, moving_ptr as *const i32,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -295,8 +316,8 @@ pub(crate) fn register<T: PixelType>(
register_u64( register_u64(
width, width,
height, height,
fixed.as_ptr() as *const u64, fixed_ptr as *const u64,
moving.as_ptr() as *const u64, moving_ptr as *const u64,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -307,8 +328,8 @@ pub(crate) fn register<T: PixelType>(
register_i64( register_i64(
width, width,
height, height,
fixed.as_ptr() as *const i64, fixed_ptr as *const i64,
moving.as_ptr() as *const i64, moving_ptr as *const i64,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -319,8 +340,8 @@ pub(crate) fn register<T: PixelType>(
register_f32( register_f32(
width, width,
height, height,
fixed.as_ptr() as *const f32, fixed_ptr as *const f32,
moving.as_ptr() as *const f32, moving_ptr as *const f32,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -331,8 +352,8 @@ pub(crate) fn register<T: PixelType>(
register_f64( register_f64(
width, width,
height, height,
fixed.as_ptr() as *const f64, fixed_ptr as *const f64,
moving.as_ptr() as *const f64, moving_ptr as *const f64,
translation_or_affine, translation_or_affine,
&mut transform_ptr, &mut transform_ptr,
) )
@@ -341,6 +362,12 @@ pub(crate) fn register<T: PixelType>(
_ => {} _ => {}
} }
// let ma1 = &mut moving as *mut Vec<T> as usize;
// println!("ma1: {:#x}", ma1);
// println!("{}", fixed.len());
// println!("{}", moving.len());
Ok(( Ok((
[ [
transform[0] as f64, transform[0] as f64,