Skip to content

Commit

Permalink
Merge pull request #58 from invenia/mz/mul
Browse files Browse the repository at this point in the history
rrule for BlockDiagonal * Vector multiplication
  • Loading branch information
mzgubic authored Feb 12, 2021
2 parents 57d8490 + 7ce3448 commit fda7259
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/Manifest.toml
docs/build/
dev/*
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.13"
version = "0.1.14"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.6"
ChainRulesTestUtils = "0.6.3"
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10"
julia = "1"

Expand Down
35 changes: 35 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,38 @@ function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagon
return Matrix(B), Matrix_pullback
end

# multiplication
function ChainRulesCore.rrule(
::typeof(*),
bm::BlockDiagonal{T, V},
v::StridedVector{T}
) where {T<:Union{Real, Complex}, V<:Matrix{T}}

y = bm * v

# needed for computing Δ * v' blockwise
nrows = size.(bm.blocks, 1)
ncols = size.(bm.blocks, 2)
row_idxs = cumsum(nrows) .- nrows .+ 1
col_idxs = cumsum(ncols) .- ncols .+ 1

function bm_vector_mul_pullback(Δ)
Δblocks = map(eachindex(nrows)) do i
block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1)
block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1)
return InplaceableThunk(
@thunk(Δ[block_rows] * v[block_cols]'),
-> mul!(X̄, Δ[block_rows], v[block_cols]', true, true)
)
end
return (
NO_FIELDS,
Composite{BlockDiagonal{T, V}}(;blocks=Δblocks),
InplaceableThunk(
@thunk(bm' * Δ),
-> mul!(X̄, bm', Δ, true, true)
),
)
end
return y, bm_vector_mul_pullback
end
6 changes: 0 additions & 6 deletions test/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@ using BlockDiagonals: isequal_blocksizes
using Random
using Test

function FiniteDifferences.to_vec(X::BlockDiagonal)
x, blocks_from_vec = to_vec(X.blocks)
BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec))
return x, BlockDiagonal_from_vec
end

@testset "blockdiagonal.jl" begin
rng = MersenneTwister(123456)
N1, N2, N3 = 3, 4, 5
Expand Down
6 changes: 6 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
test_rrule(Matrix, D)
end

@testset "BlockDiagonal * Vector" begin
D = BlockDiagonal([rand(2, 3), rand(3, 3)])
v = rand(6)
test_rrule(*, D, v)
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ using FiniteDifferences # For overloading to_vec
using Test
using LinearAlgebra

function FiniteDifferences.to_vec(X::BlockDiagonal)
x, blocks_from_vec = to_vec(X.blocks)
BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec))
return x, BlockDiagonal_from_vec
end

@testset "BlockDiagonals" begin
# The doctests fail on x86, so only run them on 64-bit hardware
Sys.WORD_SIZE == 64 && doctest(BlockDiagonals)
Expand Down

0 comments on commit fda7259

Please sign in to comment.