Skip to content

Commit

Permalink
Merge pull request #1347 from borglab/hybrid/sample
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 24, 2022
2 parents df5b894 + da86e06 commit 2abc0d9
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 53 deletions.
112 changes: 72 additions & 40 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

/**
* @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
Expand All @@ -20,18 +20,20 @@
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>

// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);

namespace gtsam {

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;

// The canonical decision tree factor which will get the discrete conditionals
// added to it.
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor dtFactor;

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
Expand All @@ -53,7 +55,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
// and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());

Expand All @@ -62,7 +64,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
Expand Down Expand Up @@ -101,6 +103,7 @@ void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();

// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
Expand Down Expand Up @@ -153,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();

// Make a copy of the gaussian mixture and prune it!
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
Expand All @@ -173,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
return at(i)->asMixture();
}

/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
return at(i)->asGaussian();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
return at(i)->asDiscreteConditional();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment));

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));
} else if (conditional->isContinuous()) {
// If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian()));

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
} else if (conditional->isDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
}
Expand All @@ -213,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose(
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
Expand All @@ -232,6 +235,41 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional());
}
}
// Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given.continuous(), rng);
return {assignment, sample};
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
HybridValues given;
return sample(given, rng);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
Expand All @@ -242,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues,
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
AlgebraicDecisionTree<Key> error_tree(0.0);

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);

// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}
error_tree = error_tree + conditional_error;

} else if (factors_.at(idx)->isContinuous()) {
} else if (conditional->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
double error = conditional->asGaussian()->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
} else if (conditional->isDiscrete()) {
// Conditional is discrete-only, we skip.
continue;
}
}
Expand Down
46 changes: 43 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

/**
* @file HybridBayesNet.h
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
Expand Down Expand Up @@ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors
/// @{

/** Construct empty bayes net */
/** Construct empty Bayes net */
HybridBayesNet() = default;

/// @}
Expand Down Expand Up @@ -120,7 +120,47 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

public:
/**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
* Example:
* std::mt19937_64 rng(42);
* VectorValues given = ...;
* auto sample = bn.sample(given, &rng);
*
* @param given Values of missing variables.
* @param rng The pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const;

/**
* @brief Sample using ancestral sampling.
*
* Example:
* std::mt19937_64 rng(42);
* auto sample = bn.sample(&rng);
*
* @param rng The pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(std::mt19937_64 *rng) const;

/**
* @brief Sample from an incomplete BayesNet, use default rng.
*
* @param given Values of missing variables.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given) const;

/**
* @brief Sample using ancestral sampling, use default rng.
*
* @return HybridValues
*/
HybridValues sample() const;

/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);

Expand Down
69 changes: 69 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,75 @@ TEST(HybridBayesNet, Serialization) {
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}

/* ****************************************************************************/
// Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) {
HybridNonlinearFactorGraph nfg;

auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0));
auto zero_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
auto one_motion =
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
std::vector<NonlinearFactor::shared_ptr> factors = {zero_motion, one_motion};
nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model);
nfg.emplace_hybrid<MixtureFactor>(
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);

DiscreteKey mode(M(0), 2);
auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1");
nfg.push_discrete(discrete_prior);

Values initial;
double z0 = 0.0, z1 = 1.0;
initial.insert<double>(X(0), z0);
initial.insert<double>(X(1), z1);

// Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
// Eliminate into BN
Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);

// Set up sampling
std::mt19937_64 gen(11);

// Initialize containers for computing the mean values.
vector<double> discrete_samples;
VectorValues average_continuous;

size_t num_samples = 1000;
for (size_t i = 0; i < num_samples; i++) {
// Sample
HybridValues sample = bn->sample(&gen);

discrete_samples.push_back(sample.discrete()[M(0)]);

if (i == 0) {
average_continuous.insert(sample.continuous());
} else {
average_continuous += sample.continuous();
}
}

EXPECT_LONGS_EQUAL(2, average_continuous.size());
EXPECT_LONGS_EQUAL(num_samples, discrete_samples.size());

// Regressions don't work across platforms :-(
// // regression for specific RNG seed
// double discrete_sum =
// std::accumulate(discrete_samples.begin(), discrete_samples.end(),
// decltype(discrete_samples)::value_type(0));
// EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);

// VectorValues expected;
// expected.insert({X(0), Vector1(-0.0131207162712)});
// expected.insert({X(1), Vector1(-0.499026377568)});
// // regression for specific RNG seed
// EXPECT(assert_equal(expected, average_continuous.scale(1.0 /
// num_samples)));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
5 changes: 3 additions & 2 deletions gtsam/linear/GaussianBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ namespace gtsam {
return sample(result, rng);
}

VectorValues GaussianBayesNet::sample(VectorValues result,
VectorValues GaussianBayesNet::sample(const VectorValues& given,
std::mt19937_64* rng) const {
VectorValues result(given);
// sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result, rng);
Expand All @@ -79,7 +80,7 @@ namespace gtsam {
return sample(&kRandomNumberGenerator);
}

VectorValues GaussianBayesNet::sample(VectorValues given) const {
VectorValues GaussianBayesNet::sample(const VectorValues& given) const {
return sample(given, &kRandomNumberGenerator);
}

Expand Down
4 changes: 2 additions & 2 deletions gtsam/linear/GaussianBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ namespace gtsam {
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
VectorValues sample(const VectorValues& given, std::mt19937_64* rng) const;

/// Sample using ancestral sampling, use default rng
VectorValues sample() const;

/// Sample from an incomplete BayesNet, use default rng
VectorValues sample(VectorValues given) const;
VectorValues sample(const VectorValues& given) const;

/**
* Return ordering corresponding to a topological sort.
Expand Down
Loading

0 comments on commit 2abc0d9

Please sign in to comment.