diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 11ecbf1834..3c74e57fda 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -82,13 +83,7 @@ namespace gtsam { return compare(this->constant_, other->constant_); } - /** - * @brief Print method. - * - * @param s Prefix string. - * @param labelFormatter Functor to format the labels of type L. - * @param valueFormatter Functor to format the values of type Y. - */ + /** print */ void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; @@ -332,7 +327,7 @@ namespace gtsam { /** apply unary operator */ NodePtr apply(const Unary& op) const override { - boost::shared_ptr r(new Choice(label_, *this, op)); + auto r = boost::make_shared(label_, *this, op); return Unique(r); } @@ -347,24 +342,24 @@ namespace gtsam { // If second argument of binary op is Leaf node, recurse on branches NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(NodePtr branch: branches_) - h->push_back(fL.apply_f_op_g(*branch, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); return Unique(h); } // If second argument of binary op is Choice, call constructor NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - boost::shared_ptr h(new Choice(fC, *this, op)); + auto h = boost::make_shared(fC, *this, op); return Unique(h); } // If second argument of binary op is Leaf template NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(const NodePtr& branch: branches_) - h->push_back(branch->apply_f_op_g(gL, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(branch->apply_f_op_g(gL, op)); return Unique(h); } @@ -374,9 +369,9 @@ namespace gtsam { return branches_[index]; // choose branch // second case, not label of interest, just recurse - boost::shared_ptr r(new Choice(label_, branches_.size())); - for(const NodePtr& branch: branches_) - r->push_back(branch->choose(label, index)); + auto r = boost::make_shared(label_, branches_.size()); + for (auto&& branch : branches_) + r->push_back(branch->choose(label, index)); return Unique(r); } @@ -401,10 +396,9 @@ namespace gtsam { } /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const L& label, const Y& y1, const Y& y2) { - boost::shared_ptr a(new Choice(label, 2)); + template + DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { + auto a = boost::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); @@ -412,12 +406,12 @@ namespace gtsam { } /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const LabelC& labelC, const Y& y1, const Y& y2) { + template + DecisionTree::DecisionTree(const LabelC& labelC, const Y& y1, + const Y& y2) { if (labelC.second != 2) throw std::invalid_argument( "DecisionTree: binary constructor called with non-binary label"); - boost::shared_ptr a(new Choice(labelC.first, 2)); + auto a = boost::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); @@ -465,23 +459,20 @@ namespace gtsam { /*********************************************************************************/ template - template + template DecisionTree::DecisionTree(const DecisionTree& other, - std::function Y_of_X) { + Func Y_of_X) { // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; + auto L_of_L = [](const L& label) { return label; }; root_ = convertFrom(other.root_, L_of_L, Y_of_X); } /*********************************************************************************/ template - template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, - std::function Y_of_X) { - std::function L_of_M = [&map](const M& label) -> L { - return map.at(label); - }; + const std::map& map, Func Y_of_X) { + auto L_of_M = [&map](const M& label) -> L { return map.at(label); }; root_ = convertFrom(other.root_, L_of_M, Y_of_X); } @@ -511,13 +502,14 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { - boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + auto choiceOnLabel = boost::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); return Choice::Unique(choiceOnLabel); } else { // Set up a new choice on the highest label - boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, nrChoices)); + auto choiceOnHighestLabel = + boost::make_shared(*highestLabel, nrChoices); // now, for all possible values of highestLabel for (size_t index = 0; index < nrChoices; index++) { // make a new set of functions for composing by iterating over the given @@ -576,7 +568,7 @@ namespace gtsam { std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; throw std::invalid_argument("DecisionTree::create invalid argument"); } - boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + auto choice = boost::make_shared(begin->first, endY - beginY); for (ValueIt y = beginY; y != endY; y++) choice->push_back(NodePtr(new Leaf(*y))); return Choice::Unique(choice); @@ -589,7 +581,7 @@ namespace gtsam { size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = create(labelC, end, beginY, beginY + split); - functions += DecisionTree(f); + functions.emplace_back(f); } return compose(functions.begin(), functions.end(), begin->first); } @@ -601,18 +593,16 @@ namespace gtsam { const typename DecisionTree::NodePtr& f, std::function L_of_M, std::function Y_of_X) const { - using MX = DecisionTree; - using MXLeaf = typename MX::Leaf; - using MXChoice = typename MX::Choice; - using MXNodePtr = typename MX::NodePtr; using LY = DecisionTree; // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf - auto leaf = boost::dynamic_pointer_cast(f); - if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + using MXLeaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(f)) + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice + using MXChoice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( "DecisionTree::Convert: Invalid NodePtr"); @@ -623,14 +613,93 @@ namespace gtsam { // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(const MXNodePtr& branch: choice->branches()) { - LY converted(convertFrom(branch, L_of_M, Y_of_X)); - functions += converted; + for(auto && branch: choice->branches()) { + functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } return LY::compose(functions.begin(), functions.end(), newLabel); } /*********************************************************************************/ + // Functor performing depth-first visit without Assignment argument. + template + struct Visit { + using F = std::function; + Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); + } + + /*********************************************************************************/ + // Functor performing depth-first visit with Assignment argument. + template + struct VisitWith { + using Choices = Assignment; + using F = std::function; + VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(choices, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (size_t i = 0; i < choice->nrChoices(); i++) { + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /*********************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /*********************************************************************************/ + // labels is just done with a visit + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& choices, const Y&) { + for (auto&& kv : choices) unique.insert(kv.first); + }; + visitWith(f); + return unique; + } + +/*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, const CompareFunc& compare) const { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index db8a12a200..9692094e19 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -28,6 +28,7 @@ #include #include #include +#include namespace gtsam { @@ -176,9 +177,8 @@ namespace gtsam { * @param other The DecisionTree to convert from. * @param Y_of_X Functor to convert from value type X to type Y. */ - template - DecisionTree(const DecisionTree& other, - std::function Y_of_X); + template + DecisionTree(const DecisionTree& other, Func Y_of_X); /** * @brief Convert from a different value type X to value type Y, also transate @@ -190,9 +190,9 @@ namespace gtsam { * @param L_of_M Map from label type M to type L. * @param Y_of_X Functor to convert from type X to type Y. */ - template - DecisionTree(const DecisionTree& other, const std::map& L_of_M, - std::function Y_of_X); + template + DecisionTree(const DecisionTree& other, const std::map& map, + Func Y_of_X); /// @} /// @name Testable @@ -229,6 +229,52 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking an assignment and a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all unique labels as a set. */ + std::set labels() const; + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 5976ea2d48..2e6ec59f72 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -123,8 +123,7 @@ struct Ring { /* ******************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -231,13 +230,10 @@ TEST(DT, example) /* ******************************************************************************** */ // test Conversion of values -std::function bool_of_int = [](const int& y) { - return y != 0; -}; +bool bool_of_int(const int& y) { return y != 0; }; typedef DecisionTree StringBoolTree; -TEST(DT, ConvertValuesOnly) -{ +TEST(DecisionTree, ConvertValuesOnly) { // Create labels string A("A"), B("B"); @@ -260,8 +256,7 @@ enum Label { }; typedef DecisionTree LabelBoolTree; -TEST(DT, ConvertBoth) -{ +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -272,7 +267,7 @@ TEST(DT, ConvertBoth) map ordering; ordering[A] = X; ordering[B] = Y; - LabelBoolTree f2(f1, ordering, bool_of_int); + LabelBoolTree f2(f1, ordering, &bool_of_int); // Check some values Assignment