-
Notifications
You must be signed in to change notification settings - Fork 779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid Mixture Error #1318
Hybrid Mixture Error #1318
Changes from 16 commits
64744b0
c41b58f
aa1c65d
5c375f6
d834897
c0eeb0c
9365a02
ca14b7e
281ad31
cff6505
11e4c1e
8fa7f44
9cb225a
551cc0d
07a616d
c2ef4f2
0e1c3b8
098d2ce
d94b319
23ec7ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,4 +207,28 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { | |
conditionals_.root_ = pruned_conditionals.root_; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
AlgebraicDecisionTree<Key> GaussianMixture::error( | ||
const VectorValues &continuousVals) const { | ||
// functor to convert from GaussianConditional to double error value. | ||
auto errorFunc = | ||
[continuousVals](const GaussianConditional::shared_ptr &conditional) { | ||
if (conditional) { | ||
return conditional->error(continuousVals); | ||
} else { | ||
// return arbitrarily large error | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment on when this would happen? |
||
return 1e50; | ||
} | ||
}; | ||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc); | ||
return errorTree; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
double GaussianMixture::error(const VectorValues &continuousVals, | ||
dellaert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const DiscreteValues &discreteValues) const { | ||
auto conditional = conditionals_(discreteValues); | ||
return conditional->error(continuousVals); | ||
} | ||
|
||
} // namespace gtsam |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor { | |
std::copy(f->keys().begin(), f->keys().end(), | ||
std::inserter(factor_keys_set, factor_keys_set.end())); | ||
|
||
nonlinear_factors.push_back( | ||
boost::dynamic_pointer_cast<NonlinearFactor>(f)); | ||
if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) { | ||
dellaert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nonlinear_factors.push_back(nf); | ||
} else { | ||
throw std::runtime_error( | ||
"Factors passed into MixtureFactor need to be nonlinear!"); | ||
} | ||
} | ||
factors_ = Factors(discreteKeys, nonlinear_factors); | ||
|
||
|
@@ -121,6 +125,22 @@ class MixtureFactor : public HybridFactor { | |
|
||
~MixtureFactor() = default; | ||
|
||
/** | ||
* @brief Compute error of the MixtureFactor as a tree. | ||
* | ||
* @param continuousVals The continuous values for which to compute the error. | ||
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys | ||
* as the factor but leaf values as the error. | ||
*/ | ||
AlgebraicDecisionTree<Key> error(const Values& continuousVals) const { | ||
// functor to convert from sharedFactor to double error value. | ||
auto errorFunc = [continuousVals](const sharedFactor& factor) { | ||
return factor->error(continuousVals); | ||
}; | ||
DecisionTree<Key, double> errorTree(factors_, errorFunc); | ||
return errorTree; | ||
} | ||
|
||
/** | ||
* @brief Compute error of factor given both continuous and discrete values. | ||
* | ||
|
@@ -149,7 +169,7 @@ class MixtureFactor : public HybridFactor { | |
|
||
/// print to stdout | ||
void print( | ||
const std::string& s = "MixtureFactor", | ||
const std::string& s = "", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It anyway prints |
||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { | ||
std::cout << (s.empty() ? "" : s + " "); | ||
Base::print("", keyFormatter); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert? -? calculate