Skip to content

Commit dbeb7c7

Browse files
authored
Merge pull request #217 from JuliaStats/dw/three_arg_mul_ldiv
Rewrite `whiten!` and `unwhiten!` to work around ForwardDiff bug
2 parents 8bd974f + 7331e62 commit dbeb7c7

File tree

6 files changed

+46
-14
lines changed

6 files changed

+46
-14
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PDMats"
22
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
3-
version = "0.11.33"
3+
version = "0.11.34"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -9,16 +9,20 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
99

1010
[compat]
1111
BandedMatrices = "0.15, 1"
12+
FiniteDifferences = "0.12"
13+
ForwardDiff = "0.10, 1"
1214
Random = "<0.0.1, 1"
1315
StaticArrays = "1"
1416
Test = "<0.0.1, 1"
1517
julia = "1.10"
1618

1719
[extras]
1820
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
21+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1923
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2024
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2125
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2226

2327
[targets]
24-
test = ["BandedMatrices", "StaticArrays", "Random", "Test"]
28+
test = ["BandedMatrices", "FiniteDifferences", "ForwardDiff", "StaticArrays", "Random", "Test"]

src/chol.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ function quad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix)
6161
aU = chol_upper(A)
6262
z = similar(r, size(A, 1)) # buffer to save allocations
6363
@inbounds for i in axes(X, 2)
64-
copyto!(z, view(X, :, i))
65-
lmul!(aU, z)
64+
mul!(z, aU, view(X, :, i))
6665
r[i] = sum(abs2, z)
6766
end
6867
return r

src/pdmat.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,20 @@ LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat)))
125125
function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
126126
@check_argdims axes(r) == axes(x)
127127
@check_argdims a.dim == size(x, 1)
128-
v = _rcopy!(r, x)
129-
return ldiv!(chol_lower(cholesky(a)), v)
128+
if r === x
129+
return ldiv!(chol_lower(cholesky(a)), r)
130+
else
131+
return ldiv!(r, chol_lower(cholesky(a)), x)
132+
end
130133
end
131134
function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
132135
@check_argdims axes(r) == axes(x)
133136
@check_argdims a.dim == size(x, 1)
134-
v = _rcopy!(r, x)
135-
return lmul!(chol_lower(cholesky(a)), v)
137+
if r === x
138+
return lmul!(chol_lower(cholesky(a)), r)
139+
else
140+
return mul!(r, chol_lower(cholesky(a)), x)
141+
end
136142
end
137143

138144
function whiten(a::PDMat, x::AbstractVecOrMat)
@@ -162,8 +168,7 @@ function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
162168
aU = chol_upper(cholesky(a))
163169
z = similar(r, a.dim) # buffer to save allocations
164170
@inbounds for i in axes(x, 2)
165-
copyto!(z, view(x, :, i))
166-
lmul!(aU, z)
171+
mul!(z, aU, view(x, :, i))
167172
r[i] = sum(abs2, z)
168173
end
169174
return r

src/utils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ macro check_argdims(cond)
77
end
88
end
99

10-
_rcopy!(r, x) = (r === x || copyto!(r, x); r)
11-
12-
1310
function _addscal!(r::Matrix, a::Matrix, b::Union{Matrix, SparseMatrixCSC}, c::Real)
1411
if c == one(c)
1512
for i in eachindex(a)

test/ad.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using PDMats
2+
using FiniteDifferences
3+
using ForwardDiff
4+
5+
using LinearAlgebra
6+
using Test
7+
8+
# issue #217
9+
@testset "PDMat: (un)whiten" begin
10+
a = vec(Matrix{Float64}(I, 4, 4))
11+
fdm = central_fdm(5, 1)
12+
13+
for (f, f!) in ((whiten, whiten!), (unwhiten, unwhiten!))
14+
apply_f = let f = f
15+
a -> f(PDMat(Symmetric(reshape(a, 4, 4))), ones(4))
16+
end
17+
apply_f! = let f! = f!
18+
a -> f!(Vector{promote_type(eltype(a),Float64)}(undef, 4), PDMat(Symmetric(reshape(a, 4, 4))), ones(4))
19+
end
20+
21+
J = only(FiniteDifferences.jacobian(fdm, apply_f, a))
22+
@test only(FiniteDifferences.jacobian(fdm, apply_f!, a)) J
23+
24+
@test ForwardDiff.jacobian(apply_f, a) J
25+
@test ForwardDiff.jacobian(apply_f!, a) J
26+
end
27+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include("testutils.jl")
2-
tests = ["pdmtypes", "abstracttypes", "addition", "generics", "kron", "chol", "specialarrays", "sqrt"]
2+
tests = ["pdmtypes", "abstracttypes", "addition", "generics", "kron", "chol", "specialarrays", "sqrt", "ad"]
33
println("Running tests ...")
44

55
for t in tests

0 commit comments

Comments
 (0)