Skip to content

Commit

Permalink
Separate L2 and normalized L2 distance classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
astamm committed Jan 9, 2024
1 parent d304788 commit 3999ed1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/kmaModelClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "polyCenterClass.h"
#include "pearsonDissimilarityClass.h"
#include "l2DissimilarityClass.h"
#include "normalizedL2DissimilarityClass.h"

#include "utilityFunctions.h"
#include "sharedFactoryClass.h"
Expand Down Expand Up @@ -100,6 +101,7 @@ void KmaModel::SetDissimilarityMethod(const std::string &val)
SharedFactory<BaseDissimilarityFunction> dissimilarityFactory;
dissimilarityFactory.Register<PearsonDissimilarityFunction>("pearson");
dissimilarityFactory.Register<L2DissimilarityFunction>("l2");
dissimilarityFactory.Register<NormalizedL2DissimilarityFunction>("normalized_l2");

m_DissimilarityPointer = dissimilarityFactory.Instantiate(val);

Expand Down Expand Up @@ -472,7 +474,7 @@ Rcpp::List KmaModel::FitModel()
unsigned int numberOfClusters = m_NumberOfClusters;
arma::rowvec observationDistances(m_NumberOfObservations, arma::fill::ones);
arma::rowvec oldObservationDistances(m_NumberOfObservations, arma::fill::zeros);
arma::urowvec observationMemberships(m_NumberOfObservations, arma::fill::ones);
arma::urowvec observationMemberships(m_NumberOfObservations, arma::fill::zeros);
arma::urowvec oldObservationMemberships(m_NumberOfObservations, arma::fill::zeros);
arma::urowvec clusterIndices = arma::linspace<arma::urowvec>(0, m_NumberOfClusters - 1, m_NumberOfClusters);
arma::mat warpedGrids = m_InputGrids;
Expand Down Expand Up @@ -584,6 +586,9 @@ Rcpp::List KmaModel::FitModel()
templateGrids.set_size(numberOfClusters, m_NumberOfPoints);
templateValues.set_size(numberOfClusters, m_NumberOfDimensions, m_NumberOfPoints);

if (m_UseVerbose)
Rcpp::Rcout << " - Updating templates" << std::endl;

this->UpdateTemplates(
numberOfIterations,
clusterIndices,
Expand Down Expand Up @@ -612,6 +617,7 @@ Rcpp::List KmaModel::FitModel()
observationDistances = oldObservationDistances;
observationMemberships = oldObservationMemberships;
--numberOfIterations;
break;
}
}
}
Expand Down
23 changes: 1 addition & 22 deletions src/l2DissimilarityClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,10 @@ double L2DissimilarityFunction::GetDistance(const arma::rowvec& grid1,
if (pair.Grid.is_empty())
return DBL_MAX;

unsigned int nDim = pair.Values1.n_rows;
unsigned int nPts = pair.Grid.size();

if (nPts <= 1.0)
return DBL_MAX;

double squaredDistanceValue = 0.0;
double squaredNorm1Value = 0.0;
double squaredNorm2Value = 0.0;

arma::rowvec workVector;

for (unsigned int k = 0;k < nDim;++k)
{
workVector = pair.Values1.row(k).cols(1, nPts - 1) - pair.Values2.row(k).cols(1, nPts - 1);
squaredDistanceValue += arma::dot(workVector, workVector);
workVector = pair.Values1.row(k).cols(1, nPts - 1);
squaredNorm1Value += arma::dot(workVector, workVector);
workVector = pair.Values2.row(k).cols(1, nPts - 1);
squaredNorm2Value += arma::dot(workVector, workVector);
}

double epsValue = std::sqrt(std::numeric_limits<double>::epsilon());
if (squaredNorm1Value < epsValue && squaredNorm2Value < epsValue)
return 0.0;

return std::sqrt(squaredDistanceValue) / (std::sqrt(squaredNorm1Value) + std::sqrt(squaredNorm2Value));
return std::sqrt(arma::sum(arma::trapz(pair.Grid, arma::pow(pair.Values1 - pair.Values2, 2.0), 1)));
}
28 changes: 28 additions & 0 deletions src/normalizedL2DissimilarityClass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "normalizedL2DissimilarityClass.h"

double NormalizedL2DissimilarityFunction::GetDistance(const arma::rowvec& grid1,
const arma::rowvec& grid2,
const arma::mat& values1,
const arma::mat& values2)
{
FunctionPairType pair = this->GetComparableFunctions(grid1, grid2, values1, values2);

if (pair.Grid.is_empty())
return DBL_MAX;

unsigned int nPts = pair.Grid.size();

if (nPts <= 1.0)
return DBL_MAX;

double squaredNorm1Value = arma::sum(arma::trapz(pair.Grid, arma::pow(pair.Values1, 2.0), 1));
double squaredNorm2Value = arma::sum(arma::trapz(pair.Grid, arma::pow(pair.Values2, 2.0), 1));

double epsValue = std::sqrt(std::numeric_limits<double>::epsilon());
if (squaredNorm1Value < epsValue && squaredNorm2Value < epsValue)
return 0.0;

double squaredDistanceValue = arma::sum(arma::trapz(pair.Grid, arma::pow(pair.Values1 - pair.Values2, 2.0), 1));

return std::sqrt(squaredDistanceValue) / (std::sqrt(squaredNorm1Value) + std::sqrt(squaredNorm2Value));
}
18 changes: 18 additions & 0 deletions src/normalizedL2DissimilarityClass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef NORMALIZEDL2DISSIMILARITYCLASS_H
#define NORMALIZEDL2DISSIMILARITYCLASS_H

#include "baseDissimilarityClass.h"

/// Normalized L2 Distance
class NormalizedL2DissimilarityFunction : public BaseDissimilarityFunction
{
public:
double GetDistance(
const arma::rowvec& grid1,
const arma::rowvec& grid2,
const arma::mat& values1,
const arma::mat& values2
);
};

#endif /* NORMALIZEDL2DISSIMILARITYCLASS_H */

0 comments on commit 3999ed1

Please sign in to comment.