From 04b03cdc2e0e782dfc6d920c277c740e4ccf5f06 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:34:38 +0100 Subject: [PATCH] Remove `tonamedtuple` (#547) * Remove dependencies to `tonamedtuple` * Remove `tonamedtuple`s * Minor version bump --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- docs/src/api.md | 1 - src/DynamicPPL.jl | 1 - src/abstract_varinfo.jl | 15 --------------- src/simple_varinfo.jl | 38 -------------------------------------- src/threadsafe.jl | 2 -- src/varinfo.jl | 16 ---------------- test/test_util.jl | 30 +++++++++++++++--------------- 8 files changed, 16 insertions(+), 89 deletions(-) diff --git a/Project.toml b/Project.toml index 47246ce11..6a7cda61b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.21" +version = "0.24.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/api.md b/docs/src/api.md index a729ee754..9b98f9dc6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -258,7 +258,6 @@ DynamicPPL.reconstruct Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten -DynamicPPL.tonamedtuple DynamicPPL.varname_leaves DynamicPPL.varname_and_value_leaves ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9853d8140..b90381dea 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -71,7 +71,6 @@ export AbstractVarInfo, invlink, invlink!, invlink!!, - tonamedtuple, values_as, # VarName (reexport from AbstractPPL) VarName, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0218a1882..67c2f3fcb 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -738,21 +738,6 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac return unflatten(varinfo, sampler, θ) end -""" - tonamedtuple(vi::AbstractVarInfo) - -Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and -indexing string of the variable. - -For example, a model that had a vector of vector-valued -variables `x` would return - -```julia -(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), ) -``` -""" -function tonamedtuple end - # TODO: Clean up all this linking stuff once and for all! """ with_logabsdet_jacobian_and_reconstruct([f, ]dist, x) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 400dd93fe..93c211483 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -532,44 +532,6 @@ function dot_assume( return value, lp, vi end -# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl. -# TODO: Move away from using these `tonamedtuple` methods. -function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} - nt_vals = map(keys(vi)) do vn - val = vi[vn] - vns = collect(TestUtils.varname_leaves(vn, val)) - vals = map(copy ∘ Base.Fix1(getindex, vi), vns) - (vals, map(string, vns)) - end - - return NamedTuple{names}(nt_vals) -end - -function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) - syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}() - for vn in keys(vi) - # Extract the leaf varnames and values. - val = vi[vn] - vns = collect(TestUtils.varname_leaves(vn, val)) - vals = map(copy ∘ Base.Fix1(getindex, vi), vns) - - # Determine the corresponding symbol. - sym = only(unique(map(getsym, vns))) - - # Initialize entry if not yet initialized. - if !haskey(syms_to_result, sym) - syms_to_result[sym] = (Real[], String[]) - end - - # Combine with old result. - old_vals, old_string_vns = syms_to_result[sym] - syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns))) - end - - # Construct `NamedTuple`. - return NamedTuple(pairs(syms_to_result)) -end - # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ab504de23..fb1cc1c0c 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -209,8 +209,6 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) end -tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo) - # Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0d5dce7aa..590626df3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1506,22 +1506,6 @@ end return expr end -# TODO: Remove this completely. -tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo) -function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names} - length(names) === 0 && return NamedTuple() - - vals_tuple = map(values(metadata)) do x - # NOTE: `tonamedtuple` is really only used in Turing.jl to convert to - # a "transition". This means that we really don't mutations of the values - # in `varinfo` to propoagate the previous samples. Hence we `copy.` - vals = map(copy ∘ Base.Fix1(getindex, varinfo), x.vns) - return vals, map(string, x.vns) - end - - return NamedTuple{names}(vals_tuple) -end - @inline function findvns(vi, f_vns) if length(f_vns) == 0 throw("Unidentified error, please report this error in an issue.") diff --git a/test/test_util.jl b/test/test_util.jl index 7a7028536..64832f51e 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -58,22 +58,22 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1) DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) θ_new = var_info[spl] @test θ_old != θ_new - nt = DynamicPPL.tonamedtuple(var_info) - for (k, (vals, names)) in pairs(nt) - for (n, v) in zip(names, vals) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v end + + @test v_true == chain_val end end