From 527b37075f570c457012a7b09cc1ca505a7f0cae Mon Sep 17 00:00:00 2001 From: Wolfhart Feldmeier Date: Fri, 26 Jan 2024 18:35:30 +0100 Subject: [PATCH] refactor batching interfaces change from vec of array to n+1-dimensional arrays. (much nicer for CUDA) --- Project.toml | 1 + src/DiffPointRasterisation.jl | 1 + src/rasterise.jl | 212 +++++++++++++++++++++------------- test/runtests.jl | 8 +- 4 files changed, 137 insertions(+), 85 deletions(-) diff --git a/Project.toml b/Project.toml index 6bee28b..359ad3b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Wolfhart Feldmeier "] version = "1.0.0-DEV" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" diff --git a/src/DiffPointRasterisation.jl b/src/DiffPointRasterisation.jl index 9dd6d00..a78a7b7 100644 --- a/src/DiffPointRasterisation.jl +++ b/src/DiffPointRasterisation.jl @@ -1,5 +1,6 @@ module DiffPointRasterisation +using ArgCheck using ChunkSplitters using FillArrays using SimpleUnPack diff --git a/src/rasterise.jl b/src/rasterise.jl index 1697fdb..13b2b3a 100644 --- a/src/rasterise.jl +++ b/src/rasterise.jl @@ -19,15 +19,33 @@ pixels/voxels of the output array (according to the closeness of the voxel center to the coordinates of point ``\\hat{p}``) via N-linear interpolation. """ +function raster end + raster( grid_size, points::AbstractMatrix{T}, + rotation::AbstractMatrix{<:Number}, + translation, + background=zero(T), + weight=one(T), +) where {T} = raster!( + similar(points, grid_size), + points, rotation, translation, - background=isa(rotation, AbstractMatrix{<:Number}) ? zero(T) : Zeros(T, length(rotation)), - weight=isa(rotation, AbstractMatrix{<:Number}) ? one(T) : Ones(T, length(rotation)), + background, + weight, +) + +raster( + grid_size, + points::AbstractMatrix{T}, + rotation::AbstractArray{<:Number, 3}, + translation, + background=Zeros(T, size(rotation, 3)), + weight=Ones(T, size(rotation, 3)), ) where {T} = raster!( - isa(rotation, AbstractMatrix{<:Number}) ? similar(points, grid_size) : [similar(points, grid_size) for _ in 1:length(rotation)], + similar(points, (grid_size..., size(rotation, 3))), points, rotation, translation, @@ -110,15 +128,33 @@ pixels/voxels of the output array (according to the closeness of the voxel center to the coordinates of point ``\\hat{p}``) via N-1-linear interpolation. """ +function raster_project end + raster_project( grid_size, points::AbstractMatrix{T}, + rotation::AbstractMatrix{<:Number}, + translation, + background=zero(T), + weight=one(T), +) where {T} = raster_project!( + similar(points, grid_size), + points, rotation, translation, - background=isa(rotation, AbstractMatrix{<:Number}) ? zero(T) : Zeros(T, length(rotation)), - weight=isa(rotation, AbstractMatrix{<:Number}) ? one(T) : Ones(T, length(rotation)), + background, + weight, +) + +raster_project( + grid_size, + points::AbstractMatrix{T}, + rotation::AbstractArray{<:Number, 3}, + translation, + background=Zeros(T, size(rotation, 3)), + weight=Ones(T, size(rotation, 3)), ) where {T} = raster_project!( - isa(rotation, AbstractMatrix{<:Number}) ? similar(points, grid_size) : [similar(points, grid_size) for _ in 1:length(rotation)], + similar(points, (grid_size..., size(rotation, 3))), points, rotation, translation, @@ -135,14 +171,14 @@ Inplace version of `raster`. Write output into `out` and return `out`. """ raster!( - out::Union{AbstractArray{T, N_out}, AbstractVector{<:AbstractArray{T, N_out}}}, + out::AbstractArray{T, N_out}, points, - rotation, + rotation::AbstractArray{<:Number, N_rotation}, translation, - background=isa(out, AbstractArray{T}) ? zero(T) : Zeros(T, length(out)), - weight=isa(out, AbstractArray{T}) ? one(T) : Ones(T, length(out)), -) where {N_out, T<:Number} = _raster!( - Val(N_out), + background=N_rotation == 2 ? zero(T) : Zeros(T, size(rotation, 3)), + weight=N_rotation == 2 ? one(T) : Ones(T, size(rotation, 3)), +) where {N_out, N_rotation, T<:Number} = _raster!( + Val(N_out - (N_rotation - 2)), out, points, rotation, @@ -171,14 +207,14 @@ Inplace version of `raster_project`. Write output into `out` and return `out`. """ raster_project!( - out::Union{AbstractArray{T, N_out}, AbstractVector{<:AbstractArray{T, N_out}}}, + out::AbstractArray{T, N_out}, points, - rotation, + rotation::AbstractArray{<:Number, N_rotation}, translation, - background=isa(out, AbstractArray{T}) ? zero(T) : Zeros(T, length(out)), - weight=isa(out, AbstractArray{T}) ? one(T) : Ones(T, length(out)), -) where {N_out, T<:Number} = _raster!( - Val(N_out + 1), + background=N_rotation == 2 ? zero(T) : Zeros(T, size(rotation, 3)), + weight=N_rotation == 2 ? one(T) : Ones(T, size(rotation, 3)), +) where {N_out, N_rotation, T<:Number} = _raster!( + Val(N_out + 1- (N_rotation - 2)), out, points, rotation, @@ -208,8 +244,8 @@ function _raster!( background::Number, weight::Number, ) where {N_in, N_out, T} - @assert size(points, 1) == size(rotation, 1) == size(rotation, 2) == N_in - @assert length(translation) == N_out + @argcheck size(points, 1) == size(rotation, 1) == size(rotation, 2) == N_in + @argcheck length(translation) == N_out fill!(out, background) origin = (-@SVector ones(T, N_out)) - translation @@ -248,16 +284,19 @@ end function _raster!( ::Val{N_in}, - out::AbstractVector{<:AbstractArray}, + out::AbstractArray{<:Number}, points::AbstractMatrix{<:Number}, - rotation::AbstractVector{<:AbstractMatrix{<:Number}}, - translation::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractArray{<:Number, 3}, + translation::AbstractMatrix{<:Number}, background::AbstractVector{<:Number}, weight::AbstractVector{<:Number}, ) where {N_in} - Threads.@threads for (idxs, ichunk) in chunks(eachindex(out, rotation, translation, background, weight), Threads.nthreads()) + out_batch_dim = ndims(out) + @argcheck axes(out, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1) + batch_axis = axes(out, out_batch_dim) + Threads.@threads for (idxs, ichunk) in chunks(batch_axis, Threads.nthreads()) for i in idxs - _raster!(Val(N_in), out[i], points, rotation[i], translation[i], background[i], weight[i]) + _raster!(Val(N_in), selectdim(out, out_batch_dim, i), points, view(rotation, :, :, i), view(translation, :, i), background[i], weight[i]) end end out @@ -269,16 +308,16 @@ end batch_size = batch_size_for_test() - out = [zeros(8, 8, 8) for _ in 1:batch_size] - out_threaded = [zeros(8, 8, 8) for _ in 1:batch_size] + out = zeros(8, 8, 8, batch_size) + out_threaded = zeros(8, 8, 8, batch_size) points = 0.3 .* randn(3, 10) - rotation = [rand(QuatRotation) for _ in 1:batch_size] - translation = [zeros(3) for _ in 1:batch_size] + rotation = stack(rand(QuatRotation, batch_size)) + translation = zeros(3, batch_size) background = zeros(batch_size) weight = ones(batch_size) - for i in 1:batch_size - raster!(out[i], points, rotation[i], translation[i], background[i], weight[i]) + for (out_i, args...) in zip(eachslice(out, dims=4), eachslice(rotation, dims=3), eachcol(translation), background, weight) + raster!(out_i, points, args...) end DiffPointRasterisation.raster!(out_threaded, points, rotation, translation, background, weight) @@ -292,16 +331,16 @@ end batch_size = batch_size_for_test() - out = [zeros(16, 16) for _ in 1:batch_size] - out_threaded = [zeros(16, 16) for _ in 1:batch_size] + out = zeros(16, 16, batch_size) + out_threaded = zeros(16, 16, batch_size) points = 0.3 .* randn(3, 10) - rotation = [rand(QuatRotation) for _ in 1:batch_size] - translation = [zeros(2) for _ in 1:batch_size] + rotation = stack(rand(QuatRotation, batch_size)) + translation = zeros(2, batch_size) background = zeros(batch_size) weight = ones(batch_size) - for i in 1:batch_size - DiffPointRasterisation.raster_project!(out[i], points, rotation[i], translation[i], background[i], weight[i]) + for (out_i, args...) in zip(eachslice(out, dims=3), eachslice(rotation, dims=3), eachcol(translation), background, weight) + DiffPointRasterisation.raster_project!(out_i, points, args...) end DiffPointRasterisation.raster_project!(out_threaded, points, rotation, translation, background, weight) @@ -329,12 +368,16 @@ specified as `ds_d\$INPUT_NAME`, e.g. `ds_dtranslation = [zeros(2) for _ in 1:8] for 2-dimensional points and a batch size of 8. """ raster_pullback!( - ds_dout::Union{AbstractArray{<:Number, N_out}, AbstractVector{<:AbstractArray{<:Number, N_out}}}, + ds_dout::AbstractArray{<:Number, N_out}, + points, + rotation::AbstractArray{<:Number, N_rotation}, args...; prealloc... -) where {N_out} = _raster_pullback!( - Val(N_out), +) where {N_out, N_rotation} = _raster_pullback!( + Val(N_out - (N_rotation - 2)), ds_dout, + points, + rotation, args...; prealloc... ) @@ -360,12 +403,16 @@ specified as `ds_d\$INPUT_NAME`, e.g. `ds_dtranslation = [zeros(2) for _ in 1:8] for 3-dimensional points and a batch size of 8. """ raster_project_pullback!( - ds_dout::Union{AbstractArray{<:Number, N_out}, AbstractVector{<:AbstractArray{<:Number, N_out}}}, + ds_dout::AbstractArray{<:Number, N_out}, + points, + rotation::AbstractArray{<:Number, N_rotation}, args...; prealloc... -) where {N_out} = _raster_pullback!( - Val(N_out + 1), +) where {N_out, N_rotation} = _raster_pullback!( + Val(N_out + 1 - (N_rotation - 2)), ds_dout, + points, + rotation, args...; prealloc... ) @@ -459,33 +506,35 @@ end function _raster_pullback!( ::Val{N_in}, - ds_dout::AbstractVector{<:AbstractArray{T, N_out}}, + ds_dout::AbstractArray{T}, points::AbstractMatrix{<:Number}, - rotation::AbstractVector{<:AbstractMatrix{<:Number}}, - translation::AbstractVector{<:AbstractVector{<:Number}}, + rotation::AbstractArray{<:Number, 3}, + translation::AbstractMatrix{<:Number}, # TODO: for some reason type inference fails if the following # two arrays are FillArrays... - background::AbstractVector{<:Number}=zeros(T, length(rotation)), - weight::AbstractVector{<:Number}=ones(T, length(rotation)); + background::AbstractVector{<:Number}=zeros(T, size(rotation, 3)), + weight::AbstractVector{<:Number}=ones(T, size(rotation, 3)); prealloc... -) where {N_in, N_out, T} +) where {N_in, T} + out_batch_dim = ndims(ds_dout) + batch_axis = axes(ds_dout, out_batch_dim) + @argcheck axes(ds_dout, out_batch_dim) == axes(rotation, 3) == axes(translation, 2) == axes(background, 1) == axes(weight, 1) args = (;points, rotation, translation, background, weight) - batch_size = length(translation) - @unpack ds_dpoints, ds_drotation, ds_dtranslation, ds_dbackground, ds_dweight = _pullback_alloc_threaded(args, NamedTuple(prealloc), min(batch_size, Threads.nthreads())) - @assert isa(ds_dpoints, AbstractVector{<:AbstractMatrix{<:Number}}) + @unpack ds_dpoints, ds_drotation, ds_dtranslation, ds_dbackground, ds_dweight = _pullback_alloc_threaded(args, NamedTuple(prealloc), min(length(batch_axis), Threads.nthreads())) + @assert ndims(ds_dpoints) == 3 + fill!(ds_dpoints, zero(T)) - Threads.@threads for (idxs, ichunk) in chunks(eachindex(ds_dout, rotation, translation, background, weight), length(ds_dpoints)) - fill!(ds_dpoints[ichunk], zero(T)) + Threads.@threads for (idxs, ichunk) in chunks(batch_axis, size(ds_dpoints, 3)) for i in idxs - args_i = (ds_dout[i], points, rotation[i], translation[i], background[i], weight[i]) - result_i = _raster_pullback!(Val(N_in), args_i...; accumulate_prealloc=true, points=ds_dpoints[ichunk]) - ds_drotation[i] .= result_i.rotation - ds_dtranslation[i] = result_i.translation + args_i = (selectdim(ds_dout, out_batch_dim, i), points, view(rotation, :, :, i), view(translation, :, i), background[i], weight[i]) + result_i = _raster_pullback!(Val(N_in), args_i...; accumulate_prealloc=true, points=view(ds_dpoints, :, :, ichunk)) + ds_drotation[:, :, i] .= result_i.rotation + ds_dtranslation[:, i] = result_i.translation ds_dbackground[i] = result_i.background ds_dweight[i] = result_i.weight end end - return (; points=sum(ds_dpoints), rotation=ds_drotation, translation=ds_dtranslation, background=ds_dbackground, weight=ds_dweight) + return (; points=dropdims(sum(ds_dpoints; dims=3); dims=3), rotation=ds_drotation, translation=ds_dtranslation, background=ds_dbackground, weight=ds_dweight) end @testitem "raster_pullback! threaded" begin @@ -494,10 +543,10 @@ end batch_size = batch_size_for_test() - ds_dout = [randn(8, 8, 8) for _ in 1:batch_size] + ds_dout = zeros(8, 8, 8, batch_size) points = 0.3 .* randn(3, 10) - rotation = [rand(QuatRotation) for _ in 1:batch_size] - translation = [zeros(3) for _ in 1:batch_size] + rotation = stack(rand(QuatRotation, batch_size)) + translation = zeros(3, batch_size) background = zeros(batch_size) weight = ones(batch_size) @@ -505,12 +554,14 @@ end ds_dpoints = Matrix{Float64}[] for i in 1:batch_size - ds_dargs_i = raster_pullback!(ds_dout[i], points, rotation[i], translation[i], background[i], weight[i]) + ds_dargs_i = @views raster_pullback!(ds_dout[:, :, :, i], points, rotation[:, :, i], translation[:, i], background[i], weight[i]) push!(ds_dpoints, ds_dargs_i.points) - @test ds_dargs_threaded.rotation[i] ≈ ds_dargs_i.rotation - @test ds_dargs_threaded.translation[i] ≈ ds_dargs_i.translation - @test ds_dargs_threaded.background[i] ≈ ds_dargs_i.background - @test ds_dargs_threaded.weight[i] ≈ ds_dargs_i.weight + @views begin + @test ds_dargs_threaded.rotation[:, :, i] ≈ ds_dargs_i.rotation + @test ds_dargs_threaded.translation[:, i] ≈ ds_dargs_i.translation + @test ds_dargs_threaded.background[i] ≈ ds_dargs_i.background + @test ds_dargs_threaded.weight[i] ≈ ds_dargs_i.weight + end end @test ds_dargs_threaded.points ≈ sum(ds_dpoints) end @@ -521,10 +572,10 @@ end batch_size = batch_size_for_test() - ds_dout = [randn(16, 16) for _ in 1:batch_size] + ds_dout = zeros(16, 16, batch_size) points = 0.3 .* randn(3, 10) - rotation = [rand(QuatRotation) for _ in 1:batch_size] - translation = [zeros(2) for _ in 1:batch_size] + rotation = stack(rand(QuatRotation, batch_size)) + translation = zeros(2, batch_size) background = zeros(batch_size) weight = ones(batch_size) @@ -532,12 +583,14 @@ end ds_dpoints = Matrix{Float64}[] for i in 1:batch_size - ds_dargs_i = DiffPointRasterisation.raster_project_pullback!(ds_dout[i], points, rotation[i], translation[i], background[i], weight[i]) + ds_dargs_i = @views raster_project_pullback!(ds_dout[:, :, i], points, rotation[:, :, i], translation[:, i], background[i], weight[i]) push!(ds_dpoints, ds_dargs_i.points) - @test ds_dargs_threaded.rotation[i] ≈ ds_dargs_i.rotation - @test ds_dargs_threaded.translation[i] ≈ ds_dargs_i.translation - @test ds_dargs_threaded.background[i] ≈ ds_dargs_i.background - @test ds_dargs_threaded.weight[i] ≈ ds_dargs_i.weight + @views begin + @test ds_dargs_threaded.rotation[:, :, i] ≈ ds_dargs_i.rotation + @test ds_dargs_threaded.translation[:, i] ≈ ds_dargs_i.translation + @test ds_dargs_threaded.background[i] ≈ ds_dargs_i.background + @test ds_dargs_threaded.weight[i] ≈ ds_dargs_i.weight + end end @test ds_dargs_threaded.points ≈ sum(ds_dpoints) end @@ -553,7 +606,7 @@ end function _pullback_alloc_others_threaded(need_allocation, ::NamedTuple{}) keys_alloc = prefix.(keys(need_allocation)) - vals = make_similar.(values(need_allocation)) + vals = similar.(values(need_allocation)) NamedTuple{keys_alloc}(vals) end @@ -561,7 +614,7 @@ function _pullback_alloc_others_threaded(args, prealloc) # it's a bit tricky to get this type-stable, but the following does the trick need_allocation = Base.structdiff(args, prealloc) keys_alloc = prefix.(keys(need_allocation)) - vals = make_similar.(values(need_allocation)) + vals = similar.(values(need_allocation)) alloc = NamedTuple{keys_alloc}(vals) keys_prealloc = prefix.(keys(prealloc)) prefixed_prealloc = NamedTuple{keys_prealloc}(values(prealloc)) @@ -570,13 +623,10 @@ end _pullback_alloc_points_serial(args, prealloc) = (;ds_dpoints = get(() -> similar(args.points), prealloc, :points)) -_pullback_alloc_points_threaded(args, prealloc, n) = (;ds_dpoints = get(() -> [similar(args.points) for _ in 1:n], prealloc, :points)) +_pullback_alloc_points_threaded(args, prealloc, n) = (;ds_dpoints = get(() -> similar(args.points, (size(args.points)..., n)), prealloc, :points)) prefix(s::Symbol) = Symbol("ds_d" * string(s)) -make_similar(x) = similar(x) -make_similar(x::AbstractVector{<:AbstractArray}) = [similar(element) for element in x] - function interpolation_weight(n, N, deltas, shift) val = @inbounds shift[n] == 1 ? one(eltype(deltas)) : -one(eltype(deltas)) diff --git a/test/runtests.jl b/test/runtests.jl index 4141a22..e68a450 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,9 +36,9 @@ include("../src/testing.jl") batch_size = 2 grid_size = (8, 8, 8) points = 0.3 .* randn(3, 4) - rotation = [Array(rand(RotMatrix3)) for _ in 1:batch_size] - rotation_tangent = [Array(rand(RotMatrix3)) for _ in 1:batch_size] - translation = [0.1 .* randn(3) for _ in 1:batch_size] + rotation = stack(rand(RotMatrix3, batch_size)) + rotation_tangent = stack(rand(RotMatrix3, batch_size)) + translation = 0.1 .* randn(3, batch_size) background = fill(0.1, batch_size) weight = fill(1.0, batch_size) @@ -48,7 +48,7 @@ include("../src/testing.jl") test_rrule(raster, grid_size, points, rotation ⊢ rotation_tangent, translation) grid_size = (8, 8) - translation = [0.1 .* randn(2) for _ in 1:batch_size] + translation = 0.1 .* randn(2, batch_size) test_rrule(raster_project, grid_size, points, rotation ⊢ rotation_tangent, translation, background, weight)