Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record number of assignments for each leaf #1154

Merged
merged 1 commit into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,23 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;

/** Constructor from constant */
Leaf(const Y& constant) :
constant_(constant) {}
/** The number of assignments contained within this leaf
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}

/** return the constant */
const Y& constant() const {
return constant_;
}

/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }

/// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_;
Expand Down Expand Up @@ -108,14 +116,14 @@ namespace gtsam {

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_)));
NodePtr f(new Leaf(op(constant_), nrAssignments_));
return f;
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_)));
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
return f;
}

Expand All @@ -130,7 +138,8 @@ namespace gtsam {

// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
return h;
}

Expand All @@ -141,7 +150,7 @@ namespace gtsam {

/** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant()));
return NodePtr(new Leaf(constant(), nrAssignments()));
}

bool isLeaf() const override { return true; }
Expand Down Expand Up @@ -178,9 +187,16 @@ namespace gtsam {
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());

size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
}
NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
#endif
Expand Down Expand Up @@ -640,7 +656,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
}

// Check if Choice
Expand Down
43 changes: 43 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int);
}

/* ************************************************************************** */
// Test nrAssignments.
TEST(DecisionTree, NrAssignments) {
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf());
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());

DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
/* The tree is
Choice(C)
0 Choice(B)
0 0 Leaf 1
0 1 Choice(A)
0 1 0 Leaf 1
0 1 1 Leaf 2
1 Choice(B)
1 0 Choice(A)
1 0 0 Leaf 3
1 0 1 Leaf 4
1 1 Leaf 5
*/

auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
CHECK(choice0);
EXPECT(choice0->branches()[0]->isLeaf());
auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
CHECK(choice00);
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());

auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
CHECK(choice1);
auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
CHECK(choice10);
auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
CHECK(choice11);
EXPECT(choice11->isLeaf());
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
}

/* ************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {
Expand Down