Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Elimination with Normalization Constant #1360

Merged
merged 83 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
99825fc
use FactorAndConstant error()
varunagrawal Dec 31, 2022
383439e
make DecisionTree docstrings clearer
varunagrawal Dec 31, 2022
02d7b87
fix docstring
varunagrawal Dec 31, 2022
878eeb5
simplify the error addition
varunagrawal Dec 31, 2022
d0a56da
add logNormalizingConstant test for GaussianConditional and make some…
varunagrawal Dec 31, 2022
1e9cbeb
Merge branch 'develop' into hybrid/elimination
dellaert Dec 31, 2022
2e6f477
update all tests and mark things that need to be addressed
varunagrawal Jan 1, 2023
a233223
GraphAndConstant struct for hybrid elimination
varunagrawal Jan 1, 2023
efd8eb1
Switch to EliminateDiscrete for max-product
dellaert Dec 31, 2022
dcb07fe
Test eliminate
dellaert Dec 31, 2022
4d3bbf6
HBN::evaluate
dellaert Dec 28, 2022
b772d67
refactoring variables for clarity
dellaert Dec 28, 2022
be8008e
Also print mean if no parents
dellaert Dec 31, 2022
f6f782a
Add static
dellaert Dec 31, 2022
143022c
Tiny Bayes net example
dellaert Dec 31, 2022
df4fb13
fix comment
dellaert Dec 31, 2022
4023e71
continuousSubset
dellaert Dec 31, 2022
c8008cb
tiny FG
dellaert Dec 31, 2022
b463386
Made SumFrontals a method to test
dellaert Dec 31, 2022
92e2a39
Added factor and constant and removed factors method
dellaert Jan 1, 2023
fa76d53
refactored and documented SumFrontals
dellaert Jan 1, 2023
039c9b9
Test SumFrontals
dellaert Jan 1, 2023
526da2c
Add Testable to GraphAndConstant
dellaert Jan 1, 2023
b094953
Fix compile issues after rebase
dellaert Jan 1, 2023
4cb03b3
Fix SumFrontals test
dellaert Jan 1, 2023
7ab4c3e
Change to real test
dellaert Jan 1, 2023
6483130
Print estimated marginals and ratios!
dellaert Jan 1, 2023
dbd9faf
Fix quality testing
dellaert Jan 1, 2023
3d821ec
Now test elimination in c++
dellaert Jan 1, 2023
0095f73
attempt to fix elimination
dellaert Jan 1, 2023
665cb29
Make testcase exactly 5.0 mean
dellaert Jan 1, 2023
2c7b3a2
Refactoring in elimination
dellaert Jan 1, 2023
4d313fa
Comment on constant
dellaert Jan 1, 2023
064f17b
Added two-measurement example
dellaert Jan 2, 2023
8f5689b
minor readjustment
varunagrawal Jan 2, 2023
cae98e1
rename eliminate to eliminateFunc
varunagrawal Jan 2, 2023
312ba5f
Synced two examples
dellaert Jan 2, 2023
7c27061
Added missing methods
dellaert Jan 2, 2023
bd8d2ea
Added error for all versions - should become logDiensity?
dellaert Jan 2, 2023
e3a63aa
check for 0 sum
varunagrawal Jan 2, 2023
40a67b5
prune nonlinear by calling method rather than on the bayestree
varunagrawal Jan 2, 2023
021ee1a
Deterministic example, much more generic importance sampler
dellaert Jan 2, 2023
fbfc20b
Fixed conversion arguments
dellaert Jan 2, 2023
06aed53
rename
dellaert Jan 2, 2023
f8d75ab
name change of Sum to GaussianFactorGraphTree and SumFrontals to asse…
dellaert Jan 2, 2023
12d02be
Right marginals for tiny1
dellaert Jan 2, 2023
797ac34
Same correct error with factor_z.error()
dellaert Jan 2, 2023
625977e
Example with 2 measurements agrees with importance sampling
dellaert Jan 2, 2023
c3f0469
Add mean to test
dellaert Jan 2, 2023
f726cf6
f(x0, x1, m0; z0, z1) now has constant ratios !
dellaert Jan 2, 2023
5c87d7d
disable test instead of commenting out
varunagrawal Jan 3, 2023
66b846f
Merge branch 'hybrid/elimination' into hybrid/test_with_evaluate
varunagrawal Jan 3, 2023
38f3209
fix GaussianConditional print test
varunagrawal Jan 3, 2023
195dddf
clean up HybridGaussianFactorGraph
varunagrawal Jan 3, 2023
47346c5
move GraphAndConstant traits definition to HybridFactor
varunagrawal Jan 3, 2023
ca1c517
remove extra print statements
varunagrawal Jan 3, 2023
7825ffd
fix tests due to change to EliminateDiscrete
varunagrawal Jan 3, 2023
f117da2
remove extra print
varunagrawal Jan 3, 2023
cb885fb
check for nullptr in HybridConditional::equals
varunagrawal Jan 3, 2023
46acba5
serialize inner_, need to test
varunagrawal Jan 3, 2023
41c73fd
comment out failing tests, need to serialize DecisionTree
varunagrawal Jan 3, 2023
e01f7e7
kill unnecessary method
varunagrawal Jan 3, 2023
9e7fcc8
make header functions as inline
varunagrawal Jan 3, 2023
3771d63
simplify HybridConditional equality check
varunagrawal Jan 3, 2023
385ae34
Merge pull request #1363 from borglab/hybrid/test_with_evaluate-2
varunagrawal Jan 3, 2023
99a3fba
DecisionTree serialization
varunagrawal Jan 3, 2023
2653c2f
serialization support for DecisionTreeFactor
varunagrawal Jan 3, 2023
0ab15cc
fix equality of HybridDiscreteFactor and HybridGaussianFactor
varunagrawal Jan 3, 2023
3f2bff8
hybrid serialization tests
varunagrawal Jan 3, 2023
74142e4
GaussianMixture serialization
varunagrawal Jan 3, 2023
2bb4fd6
fix minor bug in HybridConditional and test its serialization
varunagrawal Jan 3, 2023
6fcc087
serialize DiscreteConditional
varunagrawal Jan 3, 2023
230f65d
serialization tests for HybridBayesNet and HybridBayesTree
varunagrawal Jan 3, 2023
4cb910c
move discrete serialization tests to common file to remove export key…
varunagrawal Jan 4, 2023
5cdff9e
Merge pull request #1362 from borglab/hybrid/test_with_evaluate
varunagrawal Jan 4, 2023
bb31956
enable previously failing test, now works!!
varunagrawal Jan 4, 2023
ed16c33
add check in GaussianMixtureFactor::likelihood
varunagrawal Jan 4, 2023
34daecd
remove deferredFactors
varunagrawal Jan 4, 2023
7dd4bc9
implement dim() for MixtureFactor
varunagrawal Jan 4, 2023
b16480b
simplify code
varunagrawal Jan 4, 2023
25fd618
rename getAssignment to assignment, fix printing
varunagrawal Jan 4, 2023
b62f397
Merge pull request #1368 from borglab/hybrid/serialization
varunagrawal Jan 4, 2023
78926f7
Merge pull request #1369 from borglab/hybrid/various-fixes
varunagrawal Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ namespace gtsam {
*/
size_t nrAssignments_;

/// Default constructor for serialization.
Leaf() {}

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
Expand Down Expand Up @@ -154,6 +157,18 @@ namespace gtsam {
}

bool isLeaf() const override { return true; }

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf

/****************************************************************************/
Expand All @@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>;

public:
/// Default constructor for serialization.
Choice() {}

~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
Expand Down Expand Up @@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice

/****************************************************************************/
Expand Down
27 changes: 23 additions & 4 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>

#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp>
#include <functional>
#include <iostream>
Expand Down Expand Up @@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
};
/** ------------------------ Node base class --------------------------- */

Expand Down Expand Up @@ -236,7 +244,7 @@ namespace gtsam {
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f (side-effect) Function taking a value.
* @param f (side-effect) Function taking the value of the leaf node.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
Expand All @@ -245,7 +253,7 @@ namespace gtsam {
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
* tree.visit(visitor);
*/
template <typename Func>
void visit(Func f) const;
Expand All @@ -261,8 +269,8 @@ namespace gtsam {
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
* auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); };
* tree.visitLeaf(visitor);
*/
template <typename Func>
void visitLeaf(Func f) const;
Expand Down Expand Up @@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree

template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};

/** free versions of apply */

/// Apply unary operator `op` to DecisionTree `f`.
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ namespace gtsam {
const Names& names = {}) const override;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
};

// traits
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
};
// DiscreteConditional

Expand Down
7 changes: 3 additions & 4 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
// #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>

#include <CppUnitLite/TestHarness.h>

using namespace std;
using namespace gtsam;

Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* @date Feb 14, 2011
*/

#include <boost/make_shared.hpp>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>

#include <boost/make_shared.hpp>

using namespace std;
using namespace gtsam;

Expand Down Expand Up @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
Expand Down
105 changes: 105 additions & 0 deletions gtsam/discrete/tests/testSerializationDiscrete.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/*
* testSerializtionDiscrete.cpp
*
* @date January 2023
* @author Varun Agrawal
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/Symbol.h>

using namespace std;
using namespace gtsam;

using Tree = gtsam::DecisionTree<string, int>;

BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")

BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");

using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")

/* ****************************************************************************/
// Test DecisionTree serialization.
TEST(DiscreteSerialization, DecisionTree) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});

using namespace serializationTestHelpers;

// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));

// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));

// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}

/* ************************************************************************* */
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DiscreteSerialization, DecisionTreeFactor) {
using namespace serializationTestHelpers;

DiscreteKey A(1, 2), B(2, 2), C(3, 2);

DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));

DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}

/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
using namespace serializationTestHelpers;

DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");

EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));

DiscreteDistribution P(A % "3/2/1");
EXPECT(equalsObj<DiscreteDistribution>(P));
EXPECT(equalsXML<DiscreteDistribution>(P));
EXPECT(equalsBinary<DiscreteDistribution>(P));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
41 changes: 29 additions & 12 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,29 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const {
using Y = GaussianFactorGraph;
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(factor);
return result;
result.push_back(conditional);
if (conditional) {
return GraphAndConstant(
result, conditional->logNormalizationConstant());
} else {
return GraphAndConstant(result, 0.0);
}
};
return {conditionals_, lambda};
}
Expand Down Expand Up @@ -98,7 +103,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
if (e == nullptr) return false;

// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
}

/* *******************************************************************************/
Expand Down
Loading