Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Handle adjoints for LinearSolve
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 5305eed commit c45432f
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/)

[![CI](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml)
[![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu)](https://buildkite.com/julialang/lux-dot-jl)
[![Buildkite NVIDIA GPU CI](https://img.shields.io/buildkite/0f7d50856fb9a3a52ec010723b16710fb0bb57110b60cc3078.svg?label=gpu&logo=nvidia)](https://buildkite.com/julialang/batchedroutines-dot-jl/)
[![codecov](https://codecov.io/gh/LuxDL/BatchedRoutines.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/BatchedRoutines.jl)
[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/BatchedRoutines)](https://pkgs.genieframework.com?packages=BatchedRoutines)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
Expand Down
27 changes: 22 additions & 5 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# LU Factorization
@concrete struct CuBatchedLU{T} <: AbstractBatchedMatrixFactorization
@concrete struct CuBatchedLU{T} <: AbstractBatchedMatrixFactorization{T}
factors
pivot_array
info
size
end

const AdjCuBatchedLU{T} = LinearAlgebra.AdjointFactorization{T, <:CuBatchedLU{T}}
const TransCuBatchedLU{T} = LinearAlgebra.TransposeFactorization{T, <:CuBatchedLU{T}}
const AdjOrTransCuBatchedLU{T} = Union{AdjCuBatchedLU{T}, TransCuBatchedLU{T}}

const AllCuBatchedLU{T} = Union{CuBatchedLU{T}, AdjOrTransCuBatchedLU{T}}

BatchedRoutines.nbatches(LU::CuBatchedLU) = nbatches(LU.factors)
function BatchedRoutines.batchview(LU::CuBatchedLU)
return zip(batchview(LU.factors), batchview(LU.pivot_array), LU.info)
Expand All @@ -15,7 +21,6 @@ function BatchedRoutines.batchview(LU::CuBatchedLU, idx::Int)
end
Base.size(LU::CuBatchedLU) = LU.size
Base.size(LU::CuBatchedLU, i::Integer) = LU.size[i]
Base.eltype(::CuBatchedLU{T}) where {T} = T

function Base.show(io::IO, LU::CuBatchedLU)
return print(io, "CuBatchedLU() with Batch Count: $(nbatches(LU))")
Expand Down Expand Up @@ -43,24 +48,35 @@ function LinearAlgebra.ldiv!(A::CuBatchedLU, b::CuMatrix)
return b
end

function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedLU, b::CuMatrix)
function LinearAlgebra.ldiv!(A::AdjOrTransCuBatchedLU, b::CuMatrix)
@assert nbatches(A) == nbatches(b)
getrs_strided_batched!('T', parent(A).factors, parent(A).pivot_array, b)
return b
end

function LinearAlgebra.ldiv!(X::CuMatrix, A::AllCuBatchedLU, b::CuMatrix)
copyto!(X, b)
return LinearAlgebra.ldiv!(A, X)
end

# QR Factorization
@concrete struct CuBatchedQR{T} <: AbstractBatchedMatrixFactorization
@concrete struct CuBatchedQR{T} <: AbstractBatchedMatrixFactorization{T}
factors
τ
size
end

const AdjCuBatchedQR{T} = LinearAlgebra.AdjointFactorization{T, <:CuBatchedQR{T}}
const TransCuBatchedQR{T} = LinearAlgebra.TransposeFactorization{T, <:CuBatchedQR{T}}
const AdjOrTransCuBatchedQR{T} = Union{AdjCuBatchedQR{T}, TransCuBatchedQR{T}}

const AllCuBatchedQR{T} = Union{CuBatchedQR{T}, AdjOrTransCuBatchedQR{T}}

BatchedRoutines.nbatches(QR::CuBatchedQR) = length(QR.factors)
BatchedRoutines.batchview(QR::CuBatchedQR) = zip(QR.factors, QR.τ)
BatchedRoutines.batchview(QR::CuBatchedQR, idx::Int) = QR.factors[idx], QR.τ[idx]
Base.size(QR::CuBatchedQR) = QR.size
Base.size(QR::CuBatchedQR, i::Integer) = QR.size[i]
Base.eltype(::CuBatchedQR{T}) where {T} = T

function Base.show(io::IO, QR::CuBatchedQR)
return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))")
Expand All @@ -75,6 +91,7 @@ function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...
return CuBatchedQR{eltype(A)}(factors, τ, size(A))
end

# TODO: Handle Adjoint and Transpose for QR
function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix)
@assert nbatches(A) == nbatches(b)
(; τ, factors) = A
Expand Down
49 changes: 38 additions & 11 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ Base.@propagate_inbounds function Base.setindex!(
end

Base.@propagate_inbounds function Base.setindex!(A::UniformBlockDiagonalMatrix, v, idx::Int)
@show size(A)
return setindex!(A, v, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1)
end

Expand All @@ -130,6 +129,8 @@ function Base.Matrix(A::UniformBlockDiagonalMatrix)
return M
end

Base.Array(A::UniformBlockDiagonalMatrix) = Matrix(A)

Base.collect(A::UniformBlockDiagonalMatrix) = Matrix(A)

function Base.similar(A::UniformBlockDiagonalMatrix, ::Type{T}) where {T}
Expand Down Expand Up @@ -265,36 +266,59 @@ function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T
end

# LinearAlgebra
abstract type AbstractBatchedMatrixFactorization end
abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end

const AdjointAbstractBatchedMatrixFactorization{T} = LinearAlgebra.AdjointFactorization{
T, <:AbstractBatchedMatrixFactorization{T}}
const TransposeAbstractBatchedMatrixFactorization{T} = LinearAlgebra.TransposeFactorization{
T, <:AbstractBatchedMatrixFactorization{T}}
const AdjointOrTransposeAbstractBatchedMatrixFactorization{T} = Union{
AdjointAbstractBatchedMatrixFactorization{T},
TransposeAbstractBatchedMatrixFactorization{T}}

const AllAbstractBatchedMatrixFactorization{T} = Union{
AbstractBatchedMatrixFactorization{T},
AdjointOrTransposeAbstractBatchedMatrixFactorization{T}}

nbatches(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = nbatches(parent(f))
batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = batchview(parent(f))
function batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization, idx::Int)
return batchview(parent(f), idx)
end

function LinearAlgebra.ldiv!(A::AbstractBatchedMatrixFactorization, b::AbstractVector)
return LinearAlgebra.ldiv!(A, reshape(b, :, nbatches(A)))
function LinearAlgebra.ldiv!(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
LinearAlgebra.ldiv!(A, reshape(b, :, nbatches(A)))
return b
end

function LinearAlgebra.ldiv!(
X::AbstractVector, A::AbstractBatchedMatrixFactorization, b::AbstractVector)
X::AbstractVector, A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
LinearAlgebra.ldiv!(reshape(X, :, nbatches(A)), A, reshape(b, :, nbatches(A)))
return X
end

function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractVector)
function LinearAlgebra.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1))
LinearAlgebra.ldiv!(X, A, b)
return X
end

function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractMatrix)
function LinearAlgebra.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractMatrix)
X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1))
LinearAlgebra.ldiv!(X, A, vec(b))
return reshape(X, :, nbatches(b))
end

struct GenericBatchedFactorization{A, F} <: AbstractBatchedMatrixFactorization
struct GenericBatchedFactorization{T, A, F} <: AbstractBatchedMatrixFactorization{T}
alg::A
fact::Vector{F}

function GenericBatchedFactorization(alg::A, fact::Vector{F}) where {A, F}
return new{A, F}(alg, fact)
function GenericBatchedFactorization(alg, fact)
return GenericBatchedFactorization{eltype(first(fact))}(alg, fact)
end

function GenericBatchedFactorization{T}(alg::A, fact::Vector{F}) where {T, A, F}
return new{T, A, F}(alg, fact)
end
end

Expand All @@ -305,12 +329,15 @@ Base.size(F::GenericBatchedFactorization) = size(first(F.fact)) .* length(F.fact
function Base.size(F::GenericBatchedFactorization, i::Integer)
return size(first(F.fact), i) * length(F.fact)
end
Base.eltype(F::GenericBatchedFactorization) = eltype(first(F.fact))

function LinearAlgebra.issuccess(fact::GenericBatchedFactorization)
return all(LinearAlgebra.issuccess, fact.fact)
end

function Base.adjoint(fact::GenericBatchedFactorization{T}) where {T}
return GenericBatchedFactorization{T}(fact.alg, adjoint.(fact.fact))
end

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::GenericBatchedFactorization)
println(io, "GenericBatchedFactorization() with Batch Count: $(nbatches(F))")
Base.printstyled(io, "Factorization Function: "; color=:green)
Expand Down
7 changes: 4 additions & 3 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ end
@test Array(gs_fwddiff_x)Array(gs_rdiff[1]) atol=atol rtol=rtol
@test Array(gs_fwddiff_p)Array(gs_rdiff[2]) atol=atol rtol=rtol

__f1 = x -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p))
__f1 = x -> sum(abs2,
batched_gradient(backend, Base.Fix2(simple_batched_function, p), x))
__f2 = x -> sum(abs2,
batched_gradient(backend, simple_batched_function, x, Array(p)))
batched_gradient(
backend, Base.Fix2(simple_batched_function, Array(p)), x))

gs_zyg_x = only(Zygote.gradient(__f1, X))
gs_rdiff_x = ReverseDiff.gradient(__f2, Array(X))
Expand Down
30 changes: 28 additions & 2 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testitem "LinearSolve" setup=[SharedTestSetup] begin
using LinearSolve
using FiniteDiff, LinearSolve, Zygote

rng = get_stable_rng(1001)

Expand All @@ -8,7 +8,7 @@
A1 = UniformBlockDiagonalMatrix(rand(rng, dims...)) |> dev
A2 = Matrix(A1) |> dev
b = rand(rng, size(A1, 1)) |> dev

prob1 = LinearProblem(A1, b)
prob2 = LinearProblem(A2, b)

Expand All @@ -22,6 +22,32 @@
x1 = solve(prob1, solver)
x2 = solve(prob2, solver)
@test x1.u x2.u

test_adjoint = function (A, b)
sol = solve(LinearProblem(A, b), solver)
return sum(abs2, sol.u)
end

dims[1] != dims[2] && continue

∂A_fd = FiniteDiff.finite_difference_gradient(
x -> test_adjoint(x, Array(b)), Array(A1))
∂b_fd = FiniteDiff.finite_difference_gradient(
x -> test_adjoint(Array(A1), x), Array(b))

if solver isa QRFactorization && ongpu
@test_broken begin
∂A, ∂b = Zygote.gradient(test_adjoint, A1, b)

@test Array(∂A)∂A_fd atol=1e-1 rtol=1e-1
@test Array(∂b)∂b_fd atol=1e-1 rtol=1e-1
end
else
∂A, ∂b = Zygote.gradient(test_adjoint, A1, b)

@test Array(∂A)∂A_fd atol=1e-1 rtol=1e-1
@test Array(∂b)∂b_fd atol=1e-1 rtol=1e-1
end
end
end
end
Expand Down

0 comments on commit c45432f

Please sign in to comment.