diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 41a734b34b..d42fbe7722 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -59,27 +59,30 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const { + VectorValues GaussianBayesNet::sample(std::mt19937_64* rng, + const SharedDiagonal& model) const { VectorValues result; // no missing variables -> create an empty vector - return sample(result, rng); + return sample(result, rng, model); } VectorValues GaussianBayesNet::sample(VectorValues result, - std::mt19937_64* rng) const { + std::mt19937_64* rng, + const SharedDiagonal& model) const { // sample each node in reverse topological sort order (parents first) for (auto cg : boost::adaptors::reverse(*this)) { - const VectorValues sampled = cg->sample(result, rng); + const VectorValues sampled = cg->sample(result, rng, model); result.insert(sampled); } return result; } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample() const { + VectorValues GaussianBayesNet::sample(const SharedDiagonal& model) const { return sample(&kRandomNumberGenerator); } - VectorValues GaussianBayesNet::sample(VectorValues given) const { + VectorValues GaussianBayesNet::sample(VectorValues given, + const SharedDiagonal& model) const { return sample(given, &kRandomNumberGenerator); } diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index 83328576f2..570bfef58d 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -101,7 +101,8 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng) const; + VectorValues sample(std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /** * Sample from an incomplete BayesNet, given missing variables @@ -110,13 +111,15 @@ namespace gtsam { * VectorValues given = ...; * auto sample = gbn.sample(given, &rng); */ - VectorValues sample(VectorValues given, std::mt19937_64* rng) const; + VectorValues sample(VectorValues given, std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /// Sample using ancestral sampling, use default rng - VectorValues sample() const; + VectorValues sample(const SharedDiagonal& model = nullptr) const; /// Sample from an incomplete BayesNet, use default rng - VectorValues sample(VectorValues given) const; + VectorValues sample(VectorValues given, + const SharedDiagonal& model = nullptr) const; /** * Return ordering corresponding to a topological sort. diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 60ddb1b7d0..363d25d112 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -293,39 +293,48 @@ double GaussianConditional::logDeterminant() const { /* ************************************************************************ */ VectorValues GaussianConditional::sample(const VectorValues& parentsValues, - std::mt19937_64* rng) const { + std::mt19937_64* rng, + const SharedDiagonal& model) const { if (nrFrontals() != 1) { throw std::invalid_argument( "GaussianConditional::sample can only be called on single variable " "conditionals"); } - if (!model_) { + + VectorValues solution = solve(parentsValues); + Key key = firstFrontalKey(); + + Vector sigmas; + if (model_) { + sigmas = model_->sigmas(); + } else if (model) { + sigmas = model->sigmas(); + } else { throw std::invalid_argument( "GaussianConditional::sample can only be called if a diagonal noise " "model was specified at construction."); } - VectorValues solution = solve(parentsValues); - Key key = firstFrontalKey(); - const Vector& sigmas = model_->sigmas(); solution[key] += Sampler::sampleDiagonal(sigmas, rng); return solution; } - VectorValues GaussianConditional::sample(std::mt19937_64* rng) const { + VectorValues GaussianConditional::sample(std::mt19937_64* rng, + const SharedDiagonal& model) const { if (nrParents() != 0) throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); VectorValues values; - return sample(values); + return sample(values, rng, model); } /* ************************************************************************ */ - VectorValues GaussianConditional::sample() const { - return sample(&kRandomNumberGenerator); + VectorValues GaussianConditional::sample(const SharedDiagonal& model) const { + return sample(&kRandomNumberGenerator, model); } - VectorValues GaussianConditional::sample(const VectorValues& given) const { - return sample(given, &kRandomNumberGenerator); + VectorValues GaussianConditional::sample(const VectorValues& given, + const SharedDiagonal& model) const { + return sample(given, &kRandomNumberGenerator, model); } /* ************************************************************************ */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 8af7f66029..1ca9b7d531 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -188,7 +188,8 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng) const; + VectorValues sample(std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /** * Sample from conditional, given missing variables @@ -198,13 +199,15 @@ namespace gtsam { * auto sample = gbn.sample(given, &rng); */ VectorValues sample(const VectorValues& parentsValues, - std::mt19937_64* rng) const; + std::mt19937_64* rng, + const SharedDiagonal& model = nullptr) const; /// Sample, use default rng - VectorValues sample() const; + VectorValues sample(const SharedDiagonal& model = nullptr) const; /// Sample with given values, use default rng - VectorValues sample(const VectorValues& parentsValues) const; + VectorValues sample(const VectorValues& parentsValues, + const SharedDiagonal& model = nullptr) const; /// @}