Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More DecisionTree improvements #1005

Merged
merged 6 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 117 additions & 48 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath>
#include <fstream>
#include <list>
Expand Down Expand Up @@ -82,13 +83,7 @@ namespace gtsam {
return compare(this->constant_, other->constant_);
}

/**
* @brief Print method.
*
* @param s Prefix string.
* @param labelFormatter Functor to format the labels of type L.
* @param valueFormatter Functor to format the values of type Y.
*/
/** print */
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
Expand Down Expand Up @@ -332,7 +327,7 @@ namespace gtsam {

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
}

Expand All @@ -347,24 +342,24 @@ namespace gtsam {

// If second argument of binary op is Leaf node, recurse on branches
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
for(NodePtr branch: branches_)
h->push_back(fL.apply_f_op_g(*branch, op));
auto h = boost::make_shared<Choice>(label(), nrChoices());
for (auto&& branch : branches_)
h->push_back(fL.apply_f_op_g(*branch, op));
return Unique(h);
}

// If second argument of binary op is Choice, call constructor
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
auto h = boost::make_shared<Choice>(fC, *this, op);
return Unique(h);
}

// If second argument of binary op is Leaf
template<typename OP>
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(gL, op));
auto h = boost::make_shared<Choice>(label(), nrChoices());
for (auto&& branch : branches_)
h->push_back(branch->apply_f_op_g(gL, op));
return Unique(h);
}

Expand All @@ -374,9 +369,9 @@ namespace gtsam {
return branches_[index]; // choose branch

// second case, not label of interest, just recurse
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
for(const NodePtr& branch: branches_)
r->push_back(branch->choose(label, index));
auto r = boost::make_shared<Choice>(label_, branches_.size());
for (auto&& branch : branches_)
r->push_back(branch->choose(label, index));
return Unique(r);
}

Expand All @@ -401,23 +396,22 @@ namespace gtsam {
}

/*********************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(//
const L& label, const Y& y1, const Y& y2) {
boost::shared_ptr<Choice> a(new Choice(label, 2));
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);
}

/*********************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(//
const LabelC& labelC, const Y& y1, const Y& y2) {
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) {
if (labelC.second != 2) throw std::invalid_argument(
"DecisionTree: binary constructor called with non-binary label");
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
auto a = boost::make_shared<Choice>(labelC.first, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
Expand Down Expand Up @@ -465,23 +459,20 @@ namespace gtsam {

/*********************************************************************************/
template <typename L, typename Y>
template <typename X>
template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X) {
Func Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
}

/*********************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map,
std::function<Y(const X&)> Y_of_X) {
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
const std::map<M, L>& map, Func Y_of_X) {
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
}

Expand Down Expand Up @@ -511,13 +502,14 @@ namespace gtsam {

// if label is already in correct order, just put together a choice on label
if (!nrChoices || !highestLabel || label > *highestLabel) {
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel);
} else {
// Set up a new choice on the highest label
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
auto choiceOnHighestLabel =
boost::make_shared<Choice>(*highestLabel, nrChoices);
// now, for all possible values of highestLabel
for (size_t index = 0; index < nrChoices; index++) {
// make a new set of functions for composing by iterating over the given
Expand Down Expand Up @@ -576,7 +568,7 @@ namespace gtsam {
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument");
}
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++)
choice->push_back(NodePtr(new Leaf(*y)));
return Choice::Unique(choice);
Expand All @@ -589,7 +581,7 @@ namespace gtsam {
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
functions += DecisionTree(f);
functions.emplace_back(f);
}
return compose(functions.begin(), functions.end(), begin->first);
}
Expand All @@ -601,18 +593,16 @@ namespace gtsam {
const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
using MX = DecisionTree<M, X>;
using MXLeaf = typename MX::Leaf;
using MXChoice = typename MX::Choice;
using MXNodePtr = typename MX::NodePtr;
using LY = DecisionTree<L, Y>;

// ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant())));

// Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
Expand All @@ -623,14 +613,93 @@ namespace gtsam {

// put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions;
for(const MXNodePtr& branch: choice->branches()) {
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
functions += converted;
for(auto && branch: choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
}

/*********************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.

/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(leaf->constant());

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};

template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visit(Func f) const {
Visit<L, Y> visit(f);
visit(root_);
}

/*********************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.

/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
}
}
};

template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visitWith(Func f) const {
VisitWith<L, Y> visit(f);
visit(root_);
}

/*********************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
template <typename Func, typename X>
X DecisionTree<L, Y>::fold(Func f, X x0) const {
visit([&](const Y& y) { x0 = f(y, x0); });
return x0;
}

/*********************************************************************************/
// labels is just done with a visit
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
};
visitWith(f);
return unique;
}

/*********************************************************************************/
template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const {
Expand Down
58 changes: 52 additions & 6 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <map>
#include <sstream>
#include <vector>
#include <set>

namespace gtsam {

Expand Down Expand Up @@ -176,9 +177,8 @@ namespace gtsam {
* @param other The DecisionTree to convert from.
* @param Y_of_X Functor to convert from value type X to type Y.
*/
template <typename X>
DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X);
template <typename X, typename Func>
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);

/**
* @brief Convert from a different value type X to value type Y, also transate
Expand All @@ -190,9 +190,9 @@ namespace gtsam {
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
std::function<Y(const X&)> Y_of_X);
template <typename M, typename X, typename Func>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
Func Y_of_X);

/// @}
/// @name Testable
Expand Down Expand Up @@ -229,6 +229,52 @@ namespace gtsam {
/** evaluate */
const Y& operator()(const Assignment<L>& x) const;

/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visit(Func f) const;

/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visitWith(Func f) const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
* @tparam X type for accumulator.
* @param f binary function: Y * X -> X returning an updated accumulator.
* @param x0 initial value for accumulator.
* @return X final value for accumulator.
*
* @note X is always passed by value.
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);
*/
template <typename Func, typename X>
X fold(Func f, X x0) const;

/** Retrieve all unique labels as a set. */
std::set<L> labels() const;

/** apply Unary operation "op" to f */
DecisionTree apply(const Unary& op) const;

Expand Down
Loading