Skip to content

Commit

Permalink
simplify test
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent 9d19db2 commit f9b0784
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

function f(A, b1, b2; alg = LUFactorization())
function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)
Expand All @@ -18,16 +16,15 @@ function f(A, b1, b2; alg = LUFactorization())
norm(s1)
end

f(A, b1, b2) # Uses BLAS
f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1))

@test dA dA2
@test db1 db12
@test db2 == zeros(4)

A = rand(n, n);
dA = zeros(n, n);
Expand All @@ -36,9 +33,6 @@ b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

b2 = rand(n);
db2 = zeros(n);
db22 = zeros(n);

@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22)))
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
# This is not legal, all args need to be batch'd at the same size
@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)))
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1))

0 comments on commit f9b0784

Please sign in to comment.