Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Mixture Error #1318

Merged
merged 20 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; }
};

AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}

// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}
Expand Down Expand Up @@ -158,9 +158,9 @@ namespace gtsam {
}

/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
void print(const std::string& s = "",
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.8g") % v).str();
};
Expand Down
24 changes: 24 additions & 0 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,28 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_;
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousVals) const {
// functor to convert from GaussianConditional to double error value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert? -? calculate

auto errorFunc =
[continuousVals](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousVals);
} else {
// return arbitrarily large error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on when this would happen?

return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousVals,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteValues &discreteValues) const {
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousVals);
}

} // namespace gtsam
20 changes: 20 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousVals The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error.
dellaert marked this conversation as resolved.
Show resolved Hide resolved
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousVals The continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousVals,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteValues &discreteValues) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand Down
20 changes: 20 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,24 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
};
return {factors_, wrap};
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousVals) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousVals);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousVals,
const DiscreteValues &discreteValues) const {
auto factor = factors_(discreteValues);
return factor->error(continuousVals);
}

} // namespace gtsam
24 changes: 24 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@

#pragma once

#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>

namespace gtsam {

class GaussianFactorGraph;

// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;

/**
Expand Down Expand Up @@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
Sum add(const Sum &sum) const;

/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousVals The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error.
dellaert marked this conversation as resolved.
Show resolved Hide resolved
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousVals) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousVals The continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousVals,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteValues &discreteValues) const;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
Expand Down
47 changes: 47 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,51 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}

/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
return gbn.error(continuousValues);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;

for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);

if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
}

} // namespace gtsam
31 changes: 31 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,37 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);

/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/**
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute unnormalized probability for each discrete assignment,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const;

/// @}

private:
Expand Down
49 changes: 49 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,53 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
return ordering;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
factor_error = gaussianMixture->error(continuousValues);

if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();

double error = gaussian->error(continuousValues);
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
return prob_tree;
}

} // namespace gtsam
22 changes: 21 additions & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JacobianFactor;

/**
* @brief Main elimination function for HybridGaussianFactorGraph.
*
*
* @param factors The factor graph to eliminate.
* @param keys The elimination ordering.
* @return The conditional on the ordering keys and the remaining factors.
Expand Down Expand Up @@ -170,6 +170,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/**
* @brief Compute error for each discrete assignment,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;

/**
* @brief Compute unnormalized probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
26 changes: 23 additions & 3 deletions gtsam/hybrid/MixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor {
std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end()));

nonlinear_factors.push_back(
boost::dynamic_pointer_cast<NonlinearFactor>(f));
if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
nonlinear_factors.push_back(nf);
} else {
throw std::runtime_error(
"Factors passed into MixtureFactor need to be nonlinear!");
}
}
factors_ = Factors(discreteKeys, nonlinear_factors);

Expand All @@ -121,6 +125,22 @@ class MixtureFactor : public HybridFactor {

~MixtureFactor() = default;

/**
* @brief Compute error of the MixtureFactor as a tree.
*
* @param continuousVals The continuous values for which to compute the error.
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const Values& continuousVals) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousVals](const sharedFactor& factor) {
return factor->error(continuousVals);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/**
* @brief Compute error of factor given both continuous and discrete values.
*
Expand Down Expand Up @@ -149,7 +169,7 @@ class MixtureFactor : public HybridFactor {

/// print to stdout
void print(
const std::string& s = "MixtureFactor",
const std::string& s = "",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It anyway prints MixtureFactor below on line 176.

const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter);
Expand Down
Loading