Skip to content

Commit

Permalink
improved implementation. added the type classing setup to scalismo.co…
Browse files Browse the repository at this point in the history
…mmon.RealignExtendedBasis
  • Loading branch information
JonathanAellen committed Jun 10, 2024
1 parent d3d064c commit ba6c5f1
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 69 deletions.
65 changes: 65 additions & 0 deletions src/main/scala/scalismo/common/RealignExtendedBasis.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package scalismo.common

import breeze.linalg.DenseMatrix
import scalismo.geometry.{EuclideanVector, NDSpace, Point, _2D, _3D}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess

/**
* types the whole discrete low rank gp to make sure that it is applied to the appropriate models. The value type could
* be left out if the user knows what to do.
*/
trait RealignExtendedBasis[D: NDSpace, Value]:

def useTranslation: Boolean
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[D, DDomain, Value], center: Point[D]): DenseMatrix[Double]
def centeredP[DDomain[DD] <: DiscreteDomain[DD]](domain: DDomain[D], center: Point[D]): DenseMatrix[Double] = {
// build centered data matrix
val x = DenseMatrix.zeros[Double](center.dimensionality, domain.pointSet.numberOfPoints)
val c = center.toBreezeVector
for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c
x
}

object RealignExtendedBasis:
/**
* returns a projection basis for rotation. that is the tangential speed for the rotations around the three cardinal
* directions.
*/
given realignBasis3D: RealignExtendedBasis[_3D, EuclideanVector[_3D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]], center: Point[_3D]): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

val pr = DenseMatrix.zeros[Double](np * 3, 3)
// the derivative of the rotation matrices
val dr = new DenseMatrix[Double](9, 3,
Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
)
// get tangential speed
val dx = dr * x
for i <- 0 until 3 do
val v = dx(3 * i until 3 * i + 3, ::).toDenseVector
pr(::, i) := v / breeze.linalg.norm(v)
pr
}

/**
* returns a projection basis for rotation. that is the tangential speed for the single 2d rotation.
*/
given realignBasis2D: RealignExtendedBasis[_2D, EuclideanVector[_2D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[_2D, DDomain, EuclideanVector[_2D]], center: Point[_2D]): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

//derivative of the rotation matrix
val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0))
val dx = (dr * x).reshape(2 * np, 1)
val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0)
dx / n(0)
}



Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import scalismo.statisticalmodel.NaNStrategy.NanIsNumericValue
import scalismo.statisticalmodel.dataset.DataCollection
import scalismo.utils.{Memoize, Random}

import scala.annotation.threadUnsafe
import scala.language.higherKinds
import scala.collection.parallel.immutable.ParVector

Expand Down Expand Up @@ -364,18 +365,21 @@ class DiscreteLowRankGaussianProcess[D: NDSpace, DDomain[DD] <: DiscreteDomain[D
* around the cardinal directions.
*
* @param ids
* these define the parts of the domain that are aligned to
* @param withRotation
* True if the rotation should be included. False makes the realignment over translation exact.
* these define the parts of the domain that are aligned to. Depending on the withExtendedBasis parameter has a
* minimum requirements (default basis extension in 3D should be used with >=4 provided ids)
* @param withExtendedBasis
* True if the extended basis should be included. By default this uses a rotation extension. that means False makes
* the realignment over translation exact.
* @param diagonalize
* True if a diagonal basis should be returned. False is cheaper for exclusively drawing samples.
* @return
* The resulting [[DiscreteLowRankGaussianProcess]] aligned on the provided instances of [[PointId]]
*/
def realign(ids: IndexedSeq[PointId], withRotation: Boolean = true, diagonalize: Boolean = true)(implicit
vectorizer: Vectorizer[Value]
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
DiscreteLowRankGaussianProcess.realignment(this, ids, withRotation, diagonalize)
DiscreteLowRankGaussianProcess.realignment(this, ids, withExtendedBasis, diagonalize)
}

protected[statisticalmodel] def instanceVector(alpha: DenseVector[Double]): DenseVector[Double] = {
Expand Down Expand Up @@ -667,32 +671,22 @@ object DiscreteLowRankGaussianProcess {
def realignment[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD], Value](
model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
ids: IndexedSeq[PointId],
withRotation: Boolean,
withExtendedBasis: Boolean,
diagonalize: Boolean
)(implicit vectorizer: Vectorizer[Value]): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
model match
case m: DiscreteLowRankGaussianProcess[`_3D`, _, EuclideanVector[`_3D`]] =>
realignment3D(m, ids, withRotation, diagonalize)
case _ if !withRotation =>
// TODO general pure translation
throw new NotImplementedError("not yet implemented")
case _ =>
throw new NotImplementedError("not yet implemented")
}

private def realignment3D[DDomain[_3D] <: DiscreteDomain[_3D]](
model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]],
ids: IndexedSeq[PointId],
withRotation: Boolean,
diagonalize: Boolean
): DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]] = {
)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
val d = model.domain.pointSet.point(ids.head).dimensionality
// build the projection matrix for the desired pose
val p = {
val pt = breeze.linalg.tile(DenseMatrix.eye[Double](3), model.domain.pointSet.numberOfPoints, 1)
if withRotation then
@threadUnsafe
lazy val pt = breeze.linalg.tile(DenseMatrix.eye[Double](d), model.domain.pointSet.numberOfPoints, 1)
if withExtendedBasis then
val center = ids.map(id => model.domain.pointSet.point(id).toVector).reduce(_ + _).map(_ / ids.length).toPoint
val pr = getTangentialSpeedMatrix(model.domain, center)
DenseMatrix.horzcat(pt, pr)
val pr = realigning.getBasis[DDomain](model, center)
if realigning.useTranslation then DenseMatrix.horzcat(pt, pr)
else pr
else pt
}
// call the realignment implementation
Expand All @@ -701,12 +695,12 @@ object DiscreteLowRankGaussianProcess {
model.variance,
p,
ids.map(_.id),
dim = 3,
dim = d,
diagonalize = diagonalize,
projectMean = false
)

new DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]](model.domain, nmean, nvar, nbasis)
new DiscreteLowRankGaussianProcess[D, DDomain, Value](model.domain, nmean, nvar, nbasis)
}

private def realignmentComputation(mean: DenseVector[Double],
Expand Down Expand Up @@ -743,39 +737,6 @@ object DiscreteLowRankGaussianProcess {
(alignedMean, newbasis, news)
}

private def getTangentialSpeedMatrix[D: NDSpace](domain: DiscreteDomain[D], center: Point[D]): DenseMatrix[Double] = {
val dim = center.dimensionality
val np = domain.pointSet.numberOfPoints
require(dim >= 2, "requires at least two dimensions to calculate a tangential speed")
// build centered data matrix
val x = DenseMatrix.zeros[Double](dim, np)
val c = center.toBreezeVector
for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c

val pr = if dim == 3 then
val pr = DenseMatrix.zeros[Double](np * dim, 3)
// the derivative of the rotation matrix
val dr = new DenseMatrix[Double](9,
3,
Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
)
// get tangential speed
val dx = dr * x
for i <- 0 until 3 do
val v = dx(3 * i until 3 * i + 3, ::).toDenseVector
pr(::, i) := v / breeze.linalg.norm(v)
pr
else if dim == 2 then
val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0))
val dx = (dr * x).reshape(dim * np, 1)
val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0)
dx / n(0)
else throw new NotImplementedError("tangential speed only implemented for 2d and 3d space")

pr
}

}

// Explicit variants for 1D, 2D and 3D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ case class PointDistributionModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
/**
* realigns the [[DiscreteLowRankGaussianProcess]] and returns the resulting [[PointDistributionModel]]
*/
def realign(ids: IndexedSeq[PointId],
withRotation: Boolean = true,
diagonalize: Boolean = true
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[EuclideanVector[D]],
realign: RealignExtendedBasis[D, EuclideanVector[D]]
): PointDistributionModel[D, DDomain] = {
new PointDistributionModel[D, DDomain](this.gp.realign(ids, withRotation, diagonalize))
new PointDistributionModel[D, DDomain](this.gp.realign(ids, withExtendedBasis, diagonalize))
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ class GaussianProcessTests extends ScalismoTestSuite {
val f = Fixture
val dgp = f.discreteLowRankGp

val alignedDgp = dgp.realign(dgp.mean.pointsWithIds.map(t => t._2).toIndexedSeq, withRotation = true)
val alignedDgp = dgp.realign(dgp.mean.pointsWithIds.map(t => t._2).toIndexedSeq)

val shifts: IndexedSeq[Double] = alignedDgp.klBasis
.map(klp => {
Expand Down Expand Up @@ -769,7 +769,7 @@ class GaussianProcessTests extends ScalismoTestSuite {
})
rotations.map(m => m.data.map(math.abs).sum).sum
})
res(1) shouldBe <(res(0) * 0.9)
res(1) shouldBe <(res(0) * 0.6)
}
}
}
Expand Down

0 comments on commit ba6c5f1

Please sign in to comment.