Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify tests of views and subarrays #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 55 additions & 127 deletions src/nlp/view-subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@ Check that the API work with views, and that the results is correct.
"""
function view_subarray_nlp(nlp; exclude = [])
@testset "Test view subarray of NLPs" begin
n, m = nlp.meta.nvar, nlp.meta.ncon
N = 2n
Vidxs = [1:2:N, collect(N:-2:1)]
Cidxs = if m > 0
N = 2m
[1:2:N, collect(N:-2:1)]
else
[]
end
n, m, nnzh = nlp.meta.nvar, nlp.meta.ncon, nlp.meta.nnzh

# Inputs
x = [-(-1.1)^i for i = 1:(2n)] # Instead of [1, -1, …], because it needs to
v = [-(-1.1)^i for i = 1:(2n)] # access different parts of the vector and
y = [-(-1.1)^i for i = 1:(2m)] # make a difference
x = [-(-1.1)^i for i = 1:n] # Instead of [1, -1, …], because it needs to
v = [-(-1.1)^i for i = 1:n] # access different parts of the vector and
y = [-(-1.1)^i for i = 1:m] # make a difference

# Outputs
g = zeros(n)
Expand All @@ -33,128 +25,64 @@ function view_subarray_nlp(nlp; exclude = [])
jty2 = zeros(2n)
hv = zeros(n)
hv2 = zeros(2n)
hval = zeros(nnzh)
hval2 = zeros(2nnzh)

for I in Vidxs
xv = @view x[I]
for foo in setdiff([obj, grad, hess], exclude)
@test foo(nlp, x[I]) ≈ foo(nlp, xv)
end

if hess_coord ∉ exclude
vals1 = hess_coord(nlp, x[I])
vals2 = hess_coord(nlp, xv)
@test vals1 ≈ vals2
end

if m > 0
for foo in setdiff([cons, jac], exclude)
@test foo(nlp, x[I]) ≈ foo(nlp, xv)
end
if jac_coord ∉ exclude
vals1 = jac_coord(nlp, x[I])
vals2 = jac_coord(nlp, xv)
@test vals1 ≈ vals2
end
end

if hess ∉ exclude
for J in Cidxs
yv = @view y[J]
@test hess(nlp, x[I], y[J]) ≈ hess(nlp, xv, yv)
yv = @view y[J]
end
end
I = 1:2:2n
Iv = 1:nnzh
J = 1:2:2m

if hess_coord ∉ exclude
Hval = @view hval2[Iv]
vals1 = hess_coord!(nlp, x, Hval)
vals2 = hess_coord!(nlp, x, hval)
@test hval ≈ hval2[Iv]
end

if hess_coord ∉ exclude
for J in Cidxs
yv = @view y[J]
vals1 = hess_coord(nlp, x[I], y[J])
vals2 = hess_coord(nlp, xv, yv)
@test vals1 ≈ vals2
end
end
if hess_coord ∉ exclude && m > 0
Hval = @view hval2[Iv]
vals1 = hess_coord!(nlp, x, y, Hval)
vals2 = hess_coord!(nlp, x, y, hval)
@test hval ≈ hval2[Iv]
end

# Inplace methods can have input and output as view, so 4 possibilities
if grad ∉ exclude
for J in Vidxs
gv = @view g2[J]
grad!(nlp, x[I], g)
grad!(nlp, x[I], gv)
@test g ≈ g2[J]
grad!(nlp, xv, gv)
@test g ≈ g2[J]
grad!(nlp, xv, g)
@test g ≈ g2[J]
end
end
if grad ∉ exclude
gv = @view g2[I]
grad!(nlp, x, gv)
grad!(nlp, x, g)
@test g ≈ g2[I]
end

if cons ∉ exclude
for J in Cidxs
cv = @view c2[J]
cons!(nlp, x[I], c)
cons!(nlp, x[I], cv)
@test c ≈ c2[J]
cons!(nlp, xv, cv)
@test c ≈ c2[J]
cons!(nlp, xv, c)
@test c ≈ c2[J]
end
end
if cons ∉ exclude && m > 0
cv = @view c2[J]
cons!(nlp, x, cv)
cons!(nlp, x, c)
@test c ≈ c2[J]
end

if jprod ∉ exclude
for J in Cidxs, K in Vidxs
vv = @view v[K]
jvv = @view jv2[J]
@test jprod(nlp, x[I], v[K]) ≈ jprod(nlp, xv, vv)
jprod!(nlp, x[I], v[K], jv)
jprod!(nlp, x[I], v[K], jvv)
@test jv ≈ jv2[J]
jprod!(nlp, xv, vv, jvv)
@test jv ≈ jv2[J]
jprod!(nlp, xv, vv, jv)
@test jv ≈ jv2[J]
end
end
if jprod ∉ exclude && m > 0
jvv = @view jv2[J]
jprod!(nlp, x, v, jvv)
jprod!(nlp, x, v, jv)
@test jv ≈ jv2[J]
end

if jtprod ∉ exclude
for J in Cidxs, K in Vidxs
yv = @view y[J]
jtyv = @view jty2[K]
@test jtprod(nlp, x[I], y[J]) ≈ jtprod(nlp, xv, yv)
jtprod!(nlp, x[I], y[J], jty)
jtprod!(nlp, x[I], y[J], jtyv)
@test jty ≈ jty2[K]
jtprod!(nlp, xv, yv, jtyv)
@test jty ≈ jty2[K]
jtprod!(nlp, xv, yv, jty)
@test jty ≈ jty2[K]
end
end
if jtprod ∉ exclude && m > 0
jtyv = @view jty2[I]
jtprod!(nlp, x, y, jtyv)
jtprod!(nlp, x, y, jty)
@test jty ≈ jty2[I]
end

if hprod ∉ exclude
for J in Vidxs, K in Vidxs
vv = @view v[J]
hvv = @view hv2[K]
@test hprod(nlp, x[I], v[J]) ≈ hprod(nlp, xv, vv)
hprod!(nlp, x[I], v[J], hv)
hprod!(nlp, x[I], v[J], hvv)
@test hv ≈ hv2[K]
hprod!(nlp, xv, vv, hvv)
@test hv ≈ hv2[K]
hprod!(nlp, xv, vv, hv)
@test hv ≈ hv2[K]
for P in Cidxs
yv = @view y[P]
@test hprod(nlp, x[I], y[P], v[J]) ≈ hprod(nlp, xv, yv, vv)
hprod!(nlp, x[I], y[P], v[J], hv)
hprod!(nlp, x[I], y[P], v[J], hvv)
@test hv ≈ hv2[K]
hprod!(nlp, xv, yv, vv, hvv)
@test hv ≈ hv2[K]
hprod!(nlp, xv, yv, vv, hv)
@test hv ≈ hv2[K]
end
end
if hprod ∉ exclude
hvv = @view hv2[I]
hprod!(nlp, x, v, hvv)
hprod!(nlp, x, v, hv)
@test hv ≈ hv2[I]
if m > 0
hprod!(nlp, x, y, v, hvv)
hprod!(nlp, x, y, v, hv)
@test hv ≈ hv2[I]
end
end
end
Expand Down
104 changes: 32 additions & 72 deletions src/nls/view-subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ Check that the API work with views, and that the results is correct.
"""
function view_subarray_nls(nls; exclude = [])
@testset "Test view subarray of NLSs" begin
n, ne = nls.meta.nvar, nls.nls_meta.nequ
N = 2n
Vidxs = [1:n, n .+ (1:n), 1:2:N, collect(N:-2:1)]
N = 2ne
Fidxs = [1:ne, ne .+ (1:ne), 1:2:N, collect(N:-2:1)]
n, ne = nls.meta.nvar, nls.nls_meta.nequ

# Inputs
x = [-(-1.1)^i for i = 1:(2n)] # Instead of [1, -1, …], because it needs to
v = [-(-1.1)^i for i = 1:(2n)] # access different parts of the vector and
y = [-(-1.1)^i for i = 1:(2ne)] # make a difference
x = [-(-1.1)^i for i = 1:n] # Instead of [1, -1, …], because it needs to
v = [-(-1.1)^i for i = 1:n] # access different parts of the vector and
y = [-(-1.1)^i for i = 1:ne] # make a difference

# Outputs
F = zeros(ne)
Expand All @@ -28,74 +24,38 @@ function view_subarray_nls(nls; exclude = [])
hv = zeros(n)
hv2 = zeros(2n)

for I in Vidxs
xv = @view x[I]
for foo in setdiff([residual, jac_residual], exclude)
@test foo(nls, x[I]) ≈ foo(nls, xv)
end

# Inplace methods can have input and output as view, so 4 possibilities
if residual ∉ exclude
for J in Fidxs
Fv = @view F2[J]
residual!(nls, x[I], F)
residual!(nls, x[I], Fv)
@test F ≈ F2[J]
residual!(nls, xv, Fv)
@test F ≈ F2[J]
residual!(nls, xv, F)
@test F ≈ F2[J]
end
end
# Vidxs = [1:n, n .+ (1:n), 1:2:N, collect(N:-2:1)]
I = collect(2n:-2:1)
# Fidxs = [1:ne, ne .+ (1:ne), 1:2:N, collect(N:-2:1)]
J = ne .+ (1:ne)

if jprod_residual ∉ exclude
for J in Fidxs, K in Vidxs
vv = @view v[K]
jvv = @view jv2[J]
@test jprod_residual(nls, x[I], v[K]) ≈ jprod_residual(nls, xv, vv)
jprod_residual!(nls, x[I], v[K], jv)
jprod_residual!(nls, x[I], v[K], jvv)
@test jv ≈ jv2[J]
jprod_residual!(nls, xv, vv, jvv)
@test jv ≈ jv2[J]
jprod_residual!(nls, xv, vv, jv)
@test jv ≈ jv2[J]
end
end
if residual ∉ exclude
Fv = @view F2[J]
residual!(nls, x, Fv)
residual!(nls, x, F)
@test F ≈ F2[J]
end

if jtprod_residual ∉ exclude
for J in Fidxs, K in Vidxs
yv = @view y[J]
jtyv = @view jty2[K]
@test jtprod_residual(nls, x[I], y[J]) ≈ jtprod_residual(nls, xv, yv)
jtprod_residual!(nls, x[I], y[J], jty)
jtprod_residual!(nls, x[I], y[J], jtyv)
@test jty ≈ jty2[K]
jtprod_residual!(nls, xv, yv, jtyv)
@test jty ≈ jty2[K]
jtprod_residual!(nls, xv, yv, jty)
@test jty ≈ jty2[K]
end
end
if jprod_residual ∉ exclude
jvv = @view jv2[J]
jprod_residual!(nls, x, v, jvv)
jprod_residual!(nls, x, v, jv)
@test jv ≈ jv2[J]
end

for i = 1:ne
@test jth_hess_residual ∈ exclude ||
jth_hess_residual(nls, x[I], i) ≈ jth_hess_residual(nls, xv, i)
if jtprod_residual ∉ exclude
jtyv = @view jty2[I]
jtprod_residual!(nls, x, y, jtyv)
jtprod_residual!(nls, x, y, jty)
@test jty ≈ jty2[I]
end

if hprod_residual ∉ exclude
for J in Vidxs, K in Vidxs
vv = @view v[J]
hvv = @view hv2[K]
@test hprod_residual(nls, x[I], i, v[J]) ≈ hprod_residual(nls, xv, i, vv)
hprod_residual!(nls, x[I], i, v[J], hv)
hprod_residual!(nls, x[I], i, v[J], hvv)
@test hv ≈ hv2[K]
hprod_residual!(nls, xv, i, vv, hvv)
@test hv ≈ hv2[K]
hprod_residual!(nls, xv, i, vv, hv)
@test hv ≈ hv2[K]
end
end
for i = 1:ne
if hprod_residual ∉ exclude
hvv = @view hv2[I]
hprod_residual!(nls, x, i, v, hvv)
hprod_residual!(nls, x, i, v, hv)
@test hv ≈ hv2[I]
end
end
end
Expand Down