Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear HybridBayesNet optimization #1270

Merged
merged 7 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @date January 2022
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>

namespace gtsam {

Expand All @@ -40,4 +43,10 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}

/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
}

} // namespace gtsam
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/linear/GaussianBayesNet.h>

Expand Down Expand Up @@ -61,6 +62,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return GaussianBayesNet
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;
};

} // namespace gtsam
76 changes: 76 additions & 0 deletions gtsam/hybrid/HybridLookupDAG.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file DiscreteLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h>

#include <string>
#include <utility>

using std::pair;
using std::vector;

namespace gtsam {

/* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in
// DiscreteLookupTable.
if (isDiscrete()) {
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
values->continuous));
} else if (isHybrid()) {
// For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous));
}
}

/* ************************************************************************** */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag;
for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional);
dag.push_back(hlt);
}
return dag;
}

/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}

} // namespace gtsam
119 changes: 119 additions & 0 deletions gtsam/hybrid/HybridLookupDAG.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file HybridLookupDAG.h
* @date Aug, 2022
* @author Shangjie Xue
*/

#pragma once

#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>

namespace gtsam {

/**
* @brief HybridLookupTable table for max-product
*
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
* convenience. Is used in the max-product algorithm.
*/
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public:
using Base = HybridConditional;
using This = HybridLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;

/**
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
*
* @param conditional input hybrid conditional
*/
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};

/**
* @brief Calculate assignment for frontal variables that maximizes value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xsj01 What would be the frontal variables here? They should all be continuous variables since we eliminate those first, so then this doesn't make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it seems the HybridLookupTable is not necessary. Just had some discussions with Fan, I will submit a new PR to fix this.

* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(HybridValues* parentsValues) const;
};

/** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs for this class are severely lacking. @xsj01 you have to be better than just copy-pasting the same docstrings, since now the motivation for this class is lost.

public:
using Base = BayesNet<HybridLookupTable>;
using This = HybridLookupDAG;
using shared_ptr = boost::shared_ptr<This>;

/// @name Standard Constructors
/// @{

/// Construct empty DAG.
HybridLookupDAG() {}

/// Create from BayesNet with LookupTables
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);

/// Destructor
virtual ~HybridLookupDAG() {}

/// @}

/// @name Standard Interface
/// @{

/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
}

/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
HybridValues argmax(HybridValues given = HybridValues()) const;
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

// traits
template <>
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};

} // namespace gtsam
127 changes: 127 additions & 0 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file HybridValues.h
* @date Jul 28, 2022
* @author Shangjie Xue
*/

#pragma once

#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>

#include <map>
#include <string>
#include <vector>

namespace gtsam {

/**
* HybridValues represents a collection of DiscreteValues and VectorValues. It
* is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for thoughts: I always wondered can we just get rid of DiscreteKey and stuff, and store the cardinality of variables somewhere else, like along with the value itself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the use of this class. Can't we simply have optimize run MPE on the discrete tree and then run a regular optimization on the continuous values corresponding to the assignment? That way we can just return a pair<Values, DiscreteValues>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's also an option. But I feel like the api is not very user friendly while accessing/assigning values. We can talk about it in our meeting.

public:
// DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete;

// VectorValue stored the continuous components of the HybridValues.
VectorValues continuous;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name the class HybridVectorValues? This class is linear only. (comment only)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can change VectorValues to Values?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more like a design question: Is this limited to linear hybrid systems? If so we should name it Gaussian.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will extend that for nonlinear hybrid later, but currently it's only for linear hybrid system.
If we rename this to HybridVectorValues, then we will also need HybridValues for nonlinear hybrid system later. Is it necessary to have both HybridVectorValues and HybridValues?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the spirit for linear and nonlinear though. because VectorValues and values are different we should probably have the same


// Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous(){};

// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){};

// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous",
keyFormatter); // print continuous components
};

// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) &&
continuous.equals(other.continuous, tol);
}

// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };

// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous.exists(j); };

// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };

/** Insert a discrete \c value with key \c j. Replaces the existing value if
* the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; };

/** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); }

// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h

/**
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete.at(j); };

/**
* Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
Vector& at(Key j) { return continuous.at(j); };

/// @name Wrapper support
/// @{

/**
* @brief Output as a html table.
*
* @param keyFormatter function that formats keys.
* @return string html output.
*/
std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss;
ss << this->discrete.html(keyFormatter);
ss << this->continuous.html(keyFormatter);
return ss.str();
};

/// @}
};

// traits
template <>
struct traits<HybridValues> : public Testable<HybridValues> {};

} // namespace gtsam
Loading