From b13c812d40297b4d9f25b7bdc51829ea2c7a1ae4 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Mon, 14 Oct 2024 17:24:54 -0400 Subject: [PATCH] Apply suggestions from code review --- Project.toml | 1 - ext/OptimizationEnzymeExt.jl | 8 ++++---- test/adtests.jl | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 7f3356c..05bcd60 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" authors = ["Vaibhav Dixit and contributors"] version = "2.3.0" - [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 9ccab88..4cbea72 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -91,13 +91,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, rmode = if adtype.mode isa Nothing Enzyme.Reverse else - set_runtime_activity2(Enzyme.Reverse) + set_runtime_activity2(Enzyme.Reverse, adtype.mode) end fmode = if adtype.mode isa Nothing Enzyme.Forward else - set_runtime_activity2(Enzyme.Forward) + set_runtime_activity2(Enzyme.Forward. adtype.mode) end if g == true && f.grad === nothing @@ -423,13 +423,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x rmode = if adtype.mode isa Nothing Enzyme.Reverse else - set_runtime_activity2(Enzyme.Reverse) + set_runtime_activity2(Enzyme.Reverse, adtype.mode) end fmode = if adtype.mode isa Nothing Enzyme.Forward else - set_runtime_activity2(Enzyme.Forward) + set_runtime_activity2(Enzyme.Forward, adtype.mode) end if g == true && f.grad === nothing diff --git a/test/adtests.jl b/test/adtests.jl index fc85ad1..1075e7f 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -1,6 +1,6 @@ using OptimizationBase, Test, DifferentiationInterface, SparseArrays, Symbolics using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker -using ModelingToolkit, Enzyme, Random +using Enzyme, Random x0 = zeros(2) rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 @@ -1174,8 +1174,8 @@ using MLUtils optf = OptimizationBase.instantiate_function( optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true) G0 = zeros(3) - @test_broken optf.grad(G0, ones(3), (x, y)) - stochgrads = [] + @test_broken optf.grad(G0, ones(3), (x0, y0)) + # stochgrads = [] # for (x,y) in data # G = zeros(3) # optf.grad(G, ones(3), (x,y))