Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harmonized error/logProbability in entire Factor/Conditional hierarchy #1380

Merged
merged 21 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>

Expand Down Expand Up @@ -56,6 +57,16 @@ namespace gtsam {
}
}

/* ************************************************************************ */
double DecisionTreeFactor::error(const DiscreteValues& values) const {
return -std::log(evaluate(values));
}

/* ************************************************************************ */
double DecisionTreeFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
24 changes: 22 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
namespace gtsam {

class DiscreteConditional;
class HybridValues;

/**
* A discrete probabilistic factor.
Expand Down Expand Up @@ -97,11 +98,20 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Value is just look up in AlgebraicDecisonTree
/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}

/// Evaluate probability density, sugar.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul);
Expand Down Expand Up @@ -230,7 +240,17 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;

/// @}
/// @}
/// @name HybridValues methods.
/// @{

/**
* Calculate error for HybridValues `x`, is -log(probability)
* Simply dispatches to DiscreteValues version.
*/
double error(const HybridValues& values) const override;

/// @}

private:
/** Serialization function */
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}

/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
Expand Down
13 changes: 12 additions & 1 deletion gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values);
}

//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;

/**
* @brief do ancestral sampling
*
Expand Down Expand Up @@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

///@}
/// @}
/// @name HybridValues methods.
/// @{

using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
Expand Down
22 changes: 21 additions & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#pragma once

#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional.h>

#include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp>
Expand Down Expand Up @@ -147,6 +147,11 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Standard Interface
/// @{

/// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const {
return -error(x);
}

/// print index signature only
void printSignature(
const std::string& s = "Discrete Conditional: ",
Expand Down Expand Up @@ -225,6 +230,21 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override;


/// @}
/// @name HybridValues methods.
/// @{

/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
*/
double logProbability(const HybridValues& x) const override {
return -error(x);
}

using DecisionTreeFactor::evaluate;

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
11 changes: 11 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>

#include <cmath>
#include <sstream>
Expand All @@ -27,6 +28,16 @@ using namespace std;

namespace gtsam {

/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
}

/* ************************************************************************* */
double DiscreteFactor::error(const HybridValues& c) const {
return this->error(c.discrete());
}

/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace gtsam {

class DecisionTreeFactor;
class DiscreteConditional;
class HybridValues;

/**
* Base class for discrete probabilistic factors
Expand Down Expand Up @@ -83,6 +84,15 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;

/// Error is just -log(value)
double error(const DiscreteValues& values) const;

/**
* The Factor::error simply extracts the \class DiscreteValues from the
* \class HybridValues and calculates the error.
*/
double error(const HybridValues& c) const override;

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

/// @}
/// @name HybridValues methods.
/// @{

using Base::error; // Expose error(const HybridValues&) method..

/// @}
}; // \ DiscreteFactorGraph

Expand Down
8 changes: 8 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
Expand Down Expand Up @@ -157,7 +160,12 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;

// Standard interface.
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;

gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);

// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
}

/* ************************************************************************* */
Expand Down
6 changes: 5 additions & 1 deletion gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

#include <CppUnitLite/TestHarness.h>


#include <iostream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));

// Check evaluate and logProbability
auto result = fg.optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);

// add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1");
Expand Down
23 changes: 23 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ TEST(DiscreteConditional, constructors3) {
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
}

/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
}

/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");

DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
}

/* ************************************************************************* */
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {
Expand Down
13 changes: 7 additions & 6 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,16 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
// functor to calculate to double logProbability value from
// GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
return conditional->logProbability(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Return arbitrarily large logProbability if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
Expand All @@ -289,10 +290,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}

/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
double GaussianMixture::logProbability(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
return conditional->logProbability(values.continuous());
}

} // namespace gtsam
16 changes: 7 additions & 9 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,23 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const;

/**
* @brief Compute error of the GaussianMixture as a tree.
* @brief Compute logProbability of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
* as the conditionals, and leaf values as the logProbability.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
* @brief Compute the logProbability of this Gaussian Mixture given the
* continuous values and a discrete assignment.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
double logProbability(const HybridValues &values) const override;

// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
Expand All @@ -188,9 +189,6 @@ class GTSAM_EXPORT GaussianMixture
// double operator()(const HybridValues &values) const { return
// evaluate(values); }

// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand Down
Loading