diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index a837328838..d17401e44f 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -48,7 +48,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, - bool showZero = false) const; + bool showZero = true) const; std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index c4e5f06bb3..716a77127c 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -100,6 +100,20 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + + for (bool showZero:{true, false}) { + string actual = f.dot(formatter, showZero); + // pretty weak test, as ids are pointers and not stable across platforms. + string expected = "digraph G {"; + EXPECT(actual.substr(0, 11) == expected); + } +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DecisionTreeFactor, markdown) { diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index b6172382a6..ef9efbe026 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -382,6 +382,31 @@ TEST(DiscreteFactorGraph, Dot) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.dot(formatter); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"C\"];\n" + " var1[label=\"A\"];\n" + " var2[label=\"B\"];\n" + "\n" + " var0--var1;\n" + " var0--var2;\n" + "}\n"; + EXPECT(actual == expected); +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DiscreteFactorGraph, markdown) {