Skip to content

Commit

Permalink
Merge pull request #1380 from borglab/feature/uniform_error
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 12, 2023
2 parents e9485d1 + f4c3131 commit a34c463
Show file tree
Hide file tree
Showing 59 changed files with 828 additions and 421 deletions.
11 changes: 11 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>

Expand Down Expand Up @@ -56,6 +57,16 @@ namespace gtsam {
}
}

/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}

/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
24 changes: 22 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
namespace gtsam {

class DiscreteConditional;
class HybridValues;

/**
* A discrete probabilistic factor.
Expand Down Expand Up @@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Value is just look up in AlgebraicDecisonTree
/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}

/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul);
Expand Down Expand Up @@ -230,7 +240,17 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;

/// @}
/// @}
/// @name HybridValues methods.
/// @{

/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;

/// @}

private:
/** Serialization function */
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}

/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
Expand Down
13 changes: 12 additions & 1 deletion gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values);
}

//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;

/**
* @brief do ancestral sampling
*
Expand Down Expand Up @@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

///@}
/// @}
/// @name HybridValues methods.
/// @{

using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
Expand Down
22 changes: 21 additions & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#pragma once

#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>

#include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp>
Expand Down Expand Up @@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface
/// @{

/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}

/// print index signature only
void printSignature(
const std::string& s = "Discrete Conditional: ",
Expand Down Expand Up @@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;


/// @}
/// @name HybridValues methods.
/// @{

/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}

using DecisionTreeFactor::evaluate;

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
11 changes: 11 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>

#include <cmath>
#include <sstream>
Expand All @@ -27,6 +28,16 @@ using namespace std;

namespace gtsam {

/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
}

/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}

/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace gtsam {

class DecisionTreeFactor;
class DiscreteConditional;
class HybridValues;

/**
* Base class for discrete probabilistic factors
Expand Down Expand Up @@ -83,6 +84,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;

/// Error is just -log(value)
double error(const DiscreteValues& values) const;

/**
* The Factor::error simply extracts the \class DiscreteValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

/// @}
/// @name HybridValues methods.
/// @{

using Base::error; // Expose error(const HybridValues&) method..

/// @}
}; // \ DiscreteFactorGraph

Expand Down
8 changes: 8 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
Expand Down Expand Up @@ -157,7 +160,12 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;

// Standard interface.
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;

gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);

// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
}

/* ************************************************************************* */
Expand Down
6 changes: 5 additions & 1 deletion gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

#include <CppUnitLite/TestHarness.h>


#include <iostream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));

// Check evaluate and logProbability
auto result = fg.optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);

// add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1");
Expand Down
23 changes: 23 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
}

/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
}

/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");

DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
}

/* ************************************************************************* */
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {
Expand Down
13 changes: 7 additions & 6 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
// functor to calculate to double logProbability value from
// GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
return conditional->logProbability(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Return arbitrarily large logProbability if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
Expand All @@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}

/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
double GaussianMixture::logProbability(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
return conditional->logProbability(values.continuous());
}

} // namespace gtsam
16 changes: 7 additions & 9 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const;

/**
* @brief Compute error of the GaussianMixture as a tree.
* @brief Compute logProbability of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
* as the conditionals, and leaf values as the logProbability.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
* @brief Compute the logProbability of this Gaussian Mixture given the
* continuous values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
double logProbability(const HybridValues &values) const override;

// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
Expand All @@ -188,9 +189,6 @@ class GTSAM_EXPORT GaussianMixture
// double operator()(const HybridValues &values) const { return
// evaluate(values); }

// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand Down
Loading

0 comments on commit a34c463

Please sign in to comment.