Skip to content

Commit

Permalink
Restrict values_as_in_model API
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Jan 15, 2025
1 parent e673b69 commit c2916fa
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.1"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
45 changes: 17 additions & 28 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
Expand Down Expand Up @@ -95,7 +95,7 @@ function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)

# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
_, _, _vns = unwrap_right_left_vns(right, var, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
Expand All @@ -107,41 +107,39 @@ function dot_tilde_assume(
rng, childcontext(context), sampler, right, left, vn, vi
)
# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
_, _, _vns = unwrap_right_left_vns(right, left, vn)

Check warning on line 110 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L110

Added line #L110 was not covered by tests
broadcast_push!(context, _vns, value)

return value, logp, vi
end

"""
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
More specifically, this method attempts to extract the realization _as seen in
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
working in unconstrained space.
More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
- `context::AbstractContext`: base context to use for the extraction. Defaults
to `DynamicPPL.DefaultContext()`.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
Expand Down Expand Up @@ -191,19 +189,10 @@ true
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
varinfo::AbstractVarInfo,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
16 changes: 0 additions & 16 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "check that sampling obeys rng if passed" begin
@model function f()
x ~ Normal(0)
return y ~ Normal(x)
end
model = f()
# Call values_as_in_model with the rng
values = values_as_in_model(Random.Xoshiro(43), model, false)
# Check that they match the values that would be used if vi was seeded
# with that seed instead
expected_vi = VarInfo(Random.Xoshiro(43), model)
for vn in keys(values)
@test values[vn] == expected_vi[vn]
end
end
end

@testset "Erroneous model call" begin
Expand Down

0 comments on commit c2916fa

Please sign in to comment.