diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 4916cad7c0..df94d6908a 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -31,11 +31,12 @@ namespace gtsam { -/** A Bayes net made from discrete conditional distributions. */ - class GTSAM_EXPORT DiscreteBayesNet: public BayesNet - { - public: - +/** + * A Bayes net made from discrete conditional distributions. + * @addtogroup discrete + */ +class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { + public: typedef BayesNet Base; typedef DiscreteBayesNet This; typedef DiscreteConditional ConditionalType; @@ -49,16 +50,20 @@ namespace gtsam { DiscreteBayesNet() {} /** Construct from iterator over conditionals */ - template - DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} - - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - DiscreteBayesNet(const FactorGraph& graph) : Base(graph) {} + template + explicit DiscreteBayesNet(const CONTAINER& conditionals) + : Base(conditionals) {} + + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + DiscreteBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~DiscreteBayesNet() {} diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 80b8df1bc1..56e7248a3f 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + gtsam::Key firstFrontalKey() const; size_t nrFrontals() const; size_t nrParents() const; void printSignature( @@ -156,13 +157,17 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; - string dot(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; string markdown(const gtsam::KeyFormatter& keyFormatter, @@ -252,14 +257,6 @@ class DiscreteFactorGraph { void print(string s = "") const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; - string dot( - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; - void saveGraph( - string s, - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; - gtsam::DecisionTreeFactor product() const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; @@ -281,6 +278,14 @@ class DiscreteFactorGraph { std::pair eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; string markdown(const gtsam::KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index c35d4742c0..cfc9c1bb50 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) { fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); string actual = fragment.dot(); + cout << actual << endl; EXPECT(actual == - "digraph G{\n" - "0->3\n" - "4->6\n" - "3->5\n" - "6->5\n" + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var3[label=\"3\"];\n" + " var4[label=\"4\"];\n" + " var5[label=\"5\"];\n" + " var6[label=\"6\"];\n" + "\n" + " var3->var5\n" + " var6->var5\n" + " var4->var6\n" + " var0->var3\n" "}"); } diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index be34b2928f..afde5498dc 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -10,41 +10,51 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include #include +#include #include #include +#include namespace gtsam { /* ************************************************************************* */ template -void BayesNet::print( - const std::string& s, const KeyFormatter& formatter) const { +void BayesNet::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } /* ************************************************************************* */ template void BayesNet::dot(std::ostream& os, - const KeyFormatter& keyFormatter) const { - os << "digraph G{\n"; + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.digraphPreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : this->keys()) { + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); + } + os << "\n"; - for (auto conditional : *this) { + // Reverse order as typically Bayes nets stored in reverse topological sort. + for (auto conditional : boost::adaptors::reverse(*this)) { auto frontals = conditional->frontals(); const Key me = frontals.front(); auto parents = conditional->parents(); for (const Key& p : parents) - os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; + os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n"; } os << "}"; @@ -53,18 +63,20 @@ void BayesNet::dot(std::ostream& os, /* ************************************************************************* */ template -std::string BayesNet::dot(const KeyFormatter& keyFormatter) const { +std::string BayesNet::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { std::stringstream ss; - dot(ss, keyFormatter); + dot(ss, keyFormatter, writer); return ss.str(); } /* ************************************************************************* */ template void BayesNet::saveGraph(const std::string& filename, - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { std::ofstream of(filename.c_str()); - dot(of, keyFormatter); + dot(of, keyFormatter, writer); of.close(); } diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index f987ad51be..219864c547 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -10,77 +10,79 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include - #include -namespace gtsam { +#include +#include - /** - * A BayesNet is a tree of conditionals, stored in elimination order. - * - * todo: how to handle Bayes nets with an optimize function? Currently using global functions. - * \nosubgrouping - */ - template - class BayesNet : public FactorGraph { +namespace gtsam { - private: +/** + * A BayesNet is a tree of conditionals, stored in elimination order. + * @addtogroup inference + */ +template +class BayesNet : public FactorGraph { + private: + typedef FactorGraph Base; - typedef FactorGraph Base; + public: + typedef typename boost::shared_ptr + sharedConditional; ///< A shared pointer to a conditional - public: - typedef typename boost::shared_ptr sharedConditional; ///< A shared pointer to a conditional + protected: + /// @name Standard Constructors + /// @{ - protected: - /// @name Standard Constructors - /// @{ + /** Default constructor as an empty BayesNet */ + BayesNet() {} - /** Default constructor as an empty BayesNet */ - BayesNet() {}; + /** Construct from iterator over conditionals */ + template + BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} - /** Construct from iterator over conditionals */ - template - BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + /// @} - /// @} + public: + /// @name Testable + /// @{ - public: - /// @name Testable - /// @{ + /** print out graph */ + void print( + const std::string& s = "BayesNet", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** print out graph */ - void print( - const std::string& s = "BayesNet", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /// @} - /// @} + /// @name Graph Display + /// @{ - /// @name Graph Display - /// @{ + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// Output to graphviz format, stream version. - void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// Output to graphviz format string. - std::string dot( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// output to file with graphviz format. - void saveGraph(const std::string& filename, - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /// @} - }; + /// @} +}; -} +} // namespace gtsam #include diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp index 18130c35d7..ad53305757 100644 --- a/gtsam/inference/DotWriter.cpp +++ b/gtsam/inference/DotWriter.cpp @@ -16,30 +16,41 @@ * @date December, 2021 */ -#include #include +#include +#include + #include using namespace std; namespace gtsam { -void DotWriter::writePreamble(ostream* os) const { +void DotWriter::graphPreamble(ostream* os) const { *os << "graph {\n"; *os << " size=\"" << figureWidthInches << "," << figureHeightInches << "\";\n\n"; } -void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, +void DotWriter::digraphPreamble(ostream* os) const { + *os << "digraph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter, const boost::optional& position, - ostream* os) { + ostream* os) const { // Label the node with the label from the KeyFormatter *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) << "\""; if (position) { *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; } + if (boxes.count(key)) { + *os << ", shape=box"; + } *os << "];\n"; } @@ -53,18 +64,35 @@ void DotWriter::DrawFactor(size_t i, const boost::optional& position, } static void ConnectVariables(Key key1, Key key2, - const KeyFormatter& keyFormatter, - ostream* os) { + const KeyFormatter& keyFormatter, ostream* os) { *os << " var" << keyFormatter(key1) << "--" << "var" << keyFormatter(key2) << ";\n"; } static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, - size_t i, ostream* os) { + size_t i, ostream* os) { *os << " var" << keyFormatter(key) << "--" << "factor" << i << ";\n"; } +/// Return variable position or none +boost::optional DotWriter::variablePos(Key key) const { + boost::optional result = boost::none; + + // Check position hint + Symbol symbol(key); + auto hint = positionHints.find(symbol.chr()); + if (hint != positionHints.end()) + result.reset(Vector2(symbol.index(), hint->second)); + + // Override with explicit position, if given. + auto pos = variablePositions.find(key); + if (pos != variablePositions.end()) + result.reset(pos->second); + + return result; +} + void DotWriter::processFactor(size_t i, const KeyVector& keys, const KeyFormatter& keyFormatter, const boost::optional& position, @@ -74,7 +102,10 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys, ConnectVariables(keys[0], keys[1], keyFormatter, os); } else { // Create dot for the factor. - DrawFactor(i, position, os); + if (!position && factorPositions.count(i)) + DrawFactor(i, factorPositions.at(i), os); + else + DrawFactor(i, position, os); // Make factor-variable connections if (connectKeysToFactor) { diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h index 93c229c2b1..23302ee60e 100644 --- a/gtsam/inference/DotWriter.h +++ b/gtsam/inference/DotWriter.h @@ -23,10 +23,15 @@ #include #include +#include +#include namespace gtsam { -/// Graphviz formatter. +/** + * @brief DotWriter is a helper class for writing graphviz .dot files. + * @addtogroup inference + */ struct GTSAM_EXPORT DotWriter { double figureWidthInches; ///< The figure width on paper in inches double figureHeightInches; ///< The figure height on paper in inches @@ -35,6 +40,28 @@ struct GTSAM_EXPORT DotWriter { ///< the dot of the factor bool binaryEdges; ///< just use non-dotted edges for binary factors + /** + * Variable positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map variablePositions; + + /** + * The position hints allow one to use symbol character and index to specify + * position. Unless variable positions are specified, if a hint is present for + * a given symbol, it will be used to calculate the positions as (index,hint). + */ + std::map positionHints; + + /** A set of keys that will be displayed as a box */ + std::set boxes; + + /** + * Factor positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map factorPositions; + explicit DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, bool plotFactorPoints = true, @@ -45,18 +72,24 @@ struct GTSAM_EXPORT DotWriter { connectKeysToFactor(connectKeysToFactor), binaryEdges(binaryEdges) {} - /// Write out preamble, including size. - void writePreamble(std::ostream* os) const; + /// Write out preamble for graph, including size. + void graphPreamble(std::ostream* os) const; + + /// Write out preamble for digraph, including size. + void digraphPreamble(std::ostream* os) const; /// Create a variable dot fragment. - static void DrawVariable(Key key, const KeyFormatter& keyFormatter, - const boost::optional& position, - std::ostream* os); + void drawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os) const; /// Create factor dot. static void DrawFactor(size_t i, const boost::optional& position, std::ostream* os); + /// Return variable position or none + boost::optional variablePos(Key key) const; + /// Draw a single factor, specified by its index i and its variable keys. void processFactor(size_t i, const KeyVector& keys, const KeyFormatter& keyFormatter, diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 3ea17fc7ff..a2ae071016 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -131,11 +131,12 @@ template void FactorGraph::dot(std::ostream& os, const KeyFormatter& keyFormatter, const DotWriter& writer) const { - writer.writePreamble(&os); + writer.graphPreamble(&os); // Create nodes for each variable in the graph for (Key key : keys()) { - writer.DrawVariable(key, keyFormatter, boost::none, &os); + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); } os << "\n"; diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i index 5b9cef7efd..5a661d5cf2 100644 --- a/gtsam/inference/inference.i +++ b/gtsam/inference/inference.i @@ -127,6 +127,11 @@ class DotWriter { bool plotFactorPoints; bool connectKeysToFactor; bool binaryEdges; + + std::map variablePositions; + std::map positionHints; + std::set boxes; + std::map factorPositions; }; #include diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 1e790d0f11..8fd4f2c26d 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -205,23 +205,5 @@ namespace gtsam { } /* ************************************************************************* */ - void GaussianBayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional : boost::adaptors::reverse(*this)) { - typename GaussianConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename GaussianConditional::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; - } - - of << "}"; - of.close(); - } - - /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index e55a89bcda..6d906d65e3 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -21,17 +21,22 @@ #pragma once #include +#include #include #include +#include namespace gtsam { - /** A Bayes net made from linear-Gaussian densities */ - class GTSAM_EXPORT GaussianBayesNet: public FactorGraph + /** + * GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals. + * @addtogroup linear + */ + class GTSAM_EXPORT GaussianBayesNet: public BayesNet { public: - typedef FactorGraph Base; + typedef BayesNet Base; typedef GaussianBayesNet This; typedef GaussianConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +49,21 @@ namespace gtsam { GaussianBayesNet() {} /** Construct from iterator over conditionals */ - template - GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit GaussianBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - GaussianBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit GaussianBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~GaussianBayesNet() {} @@ -66,6 +76,13 @@ namespace gtsam { /** Check equality */ bool equals(const This& bn, double tol = 1e-9) const; + /// print graph + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + /// @} /// @name Standard Interface @@ -180,23 +197,6 @@ namespace gtsam { */ VectorValues backSubstituteTranspose(const VectorValues& gx) const; - /// print graph - void print( - const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); - } - - /** - * @brief Save the GaussianBayesNet as an image. Requires `dot` to be - * installed. - * - * @param s The name of the figure. - * @param keyFormatter Formatter to use for styling keys in the graph. - */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const; - /// @} private: diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index d2a86ddc8d..b079c3dd18 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -437,42 +437,53 @@ class GaussianFactorGraph { pair hessian() const; pair hessian(const gtsam::Ordering& ordering) const; + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + // enabling serialization functionality void serialize() const; }; #include virtual class GaussianConditional : gtsam::JacobianFactor { - //Constructors - GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); + // Constructors + GaussianConditional(size_t key, Vector d, Matrix R, + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - const gtsam::noiseModel::Diagonal* sigmas); + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas); + size_t name2, Matrix T, + const gtsam::noiseModel::Diagonal* sigmas); - //Constructors with no noise model + // Constructors with no noise model GaussianConditional(size_t key, Vector d, Matrix R); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, + size_t name2, Matrix T); - //Standard Interface - void print(string s = "GaussianConditional", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::GaussianConditional& cg, double tol) const; - - // Advanced Interface - gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; - gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, - const gtsam::VectorValues& rhs) const; - void solveTransposeInPlace(gtsam::VectorValues& gy) const; - Matrix R() const; - Matrix S() const; - Vector d() const; + // Standard Interface + void print(string s = "GaussianConditional", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::GaussianConditional& cg, double tol) const; + gtsam::Key firstFrontalKey() const; + + // Advanced Interface + gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; + gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, + const gtsam::VectorValues& rhs) const; + void solveTransposeInPlace(gtsam::VectorValues& gy) const; + Matrix R() const; + Matrix S() const; + Vector d() const; - // enabling serialization functionality - void serialize() const; + // enabling serialization functionality + void serialize() const; }; #include @@ -524,6 +535,14 @@ virtual class GaussianBayesNet { double logDeterminant() const; gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index 00a338e547..f62da15dde 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) { } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +TEST(GaussianBayesNet, Dot) { + GaussianBayesNet fragment; + DotWriter writer; + writer.variablePositions.emplace(_x_, Vector2(10, 20)); + writer.variablePositions.emplace(_y_, Vector2(50, 20)); + + auto position = writer.variablePos(_x_); + CHECK(position); + EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5)); + + string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var11[label=\"11\", pos=\"10,20!\"];\n" + " var22[label=\"22\", pos=\"50,20!\"];\n" + "\n" + " var22->var11\n" + "}"); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp index e5b81c66b9..ca3466b6a1 100644 --- a/gtsam/nonlinear/GraphvizFormatting.cpp +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values, min.y() = std::numeric_limits::infinity(); for (const Key& key : keys) { if (values.exists(key)) { - boost::optional xy = operator()(values.at(key)); + boost::optional xy = extractPosition(values.at(key)); if (xy) { if (xy->x() < min.x()) min.x() = xy->x(); if (xy->y() < min.y()) min.y() = xy->y(); @@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values, return min; } -boost::optional GraphvizFormatting::operator()( +boost::optional GraphvizFormatting::extractPosition( const Value& value) const { Vector3 t; if (const GenericValue* p = @@ -121,12 +121,11 @@ boost::optional GraphvizFormatting::operator()( return Vector2(x, y); } -// Return affinely transformed variable position if it exists. boost::optional GraphvizFormatting::variablePos(const Values& values, const Vector2& min, Key key) const { - if (!values.exists(key)) return boost::none; - boost::optional xy = operator()(values.at(key)); + if (!values.exists(key)) return DotWriter::variablePos(key); + boost::optional xy = extractPosition(values.at(key)); if (xy) { xy->x() = scale * (xy->x() - min.x()); xy->y() = scale * (xy->y() - min.y()); @@ -134,7 +133,6 @@ boost::optional GraphvizFormatting::variablePos(const Values& values, return xy; } -// Return affinely transformed factor position if it exists. boost::optional GraphvizFormatting::factorPos(const Vector2& min, size_t i) const { if (factorPositions.size() == 0) return boost::none; diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h index c36b09a8fc..03cdb34694 100644 --- a/gtsam/nonlinear/GraphvizFormatting.h +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { /// World axes to be assigned to paper axes enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal - ///< paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper - ///< axis + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis double scale; ///< Scale all positions to reduce / increase density bool mergeSimilarFactors; ///< Merge multiple factors that have the same ///< connectivity - /// (optional for each factor) Manually specify factor "dot" positions: - std::map factorPositions; - /// Default constructor sets up robot coordinates. Paper horizontal is robot /// Y, paper vertical is robot X. Default figure size of 5x5 in. GraphvizFormatting() @@ -55,8 +52,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { // Find bounds Vector2 findBounds(const Values& values, const KeySet& keys) const; - /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 - boost::optional operator()(const Value& value) const; + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional extractPosition(const Value& value) const; /// Return affinely transformed variable position if it exists. boost::optional variablePos(const Values& values, const Vector2& min, diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index da8935d5fc..dfa54f26f0 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, const KeyFormatter& keyFormatter, const GraphvizFormatting& writer) const { - writer.writePreamble(&os); + writer.graphPreamble(&os); // Find bounds (imperative) KeySet keys = this->keys(); @@ -111,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, // Create nodes for each variable in the graph for (Key key : keys) { auto position = writer.variablePos(values, min, key); - writer.DrawVariable(key, keyFormatter, position, &os); + writer.drawVariable(key, keyFormatter, position, &os); } os << "\n"; diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index ea8748f63b..3237d7c1e0 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -43,12 +43,14 @@ namespace gtsam { class ExpressionFactor; /** - * A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, - * which derive from NonlinearFactor. The values structures are typically (in SAM) more general - * than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds. - * Linearizing the non-linear factor graph creates a linear factor graph on the - * tangent vector space at the linearization point. Because the tangent space is a true - * vector space, the config type will be an VectorValues in that linearized factor graph. + * A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors, + * which derive from NonlinearFactor. The values structures are typically (in + * SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects + * in non-linear manifolds. Linearizing the non-linear factor graph creates a + * linear factor graph on the tangent vector space at the linearization point. + * Because the tangent space is a true vector space, the config type will be + * an VectorValues in that linearized factor graph. + * @addtogroup nonlinear */ class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph { @@ -58,6 +60,9 @@ namespace gtsam { typedef NonlinearFactorGraph This; typedef boost::shared_ptr shared_ptr; + /// @name Standard Constructors + /// @{ + /** Default constructor */ NonlinearFactorGraph() {} @@ -76,6 +81,10 @@ namespace gtsam { /// Destructor virtual ~NonlinearFactorGraph() {} + /// @} + /// @name Testable + /// @{ + /** print */ void print( const std::string& str = "NonlinearFactorGraph: ", @@ -90,6 +99,10 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; + /// @} + /// @name Standard Interface + /// @{ + /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ double error(const Values& values) const; @@ -206,6 +219,7 @@ namespace gtsam { emplace_shared>(key, prior, covariance); } + /// @} /// @name Graph Display /// @{ @@ -215,20 +229,19 @@ namespace gtsam { /// Output to graphviz format, stream version, with Values/extra options. void dot(std::ostream& os, const Values& values, const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// Output to graphviz format string, with Values/extra options. - std::string dot(const Values& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + std::string dot( + const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// output to file with graphviz format, with Values/extra options. - void saveGraph(const std::string& filename, const Values& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + void saveGraph( + const std::string& filename, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// @} private: @@ -251,6 +264,8 @@ namespace gtsam { public: #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ /** @deprecated */ boost::shared_ptr GTSAM_DEPRECATED linearizeToHessianFactor( const Values& values, boost::none_t, const Dampen& dampen = nullptr) const @@ -275,6 +290,7 @@ namespace gtsam { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { saveGraph(filename, values, keyFormatter, graphvizFormatting); } + /// @} #endif }; diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index 159261713a..eedf421bc7 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -95,18 +95,17 @@ class NonlinearFactorGraph { gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const; gtsam::NonlinearFactorGraph clone() const; - // enabling serialization functionality - void serialize() const; - string dot( const gtsam::Values& values, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - const GraphvizFormatting& writer = GraphvizFormatting()); - void saveGraph(const string& s, const gtsam::Values& values, - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter, - const GraphvizFormatting& writer = - GraphvizFormatting()) const; + const GraphvizFormatting& formatting = GraphvizFormatting()); + void saveGraph( + const string& s, const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const GraphvizFormatting& formatting = GraphvizFormatting()) const; + + // enabling serialization functionality + void serialize() const; }; #include diff --git a/gtsam/symbolic/SymbolicBayesNet.cpp b/gtsam/symbolic/SymbolicBayesNet.cpp index 5bc20ad127..f7113b23a5 100644 --- a/gtsam/symbolic/SymbolicBayesNet.cpp +++ b/gtsam/symbolic/SymbolicBayesNet.cpp @@ -16,41 +16,16 @@ * @author Richard Roberts */ -#include -#include #include - -#include -#include +#include namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool SymbolicBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ - void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const - { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional: boost::adaptors::reverse(*this)) { - SymbolicConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - SymbolicConditional::Parents parents = conditional->parents(); - for(Key p: parents) - of << p << "->" << me << std::endl; - } - - of << "}"; - of.close(); - } - +// Instantiate base class +template class FactorGraph; +/* ************************************************************************* */ +bool SymbolicBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); } +} // namespace gtsam diff --git a/gtsam/symbolic/SymbolicBayesNet.h b/gtsam/symbolic/SymbolicBayesNet.h index 464af060b6..2f66b80e22 100644 --- a/gtsam/symbolic/SymbolicBayesNet.h +++ b/gtsam/symbolic/SymbolicBayesNet.h @@ -19,19 +19,19 @@ #pragma once #include +#include #include #include namespace gtsam { - /** Symbolic Bayes Net - * \nosubgrouping + /** + * A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals. + * @addtogroup symbolic */ - class SymbolicBayesNet : public FactorGraph { - - public: - - typedef FactorGraph Base; + class SymbolicBayesNet : public BayesNet { + public: + typedef BayesNet Base; typedef SymbolicBayesNet This; typedef SymbolicConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +44,21 @@ namespace gtsam { SymbolicBayesNet() {} /** Construct from iterator over conditionals */ - template - SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit SymbolicBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - SymbolicBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit SymbolicBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~SymbolicBayesNet() {} @@ -75,13 +80,6 @@ namespace gtsam { /// @} - /// @name Standard Interface - /// @{ - - GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /// @} - private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/symbolic/symbolic.i b/gtsam/symbolic/symbolic.i index 771e5309ad..1f1d4b48f9 100644 --- a/gtsam/symbolic/symbolic.i +++ b/gtsam/symbolic/symbolic.i @@ -77,6 +77,14 @@ virtual class SymbolicFactorGraph { const gtsam::KeyVector& key_vector, const gtsam::Ordering& marginalizedVariableOrdering); gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -98,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor { bool equals(const gtsam::SymbolicConditional& other, double tol) const; // Standard interface + gtsam::Key firstFrontalKey() const; size_t nrFrontals() const; size_t nrParents() const; }; @@ -120,6 +129,14 @@ class SymbolicBayesNet { gtsam::SymbolicConditional* back() const; void push_back(gtsam::SymbolicConditional* conditional); void push_back(const gtsam::SymbolicBayesNet& bayesNet); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include diff --git a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp index a92d66f686..2e13be10eb 100644 --- a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp +++ b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp @@ -15,13 +15,16 @@ * @author Frank Dellaert */ -#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include using namespace std; using namespace gtsam; @@ -30,7 +33,6 @@ static const Key _L_ = 0; static const Key _A_ = 1; static const Key _B_ = 2; static const Key _C_ = 3; -static const Key _D_ = 4; static SymbolicConditional::shared_ptr B(new SymbolicConditional(_B_)), @@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine ) } /* ************************************************************************* */ -TEST(SymbolicBayesNet, saveGraph) { +TEST(SymbolicBayesNet, Dot) { + using symbol_shorthand::A; + using symbol_shorthand::X; SymbolicBayesNet bn; - bn += SymbolicConditional(_A_, _B_); - KeyVector keys {_B_, _C_, _D_}; - bn += SymbolicConditional::FromKeys(keys,2); - bn += SymbolicConditional(_D_); - - bn.saveGraph("SymbolicBayesNet.dot"); + bn += SymbolicConditional(X(3), X(2), A(2)); + bn += SymbolicConditional(X(2), X(1), A(1)); + bn += SymbolicConditional(X(1)); + + DotWriter writer; + writer.positionHints.emplace('a', 2); + writer.positionHints.emplace('x', 1); + writer.boxes.emplace(A(1)); + writer.boxes.emplace(A(2)); + + auto position = writer.variablePos(A(1)); + CHECK(position); + EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5)); + + string actual = bn.dot(DefaultKeyFormatter, writer); + bn.saveGraph("bn.dot", DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n" + " vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n" + " varx1[label=\"x1\", pos=\"1,1!\"];\n" + " varx2[label=\"x2\", pos=\"2,1!\"];\n" + " varx3[label=\"x3\", pos=\"3,1!\"];\n" + "\n" + " varx1->varx2\n" + " vara1->varx2\n" + " varx2->varx3\n" + " vara2->varx3\n" + "}"); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_GraphvizFormatting.py b/python/gtsam/tests/test_GraphvizFormatting.py index ecdc23b450..5962366efa 100644 --- a/python/gtsam/tests/test_GraphvizFormatting.py +++ b/python/gtsam/tests/test_GraphvizFormatting.py @@ -78,7 +78,7 @@ def test_swapped_axes(self): graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y self.assertEqual(self.graph.dot(self.values, - writer=graphviz_formatting), + formatting=graphviz_formatting), textwrap.dedent(expected_result)) def test_factor_points(self): @@ -100,7 +100,7 @@ def test_factor_points(self): graphviz_formatting.plotFactorPoints = False self.assertEqual(self.graph.dot(self.values, - writer=graphviz_formatting), + formatting=graphviz_formatting), textwrap.dedent(expected_result)) def test_width_height(self): @@ -127,7 +127,7 @@ def test_width_height(self): graphviz_formatting.figureHeightInches = 10 self.assertEqual(self.graph.dot(self.values, - writer=graphviz_formatting), + formatting=graphviz_formatting), textwrap.dedent(expected_result))