Skip to content

Commit

Permalink
Remove buffer pool for tiled GEMM (#42309)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Sep 20, 2021
1 parent 1843201 commit 6893f21
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 12 deletions.
3 changes: 0 additions & 3 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,6 @@ function __init__()
BLAS.lbt_forward(liblapack_path)
end
BLAS.check()
Threads.resize_nthreads!(Abuf)
Threads.resize_nthreads!(Bbuf)
Threads.resize_nthreads!(Cbuf)
catch ex
Base.showerror_nostdio(ex, "WARNING: Error during initialization of module LinearAlgebra")
end
Expand Down
12 changes: 3 additions & 9 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,6 @@ function generic_matmatmul(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
end

const tilebufsize = 10800 # Approximately 32k/3
# per-thread arrays of buffers resized by __init__ if needed
const Abuf = [Vector{UInt8}(undef, tilebufsize)]
const Bbuf = [Vector{UInt8}(undef, tilebufsize)]
const Cbuf = [Vector{UInt8}(undef, tilebufsize)]

function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul=MulAddMul())
Expand Down Expand Up @@ -775,9 +771,8 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat
@inbounds begin
if tile_size > 0
sz = (tile_size, tile_size)
# FIXME: This code is completely invalid!!!
Atile = unsafe_wrap(Array, convert(Ptr{T}, pointer(Abuf[Threads.threadid()])), sz)
Btile = unsafe_wrap(Array, convert(Ptr{S}, pointer(Bbuf[Threads.threadid()])), sz)
Atile = Array{T}(undef, sz)
Btile = Array{S}(undef, sz)

z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
z = convert(promote_type(typeof(z1), R), z1)
Expand All @@ -797,8 +792,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat
end
end
else
# FIXME: This code is completely invalid!!!
Ctile = unsafe_wrap(Array, convert(Ptr{R}, pointer(Cbuf[Threads.threadid()])), sz)
Ctile = Array{R}(undef, sz)
for jb = 1:tile_size:nB
jlim = min(jb+tile_size-1,nB)
jlen = jlim-jb+1
Expand Down

0 comments on commit 6893f21

Please sign in to comment.