Skip to content

Commit

Permalink
fix: fix bug in remake_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 26, 2024
1 parent e1befe0 commit 628de91
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ function MTKParameters(
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer)...)
dep_exprs = ArrayPartition((Any[0 for _ in eachindex(v)] for v in dep_buffer)...)
for (sym, val) in pdeps
i, j = ic.dependent_idx[sym]
dep_exprs.x[i][j] = wrap(val)
dep_exprs.x[i][j] = unwrap(val)
end
p = reorder_parameters(ic, full_parameters(sys))
oop, iip = build_function(dep_exprs, p...)
Expand Down Expand Up @@ -398,7 +398,10 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
oldbuf.nonnumeric, newbuf.nonnumeric)
if newbuf.dependent_update_oop !== nothing
@set! newbuf.dependent = newbuf.dependent_update_oop(newbuf...)
@set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.(
oldbuf.dependent,
split_into_buffers(
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false)))
end
return newbuf
end
Expand All @@ -422,6 +425,7 @@ _num_subarrays(v::Tuple) = length(v)
# getindex indexes the vectors, setindex! linearly indexes values
# it's inconsistent, but we need it to be this way
function Base.getindex(buf::MTKParameters, i)
i_orig = i
if !isempty(buf.tunable)
i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i]
i -= _num_subarrays(buf.tunable)
Expand All @@ -442,7 +446,7 @@ function Base.getindex(buf::MTKParameters, i)
i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i]
i -= _num_subarrays(buf.dependent)
end
throw(BoundsError(buf, i))
throw(BoundsError(buf, i_orig))

Check warning on line 449 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L449

Added line #L449 was not covered by tests
end
function Base.setindex!(p::MTKParameters, val, i)
function _helper(buf)
Expand Down Expand Up @@ -526,9 +530,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
for (i, val) in zip(input_idxs, p_small_inner)
_set_parameter_unchecked!(p_big, val, i)
end
# tunable, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p_big)
# tunable[input_idxs] .= p_small_inner
# p_big = repack(tunable)
return if pf isa SciMLBase.ParamJacobianWrapper
buffer = Array{dualtype}(undef, size(pf.u))
pf(buffer, p_big)
Expand All @@ -538,8 +539,6 @@ function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where
end
end
end
# tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
# p_small = tunable[input_idxs]
p_small = parameter_values.((p,), input_idxs)
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
Expand Down
17 changes: 17 additions & 0 deletions test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,20 @@ function loss(x)
end

@test_nowarn ForwardDiff.gradient(loss, collect(tunables))

# Ensure dependent parameters are `Tuple{...}` and not `ArrayPartition` when using
# `remake_buffer`.
@parameters p1 p2 p3[1:2] p4[1:2]
@named sys = ODESystem(
Equation[], t, [], [p1, p2, p3, p4]; parameter_dependencies = [p2 => 2p1, p4 => 3p3])
sys = complete(sys)
ps = MTKParameters(sys, [p1 => 1.0, p3 => [2.0, 3.0]])
@test ps[parameter_index(sys, p2)] == 2.0
@test ps[parameter_index(sys, p4)] == [6.0, 9.0]

newps = remake_buffer(
sys, ps, Dict(p1 => ForwardDiff.Dual(2.0), p3 => ForwardDiff.Dual.([3.0, 4.0])))

VDual = Vector{<:ForwardDiff.Dual}
VVDual = Vector{<:Vector{<:ForwardDiff.Dual}}
@test newps.dependent isa Union{Tuple{VDual, VVDual}, Tuple{VVDual, VDual}}

0 comments on commit 628de91

Please sign in to comment.