Skip to content

Commit

Permalink
ENH: Add ElastixRegistrationMethod::SetInitialTransform
Browse files Browse the repository at this point in the history
Supporting using an ITK transform object as initial transformation of a registration.

Added itkElastixRegistrationMethod GoogleTest unit tests.

Each of the member functions SetInitialTransform, SetInitialTransformParameterFileName, and SetInitialTransformParameterObject now resets the initial transform.
  • Loading branch information
N-Dekker committed Jul 19, 2023
1 parent d33a0d5 commit 115d8e1
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 12 deletions.
117 changes: 117 additions & 0 deletions Core/Main/GTesting/itkElastixRegistrationMethodGTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,20 @@ Test_WriteBSplineTransformToItkFileFormat(const std::string & rootOutputDirector
}
}
}


template <typename TParametersValueType, unsigned int VInputDimension, unsigned int VOutputDimension>
itk::SizeValueType
GetNumberOfTransforms(const itk::Transform<TParametersValueType, VInputDimension, VOutputDimension> & transform)
{
if (const auto multiTransform =
dynamic_cast<const itk::MultiTransform<TParametersValueType, VInputDimension, VOutputDimension> *>(&transform))
{
return multiTransform->GetNumberOfTransforms();
}
return 1;
};

} // namespace


Expand Down Expand Up @@ -1032,6 +1046,109 @@ GTEST_TEST(itkElastixRegistrationMethod, InitialTransformParameterFileWithInitia
}


GTEST_TEST(itkElastixRegistrationMethod, SetInitialTransform)
{
using PixelType = float;
enum
{
ImageDimension = 2U
};
using ImageType = itk::Image<PixelType, ImageDimension>;
using SizeType = itk::Size<ImageDimension>;
using IndexType = itk::Index<ImageDimension>;
using OffsetType = itk::Offset<ImageDimension>;

const OffsetType initialTranslation{ { 1, -2 } };
const auto regionSize = SizeType::Filled(2);
const SizeType imageSize{ { 5, 6 } };
const IndexType fixedImageRegionIndex{ { 1, 3 } };

using TransformType = ElastixRegistrationMethodType<ImageType>::TransformType;

const TransformType::ConstPointer singleInitialTransform = [] {
const auto singleTransform = itk::TranslationTransform<double, ImageDimension>::New();
singleTransform->SetOffset(itk::MakeVector(1.0, -2.0));
return singleTransform;
}();

const TransformType::ConstPointer compositeInitialTransform = [] {
const auto translationTransformX = itk::TranslationTransform<double, ImageDimension>::New();
translationTransformX->SetOffset(itk::MakeVector(1.0, 0.0));
const auto translationTransformY = itk::TranslationTransform<double, ImageDimension>::New();
translationTransformY->SetOffset(itk::MakeVector(0.0, -2.0));

const auto compositeTransform = itk::CompositeTransform<double, ImageDimension>::New();
compositeTransform->AddTransform(translationTransformX);
compositeTransform->AddTransform(translationTransformY);
return compositeTransform;
}();

// Test both a single and a composite transform as initial transform.
for (const TransformType * const initialTransform : { singleInitialTransform, compositeInitialTransform })
{
const auto fixedImage = CreateImage<PixelType>(imageSize);
FillImageRegion(*fixedImage, fixedImageRegionIndex, regionSize);

const auto movingImage = CreateImage<PixelType>(imageSize);

elx::DefaultConstruct<elx::ParameterObject> registrationParameterObject{};
elx::DefaultConstruct<ElastixRegistrationMethodType<ImageType>> registration{};
registration.SetFixedImage(fixedImage);
registration.SetInitialTransform(initialTransform);
registration.SetParameterObject(&registrationParameterObject);

const elx::ParameterObject::ParameterMapType registrationParameterMap{
// Parameters in alphabetic order:
{ "ImageSampler", { "Full" } },
{ "MaximumNumberOfIterations", { "2" } },
{ "Metric", { "AdvancedNormalizedCorrelation" } },
{ "Optimizer", { "AdaptiveStochasticGradientDescent" } },
{ "Transform", { "TranslationTransform" } }
};

for (const unsigned int numberOfRegistrationParameterMaps : { 1, 2, 3 })
{
// Specify multiple (one or more) registration parameter maps.
registrationParameterObject.SetParameterMaps(
ParameterMapVectorType(numberOfRegistrationParameterMaps, registrationParameterMap));

const auto numberOfInitialTransformParameterMaps = GetNumberOfTransforms(*initialTransform);

// Do the test for a few possible translations.
for (const auto index :
itk::ImageRegionIndexRange<ImageDimension>(itk::ImageRegion<ImageDimension>({ 0, -2 }, { 2, 3 })))
{
const auto actualTranslation = ConvertIndexToOffset(index);
movingImage->FillBuffer(0);
FillImageRegion(*movingImage, fixedImageRegionIndex + actualTranslation, regionSize);
registration.SetMovingImage(movingImage);
registration.Update();

const auto & transformParameterMaps =
DerefRawPointer(registration.GetTransformParameterObject()).GetParameterMaps();

ASSERT_EQ(transformParameterMaps.size(),
numberOfInitialTransformParameterMaps + numberOfRegistrationParameterMaps);

// All transform parameter maps, except for the initial transformations and the transform parameter map of the
// first registration should just have a zero-translation.
for (auto i = numberOfInitialTransformParameterMaps + 1; i < numberOfRegistrationParameterMaps; ++i)
{
const auto transformParameters =
ConvertStringsToVectorOfDouble(transformParameterMaps[i].at("TransformParameters"));
EXPECT_EQ(ConvertToOffset<ImageDimension>(transformParameters), OffsetType{});
}

// Together the initial translation and the first registration should yield the actual image translation.
const auto transformParameters = ConvertStringsToVectorOfDouble(
transformParameterMaps[numberOfInitialTransformParameterMaps].at("TransformParameters"));
EXPECT_EQ(initialTranslation + ConvertToOffset<ImageDimension>(transformParameters), actualTranslation);
}
}
}
}


GTEST_TEST(itkElastixRegistrationMethod, SetInitialTransformParameterObject)
{
using PixelType = float;
Expand Down
29 changes: 27 additions & 2 deletions Core/Main/itkElastixRegistrationMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class ITK_TEMPLATE_EXPORT ElastixRegistrationMethod : public itk::ImageSource<TF
GetInput(DataObjectPointerArraySizeType index) const;

/** Set/Get/Remove initial transform parameter filename. */
itkSetMacro(InitialTransformParameterFileName, std::string);
void SetInitialTransformParameterFileName(std::string);

itkGetConstMacro(InitialTransformParameterFileName, std::string);
virtual void
RemoveInitialTransformParameterFileName()
Expand All @@ -202,7 +203,12 @@ class ITK_TEMPLATE_EXPORT ElastixRegistrationMethod : public itk::ImageSource<TF
}

/** Set initial transform parameter object. */
itkSetConstObjectMacro(InitialTransformParameterObject, elx::ParameterObject);
void
SetInitialTransformParameterObject(const elx::ParameterObject *);

/** Set the initial transformation by means of an ITK Transform. */
void
SetInitialTransform(const TransformType *);

/** Set/Get/Remove fixed point set filename. */
itkSetMacro(FixedPointSetFileName, std::string);
Expand Down Expand Up @@ -324,6 +330,24 @@ class ITK_TEMPLATE_EXPORT ElastixRegistrationMethod : public itk::ImageSource<TF
return images;
}

void
ResetInitialTransformWithoutModified()
{
m_InitialTransform = nullptr;
m_InitialTransformParameterFileName.clear();
m_InitialTransformParameterObject = nullptr;
}

void
ResetInitialTransformAndModified()
{
if (m_InitialTransform || m_InitialTransformParameterObject || !m_InitialTransformParameterFileName.empty())
{
ResetInitialTransformWithoutModified();
this->Modified();
}
}

/** Private using-declaration, just to avoid GCC compilation warnings: '...' was hidden [-Woverloaded-virtual] */
using Superclass::SetInput;

Expand All @@ -333,6 +357,7 @@ class ITK_TEMPLATE_EXPORT ElastixRegistrationMethod : public itk::ImageSource<TF

std::string m_InitialTransformParameterFileName{};
elx::ParameterObject::ConstPointer m_InitialTransformParameterObject{};
SmartPointer<const TransformType> m_InitialTransform{};

std::string m_FixedPointSetFileName{};
std::string m_MovingPointSetFileName{};
Expand Down
129 changes: 119 additions & 10 deletions Core/Main/itkElastixRegistrationMethod.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,7 @@ ElastixRegistrationMethod<TFixedImage, TMovingImage>::GenerateData()
DataObjectContainerPointer movingMaskContainer = nullptr;
DataObjectContainerPointer resultImageContainer = nullptr;
ElastixMainObjectPointer transform = nullptr;
ParameterMapVectorType transformParameterMapVector = m_InitialTransformParameterObject
? m_InitialTransformParameterObject->GetParameterMaps()
: ParameterMapVectorType{};
FlatDirectionCosinesType fixedImageOriginalDirection;
FlatDirectionCosinesType fixedImageOriginalDirection;

// Split inputs into separate containers
for (const auto & inputName : this->GetInputNames())
Expand Down Expand Up @@ -177,14 +174,65 @@ ElastixRegistrationMethod<TFixedImage, TMovingImage>::GenerateData()
m_EnableOutput && m_LogToConsole,
static_cast<elastix::log::level>(m_LogLevel));

if (m_InitialTransformParameterObject && !m_OutputDirectory.empty())
const auto getInitialTransformParameterMaps = [this]() -> ParameterMapVectorType {
if (m_InitialTransformParameterObject)
{
return m_InitialTransformParameterObject->GetParameterMaps();
}

if (m_InitialTransform)
{
const auto transformToMap = [](const itk::TransformBase & transform) {
return ParameterMapType{
{ "ITKTransformFixedParameters", elx::Conversion::ToVectorOfStrings(transform.GetFixedParameters()) },
{ "ITKTransformParameters", elx::Conversion::ToVectorOfStrings(transform.GetParameters()) },
{ "ITKTransformType", { transform.GetTransformTypeAsString() } },
{ "Transform", { elx::TransformIO::ConvertITKNameOfClassToElastixClassName(transform.GetNameOfClass()) } }
};
};

const auto compositeTransform =
dynamic_cast<const CompositeTransform<double, MovingImageDimension> *>(&*m_InitialTransform);

if (compositeTransform)
{
const auto & transformQueue = compositeTransform->GetTransformQueue();

ParameterMapVectorType transformParameterMaps(transformQueue.size());

auto reverseIterator = transformParameterMaps.rbegin();

for (const auto & transform : transformQueue)
{
if (transform == nullptr)
{
itkGenericExceptionMacro("One of the subtransforms of the specified composite transform is null!");
}
*reverseIterator = transformToMap(*transform);
++reverseIterator;
}
return transformParameterMaps;
}

// Assume in this case that it is just a single transform.
assert((dynamic_cast<const MultiTransform<double, MovingImageDimension> *>(&*m_InitialTransform)) == nullptr);

// For a single transform, there should be only a single transform parameter map.
return ParameterMapVectorType{ transformToMap(*m_InitialTransform) };
}
return {};
};

ParameterMapVectorType transformParameterMapVector = getInitialTransformParameterMaps();

if (!transformParameterMapVector.empty() && !m_OutputDirectory.empty())
{
std::string initialTransformParameterFileName = "NoInitialTransform";

// Write InitialTransformParameters.0.txt, InitialTransformParameters.1.txt, InitialTransformParameters.2.txt, etc.
unsigned i{};

for (auto transformParameterMap : m_InitialTransformParameterObject->GetParameterMaps())
for (auto transformParameterMap : transformParameterMapVector)
{
transformParameterMap["InitialTransformParameterFileName"] = { initialTransformParameterFileName };

Expand Down Expand Up @@ -273,10 +321,10 @@ ElastixRegistrationMethod<TFixedImage, TMovingImage>::GenerateData()
unsigned int isError = 0;
try
{
isError = ((i == 0) && m_InitialTransformParameterObject)
? elastixMain->RunWithInitialTransformParameterMaps(
argumentMap, parameterMap, m_InitialTransformParameterObject->GetParameterMaps())
: elastixMain->Run(argumentMap, parameterMap);
isError =
((i == 0) && !transformParameterMapVector.empty())
? elastixMain->RunWithInitialTransformParameterMaps(argumentMap, parameterMap, transformParameterMapVector)
: elastixMain->Run(argumentMap, parameterMap);
}
catch (const itk::ExceptionObject & e)
{
Expand Down Expand Up @@ -737,6 +785,67 @@ ElastixRegistrationMethod<TFixedImage, TMovingImage>::SetInput(DataObjectPointer
}


template <typename TFixedImage, typename TMovingImage>
void
ElastixRegistrationMethod<TFixedImage, TMovingImage>::SetInitialTransformParameterFileName(std::string fileName)
{
if (fileName.empty())
{
ResetInitialTransformAndModified();
}
else
{
if (m_InitialTransformParameterFileName != fileName)
{
ResetInitialTransformWithoutModified();
m_InitialTransformParameterFileName = std::move(fileName);
this->Modified();
}
}
}


template <typename TFixedImage, typename TMovingImage>
void
ElastixRegistrationMethod<TFixedImage, TMovingImage>::SetInitialTransformParameterObject(
const elx::ParameterObject * const parameterObject)
{
if (parameterObject)
{
if (m_InitialTransformParameterObject != parameterObject)
{
ResetInitialTransformWithoutModified();
m_InitialTransformParameterObject = parameterObject;
this->Modified();
}
}
else
{
ResetInitialTransformAndModified();
}
}


template <typename TFixedImage, typename TMovingImage>
void
ElastixRegistrationMethod<TFixedImage, TMovingImage>::SetInitialTransform(const TransformType * const transform)
{
if (transform)
{
if (m_InitialTransform != transform)
{
ResetInitialTransformWithoutModified();
m_InitialTransform = transform;
this->Modified();
}
}
else
{
ResetInitialTransformAndModified();
}
}


template <typename TFixedImage, typename TMovingImage>
void
ElastixRegistrationMethod<TFixedImage, TMovingImage>::SetLogFileName(const std::string logFileName)
Expand Down

0 comments on commit 115d8e1

Please sign in to comment.