Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 15, 2024
1 parent 97ec67b commit b13c812
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "2.3.0"


[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
8 changes: 4 additions & 4 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse)
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward)
set_runtime_activity2(Enzyme.Forward. adtype.mode)
end

if g == true && f.grad === nothing
Expand Down Expand Up @@ -423,13 +423,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
rmode = if adtype.mode isa Nothing
Enzyme.Reverse
else
set_runtime_activity2(Enzyme.Reverse)
set_runtime_activity2(Enzyme.Reverse, adtype.mode)
end

fmode = if adtype.mode isa Nothing
Enzyme.Forward
else
set_runtime_activity2(Enzyme.Forward)
set_runtime_activity2(Enzyme.Forward, adtype.mode)
end

if g == true && f.grad === nothing
Expand Down
6 changes: 3 additions & 3 deletions test/adtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using OptimizationBase, Test, DifferentiationInterface, SparseArrays, Symbolics
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
using ModelingToolkit, Enzyme, Random
using Enzyme, Random

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
Expand Down Expand Up @@ -1174,8 +1174,8 @@ using MLUtils
optf = OptimizationBase.instantiate_function(
optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true)
G0 = zeros(3)
@test_broken optf.grad(G0, ones(3), (x, y))
stochgrads = []
@test_broken optf.grad(G0, ones(3), (x0, y0))
# stochgrads = []
# for (x,y) in data
# G = zeros(3)
# optf.grad(G, ones(3), (x,y))
Expand Down

0 comments on commit b13c812

Please sign in to comment.