From 58068503f4e79869a46c4ba0ce1aa5b5e7048e86 Mon Sep 17 00:00:00 2001 From: sjxue Date: Sun, 14 Aug 2022 17:39:35 -0400 Subject: [PATCH 1/5] HybridValues and optimize() method for hybrid gaussian bayes net --- gtsam/hybrid/HybridBayesNet.cpp | 13 + gtsam/hybrid/HybridBayesNet.h | 9 +- gtsam/hybrid/HybridLookupDAG.cpp | 73 +++++ gtsam/hybrid/HybridLookupDAG.h | 118 ++++++++ gtsam/hybrid/HybridValues.h | 138 +++++++++ gtsam/hybrid/hybrid.i | 17 ++ .../tests/testGaussianHybridFactorGraph.cpp | 24 ++ gtsam/hybrid/tests/testHybridLookupDAG.cpp | 276 ++++++++++++++++++ gtsam/hybrid/tests/testHybridValues.cpp | 59 ++++ 9 files changed, 726 insertions(+), 1 deletion(-) create mode 100644 gtsam/hybrid/HybridLookupDAG.cpp create mode 100644 gtsam/hybrid/HybridLookupDAG.h create mode 100644 gtsam/hybrid/HybridValues.h create mode 100644 gtsam/hybrid/tests/testHybridLookupDAG.cpp create mode 100644 gtsam/hybrid/tests/testHybridValues.cpp diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 1292711d89..fc0b93c5c9 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -10,7 +10,20 @@ * @file HybridBayesNet.cpp * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @author Fan Jiang + * @author Shangjie Xue * @date January 2022 */ #include +#include +#include + +namespace gtsam { + +/* *******************************************************************************/ +HybridValues HybridBayesNet::optimize() const { + auto dag = HybridLookupDAG::FromBayesNet(*this); + return dag.argmax(); +} + +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 43eead2801..5665ce5c99 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include namespace gtsam { @@ -34,8 +35,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using shared_ptr = boost::shared_ptr; using sharedConditional = boost::shared_ptr; - /** Construct empty bayes net */ + /// Construct empty bayes net HybridBayesNet() = default; + + /// Destructor + virtual ~HybridBayesNet() {} + + /// Solve the HybridBayesNet by back-substitution. + HybridValues optimize() const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp new file mode 100644 index 0000000000..7232309f4d --- /dev/null +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -0,0 +1,73 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.cpp + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + + + +/* ************************************************************************** */ +void HybridLookupTable::argmaxInPlace(HybridValues* values) const { + // For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable. + if (isDiscrete()){ + boost::static_pointer_cast(inner_)->argmaxInPlace(&(values->discrete)); + } else if (isContinuous()){ + // For Gaussian conditional, uses solve() method in GaussianConditional. + values->continuous.insert(boost::static_pointer_cast(inner_)->solve(values->continuous)); + } else if (isHybrid()){ + // For hybrid conditional, since children should not contain discrete variable, we can condition on + // the discrete variable in the parents and solve the resulting GaussianConditional. + auto conditional = boost::static_pointer_cast(inner_)->conditionals()(values->discrete); + values->continuous.insert(conditional->solve(values->continuous)); + } +} + + +// /* ************************************************************************** */ +HybridLookupDAG HybridLookupDAG::FromBayesNet( + const HybridBayesNet& bayesNet) { + HybridLookupDAG dag; + for (auto&& conditional : bayesNet) { + HybridLookupTable hlt(*conditional); + dag.push_back(hlt); + } + return dag; +} + +HybridValues HybridLookupDAG::argmax(HybridValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/hybrid/HybridLookupDAG.h b/gtsam/hybrid/HybridLookupDAG.h new file mode 100644 index 0000000000..903cc55190 --- /dev/null +++ b/gtsam/hybrid/HybridLookupDAG.h @@ -0,0 +1,118 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridLookupDAG.h + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +/** + * @brief HybridLookupTable table for max-product + * + * Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience. + * Is used in the max-product algorithm. + */ +class GTSAM_EXPORT HybridLookupTable : public HybridConditional { + public: + using Base = HybridConditional; + using This = HybridLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Hybrid Lookup Table object form a HybridConditional. + * + * @param conditional input hybrid conditional + */ + HybridLookupTable(HybridConditional& conditional) : Base(conditional){}; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(HybridValues* parentsValues) const; +}; + +/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */ +class GTSAM_EXPORT HybridLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = HybridLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + HybridLookupDAG() {} + + /// Create from BayesNet with LookupTables + static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet); + + /// Destructor + virtual ~HybridLookupDAG() {} + + /// @} + + /// @name Standard Interface + /// @{ + + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + /** + * @brief argmax by back-substitution, optionally given certain variables. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + HybridValues argmax(HybridValues given = HybridValues()) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h new file mode 100644 index 0000000000..98f862279a --- /dev/null +++ b/gtsam/hybrid/HybridValues.h @@ -0,0 +1,138 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridValues.h + * @date Jul 28, 2022 + * @author Shangjie Xue + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + + +#include +#include +#include + +namespace gtsam { + +/** + * HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables + * of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class. + */ +class GTSAM_EXPORT HybridValues { + public: + // DiscreteValue stored the discrete components of the HybridValues. + DiscreteValues discrete; + + // VectorValue stored the continuous components of the HybridValues. + VectorValues continuous; + + // Default constructor creates an empty HybridValues. + HybridValues() : discrete(), continuous() {}; + + // Construct from DiscreteValues and VectorValues. + HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {}; + + // print required by Testable for unit testing + void print(const std::string& s = "HybridValues", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + std::cout << s << ": \n"; + discrete.print(" Discrete", keyFormatter); + continuous.print(" Continuous", keyFormatter); + }; + + // equals required by Testable for unit testing + bool equals(const HybridValues& other, double tol = 1e-9) const { + return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol); + } + + // Check whether a variable with key \c j exists in DiscreteValue. + bool existsDiscrete(Key j){ + return (discrete.find(j) != discrete.end()); + }; + + // Check whether a variable with key \c j exists in VectorValue. + bool existsVector(Key j){ + return continuous.exists(j); + }; + + // Check whether a variable with key \c j exists. + bool exists(Key j){ + return existsDiscrete(j) || existsVector(j); + }; + + /** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c + * j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, int value){ + discrete[j] = value; + }; + + /** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c + * j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, const Vector& value) { + continuous.insert(j, value); + } + + // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h + + /** + * Read/write access to the discrete value with key \c j, throws + * std::out_of_range if \c j does not exist. + */ + size_t& atDiscrete(Key j){ + return discrete.at(j); + }; + + /** + * Read/write access to the vector value with key \c j, throws + * std::out_of_range if \c j does not exist. + */ + Vector& at(Key j) { + return continuous.at(j); + }; + + + /// @name Wrapper support + /// @{ + + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + * @return string html output. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + std::stringstream ss; + ss << this->discrete.html(keyFormatter); + ss << this->continuous.html(keyFormatter); + return ss.str(); + }; + + /// @} +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index bbe1e2400d..207f3ff635 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -4,6 +4,22 @@ namespace gtsam { +#include +class HybridValues { + gtsam::DiscreteValues discrete; + gtsam::VectorValues continuous; + HybridValues(); + HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv); + void print(string s = "HybridValues", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridValues& other, double tol) const; + void insert(gtsam::Key j, int value); + void insert(gtsam::Key j, const gtsam::Vector& value); + size_t& atDiscrete(gtsam::Key j); + gtsam::Vector& at(gtsam::Key j); +}; + #include virtual class HybridFactor { void print(string s = "HybridFactor\n", @@ -84,6 +100,7 @@ class HybridBayesNet { size_t size() const; gtsam::KeySet keys() const; const gtsam::HybridConditional* at(size_t i) const; + gtsam::HybridValues optimize() const; void print(string s = "HybridBayesNet\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index 552bb18f59..b1dbda836b 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -585,6 +587,28 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +TEST(HybridGaussianFactorGraph, optimize) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + + auto result = + hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)})); + + HybridValues hv = result->optimize(); + + EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); + +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridLookupDAG.cpp b/gtsam/hybrid/tests/testHybridLookupDAG.cpp new file mode 100644 index 0000000000..70b09ecbe9 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridLookupDAG.cpp @@ -0,0 +1,276 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridLookupDAG.cpp + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +TEST(HybridLookupTable, basics) { + // create a conditional gaussian node + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 2; + S1(0, 1) = 3; + S1(1, 1) = 4; + + Matrix S2(2, 2); + S2(0, 0) = 6; + S2(1, 0) = 0.2; + S2(0, 1) = 8; + S2(1, 1) = 0.4; + + Matrix R1(2, 2); + R1(0, 0) = 0.1; + R1(1, 0) = 0.3; + R1(0, 1) = 0.0; + R1(1, 1) = 0.34; + + Matrix R2(2, 2); + R2(0, 0) = 0.1; + R2(1, 0) = 0.3; + R2(0, 1) = 0.0; + R2(1, 1) = 0.34; + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + Vector2 d1(0.2, 0.5), d2(0.5, 0.2); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); +// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); + + boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); + + HybridConditional hc(mixtureFactor); + + GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast(hc.inner())->conditionals(); + + DiscreteValues dv; + dv[1]=1; + + VectorValues cv; + cv.insert(X(2),Vector2(0.0, 0.0)); + + HybridValues hv(dv, cv); + + + + // std::cout << conditional2(values).markdown(); + EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); + EXPECT(conditional2(dv)==conditionals(dv)); + HybridLookupTable hlt(hc); + +// hlt.argmaxInPlace(&hv); + + HybridLookupDAG dag; + dag.push_back(hlt); + dag.argmax(hv); + +// HybridBayesNet hbn; +// hbn.push_back(hc); +// hbn.optimize(); + +} + +TEST(HybridLookupTable, hybrid_argmax) { + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = boost::make_shared(X(1), d1, S1, model), + conditional1 = boost::make_shared(X(1), d2, S1, model); + + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals)); + + HybridConditional hc(mixtureFactor); + + DiscreteValues dv; + dv[1]=1; + VectorValues cv; + // cv.insert(X(2),Vector2(0.0, 0.0)); + HybridValues hv(dv, cv); + + HybridLookupTable hlt(hc); + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.at(X(1)), d2)); + + +} + +TEST(HybridLookupTable, discrete_argmax) { + DiscreteKey X(0, 2), Y(1, 2); + + auto conditional = boost::make_shared(X | Y = "0/1 3/2"); + + HybridConditional hc(conditional); + + HybridLookupTable hlt(hc); + + DiscreteValues dv; + dv[1]=0; + VectorValues cv; + // cv.insert(X(2),Vector2(0.0, 0.0)); + HybridValues hv(dv, cv); + + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.atDiscrete(0), 1)); + + DecisionTreeFactor f1(X , "2 3"); + auto conditional2 = boost::make_shared(1,f1); + + HybridConditional hc2(conditional2); + + HybridLookupTable hlt2(hc2); + + HybridValues hv2; + + + hlt2.argmaxInPlace(&hv2); + + EXPECT(assert_equal(hv2.atDiscrete(0), 1)); +} + +TEST(HybridLookupTable, gaussian_argmax) { + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional = boost::make_shared(X(1), d1, S1, + X(2), -S1, model); + + HybridConditional hc(conditional); + + HybridLookupTable hlt(hc); + + DiscreteValues dv; + // dv[1]=0; + VectorValues cv; + cv.insert(X(2),d2); + HybridValues hv(dv, cv); + + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.at(X(1)), d1+d2)); + +} + +TEST(HybridLookupDAG, argmax) { + + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = boost::make_shared(X(2), d1, S1, model), + conditional1 = boost::make_shared(X(2), d2, S1, model); + + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + boost::shared_ptr mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals)); + HybridConditional hc2(mixtureFactor); + HybridLookupTable hlt2(hc2); + + + auto conditional2 = boost::make_shared(X(1), d1, S1, + X(2), -S1, model); + + HybridConditional hc1(conditional2); + HybridLookupTable hlt1(hc1); + + DecisionTreeFactor f1(m1 , "2 3"); + auto discrete_conditional = boost::make_shared(1,f1); + + HybridConditional hc3(discrete_conditional); + HybridLookupTable hlt3(hc3); + + HybridLookupDAG dag; + dag.push_back(hlt1); + dag.push_back(hlt2); + dag.push_back(hlt3); + auto hv = dag.argmax(); + + EXPECT(assert_equal(hv.atDiscrete(1), 1)); + EXPECT(assert_equal(hv.at(X(2)), d2)); + EXPECT(assert_equal(hv.at(X(1)), d2+d1)); +} + + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file diff --git a/gtsam/hybrid/tests/testHybridValues.cpp b/gtsam/hybrid/tests/testHybridValues.cpp new file mode 100644 index 0000000000..3e821aef2f --- /dev/null +++ b/gtsam/hybrid/tests/testHybridValues.cpp @@ -0,0 +1,59 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridValues.cpp + * @date Jul 28, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + + +using namespace std; +using namespace gtsam; + +TEST(HybridValues, basics) { + HybridValues values; + values.insert(99, Vector2(2, 3)); + values.insert(100, 3); + EXPECT(assert_equal(values.at(99), Vector2(2, 3))); + EXPECT(assert_equal(values.atDiscrete(100), int(3))); + + values.print(); + + HybridValues values2; + values2.insert(100, 3); + values2.insert(99, Vector2(2, 3)); + EXPECT(assert_equal(values2, values)); + + values2.insert(98, Vector2(2,3)); + EXPECT(!assert_equal(values2, values)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file From 7d36a9eb98803bc4712917bee86583fec50ad030 Mon Sep 17 00:00:00 2001 From: sjxue Date: Sun, 14 Aug 2022 21:08:31 -0400 Subject: [PATCH 2/5] add some comments --- gtsam/hybrid/HybridBayesNet.h | 1 + gtsam/hybrid/HybridValues.h | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index c376c9a86f..b19528120a 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -64,6 +64,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { GaussianBayesNet choose(const DiscreteValues &assignment) const; /// Solve the HybridBayesNet by back-substitution. + /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and put this method there? HybridValues optimize() const; }; diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index 98f862279a..89f7bb58a7 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -53,8 +53,8 @@ class GTSAM_EXPORT HybridValues { void print(const std::string& s = "HybridValues", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ std::cout << s << ": \n"; - discrete.print(" Discrete", keyFormatter); - continuous.print(" Continuous", keyFormatter); + discrete.print(" Discrete", keyFormatter); // print discrete components + continuous.print(" Continuous", keyFormatter); //print continuous components }; // equals required by Testable for unit testing From 379a65f40fd502c04eacd69a871575c0771a074b Mon Sep 17 00:00:00 2001 From: sjxue Date: Tue, 16 Aug 2022 18:26:59 -0400 Subject: [PATCH 3/5] Address review comments --- gtsam/hybrid/HybridBayesNet.h | 3 +- gtsam/hybrid/HybridLookupDAG.cpp | 40 ++++--- gtsam/hybrid/HybridLookupDAG.h | 11 +- gtsam/hybrid/HybridValues.h | 73 +++++------- .../tests/testGaussianHybridFactorGraph.cpp | 1 - gtsam/hybrid/tests/testHybridLookupDAG.cpp | 112 +++++++++--------- gtsam/hybrid/tests/testHybridValues.cpp | 11 +- 7 files changed, 120 insertions(+), 131 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index b19528120a..9d6d5f2361 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { GaussianBayesNet choose(const DiscreteValues &assignment) const; /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and put this method there? + /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and + /// put this method there? HybridValues optimize() const; }; diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp index 7232309f4d..7acff081bf 100644 --- a/gtsam/hybrid/HybridLookupDAG.cpp +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -18,10 +18,10 @@ #include #include #include -#include -#include #include +#include #include +#include #include #include @@ -32,28 +32,32 @@ using std::vector; namespace gtsam { - - /* ************************************************************************** */ void HybridLookupTable::argmaxInPlace(HybridValues* values) const { - // For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable. - if (isDiscrete()){ - boost::static_pointer_cast(inner_)->argmaxInPlace(&(values->discrete)); - } else if (isContinuous()){ + // For discrete conditional, uses argmaxInPlace() method in + // DiscreteLookupTable. + if (isDiscrete()) { + boost::static_pointer_cast(inner_)->argmaxInPlace( + &(values->discrete)); + } else if (isContinuous()) { // For Gaussian conditional, uses solve() method in GaussianConditional. - values->continuous.insert(boost::static_pointer_cast(inner_)->solve(values->continuous)); - } else if (isHybrid()){ - // For hybrid conditional, since children should not contain discrete variable, we can condition on - // the discrete variable in the parents and solve the resulting GaussianConditional. - auto conditional = boost::static_pointer_cast(inner_)->conditionals()(values->discrete); + values->continuous.insert( + boost::static_pointer_cast(inner_)->solve( + values->continuous)); + } else if (isHybrid()) { + // For hybrid conditional, since children should not contain discrete + // variable, we can condition on the discrete variable in the parents and + // solve the resulting GaussianConditional. + auto conditional = + boost::static_pointer_cast(inner_)->conditionals()( + values->discrete); values->continuous.insert(conditional->solve(values->continuous)); - } + } } - -// /* ************************************************************************** */ -HybridLookupDAG HybridLookupDAG::FromBayesNet( - const HybridBayesNet& bayesNet) { +// /* ************************************************************************** +// */ +HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { HybridLookupDAG dag; for (auto&& conditional : bayesNet) { HybridLookupTable hlt(*conditional); diff --git a/gtsam/hybrid/HybridLookupDAG.h b/gtsam/hybrid/HybridLookupDAG.h index 903cc55190..cc1c58c58f 100644 --- a/gtsam/hybrid/HybridLookupDAG.h +++ b/gtsam/hybrid/HybridLookupDAG.h @@ -19,10 +19,10 @@ #include #include -#include -#include #include #include +#include +#include #include #include @@ -34,8 +34,8 @@ namespace gtsam { /** * @brief HybridLookupTable table for max-product * - * Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience. - * Is used in the max-product algorithm. + * Similar to DiscreteLookupTable, inherits from hybrid conditional for + * convenience. Is used in the max-product algorithm. */ class GTSAM_EXPORT HybridLookupTable : public HybridConditional { public: @@ -58,7 +58,8 @@ class GTSAM_EXPORT HybridLookupTable : public HybridConditional { void argmaxInPlace(HybridValues* parentsValues) const; }; -/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */ +/** A DAG made from hybrid lookup tables, as defined above. Similar to + * DiscreteLookupDAG */ class GTSAM_EXPORT HybridLookupDAG : public BayesNet { public: using Base = BayesNet; diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index 89f7bb58a7..5e1bd41646 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -19,11 +19,10 @@ #include #include +#include #include -#include #include -#include - +#include #include #include @@ -32,8 +31,9 @@ namespace gtsam { /** - * HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables - * of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class. + * HybridValues represents a collection of DiscreteValues and VectorValues. It + * is typically used to store the variables of a HybridGaussianFactorGraph. + * Optimizing a HybridGaussianBayesNet returns this class. */ class GTSAM_EXPORT HybridValues { public: @@ -44,54 +44,47 @@ class GTSAM_EXPORT HybridValues { VectorValues continuous; // Default constructor creates an empty HybridValues. - HybridValues() : discrete(), continuous() {}; + HybridValues() : discrete(), continuous(){}; // Construct from DiscreteValues and VectorValues. - HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {}; + HybridValues(const DiscreteValues& dv, const VectorValues& cv) + : discrete(dv), continuous(cv){}; // print required by Testable for unit testing void print(const std::string& s = "HybridValues", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::cout << s << ": \n"; - discrete.print(" Discrete", keyFormatter); // print discrete components - continuous.print(" Continuous", keyFormatter); //print continuous components + discrete.print(" Discrete", keyFormatter); // print discrete components + continuous.print(" Continuous", + keyFormatter); // print continuous components }; // equals required by Testable for unit testing bool equals(const HybridValues& other, double tol = 1e-9) const { - return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol); + return discrete.equals(other.discrete, tol) && + continuous.equals(other.continuous, tol); } // Check whether a variable with key \c j exists in DiscreteValue. - bool existsDiscrete(Key j){ - return (discrete.find(j) != discrete.end()); - }; + bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; // Check whether a variable with key \c j exists in VectorValue. - bool existsVector(Key j){ - return continuous.exists(j); - }; + bool existsVector(Key j) { return continuous.exists(j); }; // Check whether a variable with key \c j exists. - bool exists(Key j){ - return existsDiscrete(j) || existsVector(j); - }; + bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; - /** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c - * j is already used. - * @param value The vector to be inserted. - * @param j The index with which the value will be associated. */ - void insert(Key j, int value){ - discrete[j] = value; - }; + /** Insert a discrete \c value with key \c j. Replaces the existing value if + * the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, int value) { discrete[j] = value; }; - /** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c - * j is already used. - * @param value The vector to be inserted. - * @param j The index with which the value will be associated. */ - void insert(Key j, const Vector& value) { - continuous.insert(j, value); - } + /** Insert a vector \c value with key \c j. Throws an invalid_argument + * exception if the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, const Vector& value) { continuous.insert(j, value); } // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h @@ -99,18 +92,13 @@ class GTSAM_EXPORT HybridValues { * Read/write access to the discrete value with key \c j, throws * std::out_of_range if \c j does not exist. */ - size_t& atDiscrete(Key j){ - return discrete.at(j); - }; + size_t& atDiscrete(Key j) { return discrete.at(j); }; /** * Read/write access to the vector value with key \c j, throws * std::out_of_range if \c j does not exist. */ - Vector& at(Key j) { - return continuous.at(j); - }; - + Vector& at(Key j) { return continuous.at(j); }; /// @name Wrapper support /// @{ @@ -121,7 +109,8 @@ class GTSAM_EXPORT HybridValues { * @param keyFormatter function that formats keys. * @return string html output. */ - std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::stringstream ss; ss << this->discrete.html(keyFormatter); ss << this->continuous.html(keyFormatter); diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index 7e532b0137..17a2d94d74 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -523,7 +523,6 @@ TEST(HybridGaussianFactorGraph, optimize) { HybridValues hv = result->optimize(); EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); - } /* ************************************************************************* */ int main() { diff --git a/gtsam/hybrid/tests/testHybridLookupDAG.cpp b/gtsam/hybrid/tests/testHybridLookupDAG.cpp index 70b09ecbe9..c472aa22f5 100644 --- a/gtsam/hybrid/tests/testHybridLookupDAG.cpp +++ b/gtsam/hybrid/tests/testHybridLookupDAG.cpp @@ -17,19 +17,19 @@ #include #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include #include #include -#include +#include +#include +#include #include +#include +#include // Include for test suite #include @@ -43,7 +43,7 @@ using symbol_shorthand::M; using symbol_shorthand::X; TEST(HybridLookupTable, basics) { - // create a conditional gaussian node + // create a conditional gaussian node Matrix S1(2, 2); S1(0, 0) = 1; S1(1, 0) = 2; @@ -82,39 +82,38 @@ TEST(HybridLookupTable, basics) { GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); -// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); - - boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); - + // GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); + + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); + HybridConditional hc(mixtureFactor); - GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast(hc.inner())->conditionals(); + GaussianMixture::Conditionals conditional2 = + boost::static_pointer_cast(hc.inner())->conditionals(); DiscreteValues dv; - dv[1]=1; + dv[1] = 1; VectorValues cv; - cv.insert(X(2),Vector2(0.0, 0.0)); - - HybridValues hv(dv, cv); + cv.insert(X(2), Vector2(0.0, 0.0)); - + HybridValues hv(dv, cv); // std::cout << conditional2(values).markdown(); EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); - EXPECT(conditional2(dv)==conditionals(dv)); + EXPECT(conditional2(dv) == conditionals(dv)); HybridLookupTable hlt(hc); -// hlt.argmaxInPlace(&hv); - + // hlt.argmaxInPlace(&hv); + HybridLookupDAG dag; dag.push_back(hlt); dag.argmax(hv); -// HybridBayesNet hbn; -// hbn.push_back(hc); -// hbn.optimize(); - + // HybridBayesNet hbn; + // hbn.push_back(hc); + // hbn.optimize(); } TEST(HybridLookupTable, hybrid_argmax) { @@ -124,23 +123,26 @@ TEST(HybridLookupTable, hybrid_argmax) { S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional0 = boost::make_shared(X(1), d1, S1, model), - conditional1 = boost::make_shared(X(1), d2, S1, model); + auto conditional0 = + boost::make_shared(X(1), d1, S1, model), + conditional1 = + boost::make_shared(X(1), d2, S1, model); DiscreteKey m1(1, 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals)); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {}, {m1}, conditionals)); HybridConditional hc(mixtureFactor); DiscreteValues dv; - dv[1]=1; + dv[1] = 1; VectorValues cv; // cv.insert(X(2),Vector2(0.0, 0.0)); HybridValues hv(dv, cv); @@ -150,8 +152,6 @@ TEST(HybridLookupTable, hybrid_argmax) { hlt.argmaxInPlace(&hv); EXPECT(assert_equal(hv.at(X(1)), d2)); - - } TEST(HybridLookupTable, discrete_argmax) { @@ -164,18 +164,17 @@ TEST(HybridLookupTable, discrete_argmax) { HybridLookupTable hlt(hc); DiscreteValues dv; - dv[1]=0; + dv[1] = 0; VectorValues cv; // cv.insert(X(2),Vector2(0.0, 0.0)); HybridValues hv(dv, cv); - hlt.argmaxInPlace(&hv); EXPECT(assert_equal(hv.atDiscrete(0), 1)); - DecisionTreeFactor f1(X , "2 3"); - auto conditional2 = boost::make_shared(1,f1); + DecisionTreeFactor f1(X, "2 3"); + auto conditional2 = boost::make_shared(1, f1); HybridConditional hc2(conditional2); @@ -183,7 +182,6 @@ TEST(HybridLookupTable, discrete_argmax) { HybridValues hv2; - hlt2.argmaxInPlace(&hv2); EXPECT(assert_equal(hv2.atDiscrete(0), 1)); @@ -196,12 +194,12 @@ TEST(HybridLookupTable, gaussian_argmax) { S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional = boost::make_shared(X(1), d1, S1, - X(2), -S1, model); + auto conditional = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); HybridConditional hc(conditional); @@ -210,52 +208,51 @@ TEST(HybridLookupTable, gaussian_argmax) { DiscreteValues dv; // dv[1]=0; VectorValues cv; - cv.insert(X(2),d2); + cv.insert(X(2), d2); HybridValues hv(dv, cv); - hlt.argmaxInPlace(&hv); - EXPECT(assert_equal(hv.at(X(1)), d1+d2)); - + EXPECT(assert_equal(hv.at(X(1)), d1 + d2)); } TEST(HybridLookupDAG, argmax) { - Matrix S1(2, 2); S1(0, 0) = 1; S1(1, 0) = 0; S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional0 = boost::make_shared(X(2), d1, S1, model), - conditional1 = boost::make_shared(X(2), d2, S1, model); + auto conditional0 = + boost::make_shared(X(2), d1, S1, model), + conditional1 = + boost::make_shared(X(2), d2, S1, model); DiscreteKey m1(1, 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals)); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(2)}, {}, {m1}, conditionals)); HybridConditional hc2(mixtureFactor); HybridLookupTable hlt2(hc2); - - auto conditional2 = boost::make_shared(X(1), d1, S1, - X(2), -S1, model); + auto conditional2 = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); HybridConditional hc1(conditional2); HybridLookupTable hlt1(hc1); - DecisionTreeFactor f1(m1 , "2 3"); - auto discrete_conditional = boost::make_shared(1,f1); + DecisionTreeFactor f1(m1, "2 3"); + auto discrete_conditional = boost::make_shared(1, f1); HybridConditional hc3(discrete_conditional); HybridLookupTable hlt3(hc3); - + HybridLookupDAG dag; dag.push_back(hlt1); dag.push_back(hlt2); @@ -264,10 +261,9 @@ TEST(HybridLookupDAG, argmax) { EXPECT(assert_equal(hv.atDiscrete(1), 1)); EXPECT(assert_equal(hv.at(X(2)), d2)); - EXPECT(assert_equal(hv.at(X(1)), d2+d1)); + EXPECT(assert_equal(hv.at(X(1)), d2 + d1)); } - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridValues.cpp b/gtsam/hybrid/tests/testHybridValues.cpp index 3e821aef2f..9581faaa09 100644 --- a/gtsam/hybrid/tests/testHybridValues.cpp +++ b/gtsam/hybrid/tests/testHybridValues.cpp @@ -17,19 +17,18 @@ #include #include -#include #include #include +#include +#include #include -#include -#include #include -#include +#include +#include // Include for test suite #include - using namespace std; using namespace gtsam; @@ -47,7 +46,7 @@ TEST(HybridValues, basics) { values2.insert(99, Vector2(2, 3)); EXPECT(assert_equal(values2, values)); - values2.insert(98, Vector2(2,3)); + values2.insert(98, Vector2(2, 3)); EXPECT(!assert_equal(values2, values)); } From 746ca7856df7848e677bf19751a4ea8b0f5a9e59 Mon Sep 17 00:00:00 2001 From: sjxue Date: Thu, 18 Aug 2022 17:50:20 -0400 Subject: [PATCH 4/5] Address review comments --- gtsam/hybrid/HybridLookupDAG.cpp | 7 ++-- python/gtsam/tests/test_HybridFactorGraph.py | 31 +++++++++++++++ python/gtsam/tests/test_HybridValues.py | 41 ++++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 python/gtsam/tests/test_HybridValues.py diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp index 7acff081bf..a322a81776 100644 --- a/gtsam/hybrid/HybridLookupDAG.cpp +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -55,8 +55,7 @@ void HybridLookupTable::argmaxInPlace(HybridValues* values) const { } } -// /* ************************************************************************** -// */ +/* ************************************************************************** */ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { HybridLookupDAG dag; for (auto&& conditional : bayesNet) { @@ -66,12 +65,12 @@ HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { return dag; } +/* ************************************************************************** */ HybridValues HybridLookupDAG::argmax(HybridValues result) const { // Argmax each node in turn in topological sort order (parents first). for (auto lookupTable : boost::adaptors::reverse(*this)) lookupTable->argmaxInPlace(&result); return result; } -/* ************************************************************************** */ -} // namespace gtsam \ No newline at end of file +} // namespace gtsam diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 781cfd9240..44fb175e8f 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -55,6 +55,37 @@ def test_create(self): discrete_conditional = hbn.at(hbn.size() - 1).inner() self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) + def test_optimize(self): + """Test contruction of hybrid factor graph.""" + noiseModel = gtsam.noiseModel.Unit.Create(3) + dk = gtsam.DiscreteKeys() + dk.push_back((C(0), 2)) + + jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), + noiseModel) + jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), + noiseModel) + + gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2]) + + hfg = gtsam.HybridGaussianFactorGraph() + hfg.add(jf1) + hfg.add(jf2) + hfg.push_back(gmf) + + dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1") + hfg.add(dtf) + + hbn = hfg.eliminateSequential( + gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph( + hfg, [C(0)])) + + # print("hbn = ", hbn) + hv = hbn.optimize() + self.assertEqual(hv.atDiscrete(C(0)), 1) + + self.assertEqual(hv.at(X(0)), np.ones((3, 1))) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_HybridValues.py b/python/gtsam/tests/test_HybridValues.py new file mode 100644 index 0000000000..63e7c8e7dd --- /dev/null +++ b/python/gtsam/tests/test_HybridValues.py @@ -0,0 +1,41 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Values. +Author: Shangjie Xue +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import C, X +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridGaussianFactorGraph(GtsamTestCase): + """Unit tests for HybridValues.""" + + def test_basic(self): + """Test contruction and basic methods of hybrid values.""" + + hv1 = gtsam.HybridValues() + hv1.insert(X(0), np.ones((3,1))) + hv1.insert(C(0), 2) + + hv2 = gtsam.HybridValues() + hv2.insert(C(0), 2) + hv2.insert(X(0), np.ones((3,1))) + + self.assertEqual(hv1.atDiscrete(C(0)), 2) + self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0]) + +if __name__ == "__main__": + unittest.main() From c4184e192b4605303cc0b0d51129e470eb4b4ed1 Mon Sep 17 00:00:00 2001 From: sjxue Date: Thu, 18 Aug 2022 19:35:06 -0400 Subject: [PATCH 5/5] fix error in python unit test --- python/gtsam/tests/test_HybridFactorGraph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 44fb175e8f..576efa82fd 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -84,8 +84,5 @@ def test_optimize(self): hv = hbn.optimize() self.assertEqual(hv.atDiscrete(C(0)), 1) - self.assertEqual(hv.at(X(0)), np.ones((3, 1))) - - if __name__ == "__main__": unittest.main()