diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b3df73bf20..816450ff32 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -11,10 +11,13 @@ * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @author Fan Jiang * @author Varun Agrawal + * @author Shangjie Xue * @date January 2022 */ #include +#include +#include namespace gtsam { @@ -40,4 +43,10 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } +/* *******************************************************************************/ +HybridValues HybridBayesNet::optimize() const { + auto dag = HybridLookupDAG::FromBayesNet(*this); + return dag.argmax(); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 7fe19c0ea7..9d6d5f2361 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include @@ -61,6 +62,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @return GaussianBayesNet */ GaussianBayesNet choose(const DiscreteValues &assignment) const; + + /// Solve the HybridBayesNet by back-substitution. + /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and + /// put this method there? + HybridValues optimize() const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp new file mode 100644 index 0000000000..a322a81776 --- /dev/null +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -0,0 +1,76 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.cpp + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +/* ************************************************************************** */ +void HybridLookupTable::argmaxInPlace(HybridValues* values) const { + // For discrete conditional, uses argmaxInPlace() method in + // DiscreteLookupTable. + if (isDiscrete()) { + boost::static_pointer_cast(inner_)->argmaxInPlace( + &(values->discrete)); + } else if (isContinuous()) { + // For Gaussian conditional, uses solve() method in GaussianConditional. + values->continuous.insert( + boost::static_pointer_cast(inner_)->solve( + values->continuous)); + } else if (isHybrid()) { + // For hybrid conditional, since children should not contain discrete + // variable, we can condition on the discrete variable in the parents and + // solve the resulting GaussianConditional. + auto conditional = + boost::static_pointer_cast(inner_)->conditionals()( + values->discrete); + values->continuous.insert(conditional->solve(values->continuous)); + } +} + +/* ************************************************************************** */ +HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { + HybridLookupDAG dag; + for (auto&& conditional : bayesNet) { + HybridLookupTable hlt(*conditional); + dag.push_back(hlt); + } + return dag; +} + +/* ************************************************************************** */ +HybridValues HybridLookupDAG::argmax(HybridValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridLookupDAG.h b/gtsam/hybrid/HybridLookupDAG.h new file mode 100644 index 0000000000..cc1c58c58f --- /dev/null +++ b/gtsam/hybrid/HybridLookupDAG.h @@ -0,0 +1,119 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridLookupDAG.h + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +/** + * @brief HybridLookupTable table for max-product + * + * Similar to DiscreteLookupTable, inherits from hybrid conditional for + * convenience. Is used in the max-product algorithm. + */ +class GTSAM_EXPORT HybridLookupTable : public HybridConditional { + public: + using Base = HybridConditional; + using This = HybridLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Hybrid Lookup Table object form a HybridConditional. + * + * @param conditional input hybrid conditional + */ + HybridLookupTable(HybridConditional& conditional) : Base(conditional){}; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(HybridValues* parentsValues) const; +}; + +/** A DAG made from hybrid lookup tables, as defined above. Similar to + * DiscreteLookupDAG */ +class GTSAM_EXPORT HybridLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = HybridLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + HybridLookupDAG() {} + + /// Create from BayesNet with LookupTables + static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet); + + /// Destructor + virtual ~HybridLookupDAG() {} + + /// @} + + /// @name Standard Interface + /// @{ + + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + /** + * @brief argmax by back-substitution, optionally given certain variables. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + HybridValues argmax(HybridValues given = HybridValues()) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h new file mode 100644 index 0000000000..5e1bd41646 --- /dev/null +++ b/gtsam/hybrid/HybridValues.h @@ -0,0 +1,127 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridValues.h + * @date Jul 28, 2022 + * @author Shangjie Xue + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * HybridValues represents a collection of DiscreteValues and VectorValues. It + * is typically used to store the variables of a HybridGaussianFactorGraph. + * Optimizing a HybridGaussianBayesNet returns this class. + */ +class GTSAM_EXPORT HybridValues { + public: + // DiscreteValue stored the discrete components of the HybridValues. + DiscreteValues discrete; + + // VectorValue stored the continuous components of the HybridValues. + VectorValues continuous; + + // Default constructor creates an empty HybridValues. + HybridValues() : discrete(), continuous(){}; + + // Construct from DiscreteValues and VectorValues. + HybridValues(const DiscreteValues& dv, const VectorValues& cv) + : discrete(dv), continuous(cv){}; + + // print required by Testable for unit testing + void print(const std::string& s = "HybridValues", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + std::cout << s << ": \n"; + discrete.print(" Discrete", keyFormatter); // print discrete components + continuous.print(" Continuous", + keyFormatter); // print continuous components + }; + + // equals required by Testable for unit testing + bool equals(const HybridValues& other, double tol = 1e-9) const { + return discrete.equals(other.discrete, tol) && + continuous.equals(other.continuous, tol); + } + + // Check whether a variable with key \c j exists in DiscreteValue. + bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; + + // Check whether a variable with key \c j exists in VectorValue. + bool existsVector(Key j) { return continuous.exists(j); }; + + // Check whether a variable with key \c j exists. + bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; + + /** Insert a discrete \c value with key \c j. Replaces the existing value if + * the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, int value) { discrete[j] = value; }; + + /** Insert a vector \c value with key \c j. Throws an invalid_argument + * exception if the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, const Vector& value) { continuous.insert(j, value); } + + // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h + + /** + * Read/write access to the discrete value with key \c j, throws + * std::out_of_range if \c j does not exist. + */ + size_t& atDiscrete(Key j) { return discrete.at(j); }; + + /** + * Read/write access to the vector value with key \c j, throws + * std::out_of_range if \c j does not exist. + */ + Vector& at(Key j) { return continuous.at(j); }; + + /// @name Wrapper support + /// @{ + + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + * @return string html output. + */ + std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + std::stringstream ss; + ss << this->discrete.html(keyFormatter); + ss << this->continuous.html(keyFormatter); + return ss.str(); + }; + + /// @} +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index bbe1e2400d..207f3ff635 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -4,6 +4,22 @@ namespace gtsam { +#include +class HybridValues { + gtsam::DiscreteValues discrete; + gtsam::VectorValues continuous; + HybridValues(); + HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv); + void print(string s = "HybridValues", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridValues& other, double tol) const; + void insert(gtsam::Key j, int value); + void insert(gtsam::Key j, const gtsam::Vector& value); + size_t& atDiscrete(gtsam::Key j); + gtsam::Vector& at(gtsam::Key j); +}; + #include virtual class HybridFactor { void print(string s = "HybridFactor\n", @@ -84,6 +100,7 @@ class HybridBayesNet { size_t size() const; gtsam::KeySet keys() const; const gtsam::HybridConditional* at(size_t i) const; + gtsam::HybridValues optimize() const; void print(string s = "HybridBayesNet\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index a96aab6b9a..17a2d94d74 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -501,6 +503,27 @@ TEST_DISABLED(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +TEST(HybridGaussianFactorGraph, optimize) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + + auto result = + hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)})); + + HybridValues hv = result->optimize(); + + EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridLookupDAG.cpp b/gtsam/hybrid/tests/testHybridLookupDAG.cpp new file mode 100644 index 0000000000..c472aa22f5 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridLookupDAG.cpp @@ -0,0 +1,272 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridLookupDAG.cpp + * @date Aug, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +TEST(HybridLookupTable, basics) { + // create a conditional gaussian node + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 2; + S1(0, 1) = 3; + S1(1, 1) = 4; + + Matrix S2(2, 2); + S2(0, 0) = 6; + S2(1, 0) = 0.2; + S2(0, 1) = 8; + S2(1, 1) = 0.4; + + Matrix R1(2, 2); + R1(0, 0) = 0.1; + R1(1, 0) = 0.3; + R1(0, 1) = 0.0; + R1(1, 1) = 0.34; + + Matrix R2(2, 2); + R2(0, 0) = 0.1; + R2(1, 0) = 0.3; + R2(0, 1) = 0.0; + R2(1, 1) = 0.34; + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + Vector2 d1(0.2, 0.5), d2(0.5, 0.2); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + // GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); + + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); + + HybridConditional hc(mixtureFactor); + + GaussianMixture::Conditionals conditional2 = + boost::static_pointer_cast(hc.inner())->conditionals(); + + DiscreteValues dv; + dv[1] = 1; + + VectorValues cv; + cv.insert(X(2), Vector2(0.0, 0.0)); + + HybridValues hv(dv, cv); + + // std::cout << conditional2(values).markdown(); + EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); + EXPECT(conditional2(dv) == conditionals(dv)); + HybridLookupTable hlt(hc); + + // hlt.argmaxInPlace(&hv); + + HybridLookupDAG dag; + dag.push_back(hlt); + dag.argmax(hv); + + // HybridBayesNet hbn; + // hbn.push_back(hc); + // hbn.optimize(); +} + +TEST(HybridLookupTable, hybrid_argmax) { + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = + boost::make_shared(X(1), d1, S1, model), + conditional1 = + boost::make_shared(X(1), d2, S1, model); + + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {}, {m1}, conditionals)); + + HybridConditional hc(mixtureFactor); + + DiscreteValues dv; + dv[1] = 1; + VectorValues cv; + // cv.insert(X(2),Vector2(0.0, 0.0)); + HybridValues hv(dv, cv); + + HybridLookupTable hlt(hc); + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.at(X(1)), d2)); +} + +TEST(HybridLookupTable, discrete_argmax) { + DiscreteKey X(0, 2), Y(1, 2); + + auto conditional = boost::make_shared(X | Y = "0/1 3/2"); + + HybridConditional hc(conditional); + + HybridLookupTable hlt(hc); + + DiscreteValues dv; + dv[1] = 0; + VectorValues cv; + // cv.insert(X(2),Vector2(0.0, 0.0)); + HybridValues hv(dv, cv); + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.atDiscrete(0), 1)); + + DecisionTreeFactor f1(X, "2 3"); + auto conditional2 = boost::make_shared(1, f1); + + HybridConditional hc2(conditional2); + + HybridLookupTable hlt2(hc2); + + HybridValues hv2; + + hlt2.argmaxInPlace(&hv2); + + EXPECT(assert_equal(hv2.atDiscrete(0), 1)); +} + +TEST(HybridLookupTable, gaussian_argmax) { + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); + + HybridConditional hc(conditional); + + HybridLookupTable hlt(hc); + + DiscreteValues dv; + // dv[1]=0; + VectorValues cv; + cv.insert(X(2), d2); + HybridValues hv(dv, cv); + + hlt.argmaxInPlace(&hv); + + EXPECT(assert_equal(hv.at(X(1)), d1 + d2)); +} + +TEST(HybridLookupDAG, argmax) { + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 0; + S1(0, 1) = 0; + S1(1, 1) = 1; + + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = + boost::make_shared(X(2), d1, S1, model), + conditional1 = + boost::make_shared(X(2), d2, S1, model); + + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(2)}, {}, {m1}, conditionals)); + HybridConditional hc2(mixtureFactor); + HybridLookupTable hlt2(hc2); + + auto conditional2 = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); + + HybridConditional hc1(conditional2); + HybridLookupTable hlt1(hc1); + + DecisionTreeFactor f1(m1, "2 3"); + auto discrete_conditional = boost::make_shared(1, f1); + + HybridConditional hc3(discrete_conditional); + HybridLookupTable hlt3(hc3); + + HybridLookupDAG dag; + dag.push_back(hlt1); + dag.push_back(hlt2); + dag.push_back(hlt3); + auto hv = dag.argmax(); + + EXPECT(assert_equal(hv.atDiscrete(1), 1)); + EXPECT(assert_equal(hv.at(X(2)), d2)); + EXPECT(assert_equal(hv.at(X(1)), d2 + d1)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file diff --git a/gtsam/hybrid/tests/testHybridValues.cpp b/gtsam/hybrid/tests/testHybridValues.cpp new file mode 100644 index 0000000000..9581faaa09 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridValues.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridValues.cpp + * @date Jul 28, 2022 + * @author Shangjie Xue + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; + +TEST(HybridValues, basics) { + HybridValues values; + values.insert(99, Vector2(2, 3)); + values.insert(100, 3); + EXPECT(assert_equal(values.at(99), Vector2(2, 3))); + EXPECT(assert_equal(values.atDiscrete(100), int(3))); + + values.print(); + + HybridValues values2; + values2.insert(100, 3); + values2.insert(99, Vector2(2, 3)); + EXPECT(assert_equal(values2, values)); + + values2.insert(98, Vector2(2, 3)); + EXPECT(!assert_equal(values2, values)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 781cfd9240..576efa82fd 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -55,6 +55,34 @@ def test_create(self): discrete_conditional = hbn.at(hbn.size() - 1).inner() self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) + def test_optimize(self): + """Test contruction of hybrid factor graph.""" + noiseModel = gtsam.noiseModel.Unit.Create(3) + dk = gtsam.DiscreteKeys() + dk.push_back((C(0), 2)) + + jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), + noiseModel) + jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), + noiseModel) + + gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2]) + + hfg = gtsam.HybridGaussianFactorGraph() + hfg.add(jf1) + hfg.add(jf2) + hfg.push_back(gmf) + + dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1") + hfg.add(dtf) + + hbn = hfg.eliminateSequential( + gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph( + hfg, [C(0)])) + + # print("hbn = ", hbn) + hv = hbn.optimize() + self.assertEqual(hv.atDiscrete(C(0)), 1) if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_HybridValues.py b/python/gtsam/tests/test_HybridValues.py new file mode 100644 index 0000000000..63e7c8e7dd --- /dev/null +++ b/python/gtsam/tests/test_HybridValues.py @@ -0,0 +1,41 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Values. +Author: Shangjie Xue +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import C, X +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridGaussianFactorGraph(GtsamTestCase): + """Unit tests for HybridValues.""" + + def test_basic(self): + """Test contruction and basic methods of hybrid values.""" + + hv1 = gtsam.HybridValues() + hv1.insert(X(0), np.ones((3,1))) + hv1.insert(C(0), 2) + + hv2 = gtsam.HybridValues() + hv2.insert(C(0), 2) + hv2.insert(X(0), np.ones((3,1))) + + self.assertEqual(hv1.atDiscrete(C(0)), 2) + self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0]) + +if __name__ == "__main__": + unittest.main()