Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 1, 2024
1 parent 043a01c commit c850508
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
8 changes: 4 additions & 4 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ function OptimizationBase.instantiate_function(
if hv == true && f.hv === nothing
prep_hvp = prepare_hvp(f.f, soadtype, x, (zeros(eltype(x), size(x)),), Constant(p))
function hv!(H, θ, v)
hvp!(f.f, H, prep_hvp, soadtype, θ, (v,), Constant(p))
hvp!(f.f, (H,), prep_hvp, soadtype, θ, (v,), Constant(p))
end
if p !== SciMLBase.NullParameters() && p !== nothing
function hv!(H, θ, v, p)
hvp!(f.f, H, prep_hvp, soadtype, θ, (v,), Constant(p))
hvp!(f.f, (H,), prep_hvp, soadtype, θ, (v,), Constant(p))
end
end
elseif hv == true
Expand Down Expand Up @@ -141,9 +141,9 @@ function OptimizationBase.instantiate_function(
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && cons_j == true && f.cons_j === nothing
prep_jac = prepare_jacobian(cons_oop, adtype, x, Constant(p))
prep_jac = prepare_jacobian(cons_oop, adtype, x)
function cons_j!(J, θ)
jacobian!(cons_oop, J, prep_jac, adtype, θ, Constant(p))
jacobian!(cons_oop, J, prep_jac, adtype, θ)
if size(J, 1) == 1
J = vec(J)
end
Expand Down
3 changes: 1 addition & 2 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ function instantiate_function(
end

function cons_oop(x, i)
_res = zeros(eltype(x))
_res = zeros(eltype(x), num_cons)
f.cons(_res, x, p)
@show _res
return _res[i]
end

Expand Down
2 changes: 1 addition & 1 deletion test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ optprob.cons_h(H3, x0)
optprob.cons(res, x0)
@test res == [0.0]
J = Array{Float64}(undef, 2)
@test optprob.cons_j(J, [5.0, 3.0])
optprob.cons_j(J, [5.0, 3.0])
@test J == [10.0, 6.0]
vJ = Array{Float64}(undef, 2)
optprob.cons_vjp(vJ, [5.0, 3.0], [1.0])
Expand Down

0 comments on commit c850508

Please sign in to comment.