Skip to content

Commit

Permalink
Merge pull request #1013 from borglab/feature/remove_potentials
Browse files Browse the repository at this point in the history
Remove Potentials
  • Loading branch information
ProfFan authored Jan 7, 2022
2 parents 0683dfa + 79cb4d0 commit 4dafcc5
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 219 deletions.
1 change: 1 addition & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,6 @@ namespace gtsam {
};
// AlgebraicDecisionTree

template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
}
// namespace gtsam
35 changes: 27 additions & 8 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}

/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), Potentials(c) {
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}

/* ************************************************************************* */
Expand All @@ -48,16 +49,24 @@ namespace gtsam {
return false;
}
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}

/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}

/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
cout << s;
Potentials::print("Potentials:",formatter);
ADT::print("Potentials:",formatter);
}

/* ************************************************************************* */
Expand Down Expand Up @@ -162,20 +171,20 @@ namespace gtsam {
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
ADT::dot(os, keyFormatter, valueFormatter, showZero);
}

/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}

/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero);
return ADT::dot(keyFormatter, valueFormatter, showZero);
}

/* ************************************************************************* */
Expand Down Expand Up @@ -209,5 +218,15 @@ namespace gtsam {
return ss.str();
}

DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}

DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}

/* ************************************************************************* */
} // namespace gtsam
27 changes: 18 additions & 9 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#pragma once

#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h>

#include <boost/shared_ptr.hpp>
Expand All @@ -35,14 +36,18 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {

public:

// typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;

protected:
std::map<Key,size_t> cardinalities_;

public:

Expand All @@ -55,11 +60,11 @@ namespace gtsam {
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);

/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);

/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);

/// Single-key specialization
template <class SOURCE>
Expand All @@ -71,7 +76,7 @@ namespace gtsam {
: DecisionTreeFactor(DiscreteKeys{key}, row) {}

/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);
explicit DecisionTreeFactor(const DiscreteConditional& c);

/// @}
/// @name Testable
Expand All @@ -90,14 +95,18 @@ namespace gtsam {

/// Value is just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, ADT::Ring::mul);
}

static double safe_div(const double& a, const double& b);

size_t cardinality(Key j) const { return cardinalities_.at(j);}

/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s,
}
}
cout << ")";
Potentials::print("");
ADT::print("");
cout << endl;
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,

/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}

/** Convert to a factor */
Expand Down
96 changes: 0 additions & 96 deletions gtsam/discrete/Potentials.cpp

This file was deleted.

97 changes: 0 additions & 97 deletions gtsam/discrete/Potentials.h

This file was deleted.

Loading

0 comments on commit 4dafcc5

Please sign in to comment.