diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..ce1243a --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,2 @@ +style = "blue" +always_use_return = false \ No newline at end of file diff --git a/Project.toml b/Project.toml index 07f790a..f4776e0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffPointRasterisation" uuid = "f984992d-3c45-4382-99a1-cf20f5c47c61" authors = ["Wolfhart Feldmeier "] -version = "0.2.0" +version = "0.2.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -9,7 +9,6 @@ Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" @@ -23,22 +22,35 @@ DiffPointRasterisationCUDAExt = "CUDA" DiffPointRasterisationChainRulesCoreExt = "ChainRulesCore" [compat] +Adapt = "4" +Aqua = "0.8" +ArgCheck = "2.3" +Atomix = "0.1" +BenchmarkTools = "1" +CUDA = "5.3" +ChainRulesCore = "1.23" +ChainRulesTestUtils = "1.12" ChunkSplitters = "2" FillArrays = "1.9.3" -Rotations = "1.6.1" +KernelAbstractions = "0.9.18" +Rotations = "1.7" SimpleUnPack = "1.1" StaticArrays = "1.9.1" +Test = "1" +TestItemRunner = "0.2" TestItems = "0.1.1" -julia = "1" +julia = "^1.9" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" [targets] -test = ["Adapt", "BenchmarkTools", "ChainRulesCore", "ChainRulesTestUtils", "CUDA", "Test", "TestItemRunner"] +test = ["Adapt", "Aqua", "BenchmarkTools", "ChainRulesCore", "ChainRulesTestUtils", "CUDA", "Rotations", "Test", "TestItemRunner"] diff --git a/README.md b/README.md index 53f40e3..8d9ec13 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Build Status](https://github.com/microscopic-image-analysis/DiffPointRasterisation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/microscopic-image-analysis/DiffPointRasterisation.jl/actions/workflows/CI.yml?query=branch%3Amain) [![](https://img.shields.io/badge/docs-main-blue.svg)](https://microscopic-image-analysis.github.io/DiffPointRasterisation.jl/dev) + [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) ![](logo.gif) diff --git a/ext/DiffPointRasterisationCUDAExt.jl b/ext/DiffPointRasterisationCUDAExt.jl index da4a1be..74efc52 100644 --- a/ext/DiffPointRasterisationCUDAExt.jl +++ b/ext/DiffPointRasterisationCUDAExt.jl @@ -12,19 +12,16 @@ using ArgCheck using FillArrays using StaticArrays +const CuOrFillArray{T,N} = Union{CuArray{T,N},FillArrays.AbstractFill{T,N}} -const CuOrFillArray{T, N} = Union{CuArray{T, N}, FillArrays.AbstractFill{T, N}} - - -const CuOrFillVector{T} = CuOrFillArray{T, 1} - +const CuOrFillVector{T} = CuOrFillArray{T,1} function raster_pullback_kernel!( ::Type{T}, ds_dout, points::AbstractVector{<:StaticVector{N_in}}, - rotations::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, - translations::AbstractVector{<:StaticVector{N_out, TT}}, + rotations::AbstractVector{<:StaticMatrix{N_out,N_in,TR}}, + translations::AbstractVector{<:StaticVector{N_out,TT}}, out_weights, point_weights, shifts, @@ -35,8 +32,7 @@ function raster_pullback_kernel!( ds_dtranslation, ds_dout_weight, ds_dpoint_weight, - -) where {T, TR, TT, N_in, N_out} +) where {T,TR,TT,N_in,N_out} n_voxel = blockDim().z points_per_workgroup = blockDim().x batchsize_per_workgroup = blockDim().y @@ -74,24 +70,27 @@ function raster_pullback_kernel!( origin = (-@SVector ones(TT, N_out)) - translation coord_reference_voxel, deltas = DiffPointRasterisation.reference_coordinate_and_deltas( - point, - rotation, - origin, - scale, + point, rotation, origin, scale + ) + voxel_idx = CartesianIndex( + CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift), batch_idx ) - voxel_idx = CartesianIndex(CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift), batch_idx) - ds_dweight_local = zero(T) if voxel_idx in CartesianIndices(ds_dout) @inbounds ds_dweight_local = DiffPointRasterisation.voxel_weight( - deltas, - shift, - ds_dout[voxel_idx], + deltas, shift, ds_dout[voxel_idx] ) factor = ds_dout[voxel_idx] * out_weight * point_weight - ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), Val(N_out))) + ds_dcoord_part = SVector( + factor .* ntuple( + n -> DiffPointRasterisation.interpolation_weight( + n, N_out, deltas, shift + ), + Val(N_out), + ), + ) @inbounds ds_dpoint_rot_shared[:, s, b] .= ds_dcoord_part .* scale else @inbounds ds_dpoint_rot_shared[:, s, b] .= zero(T) @@ -136,7 +135,7 @@ function raster_pullback_kernel!( j = 1 while j <= N_in val = coef * point[j] - @inbounds CUDA.@atomic ds_drotation[dim, j, batch_idx] += val + @inbounds CUDA.@atomic ds_drotation[dim, j, batch_idx] += val j += 1 end end @@ -161,7 +160,7 @@ function raster_pullback_kernel!( sync_threads() idx = 2 * stride * (b - 1) + 1 if idx <= batchsize_per_workgroup - dim = s + dim = s while dim <= N_in other_val_p = if idx + stride <= batchsize_per_workgroup ds_dpoint_shared[dim, idx + stride] @@ -181,7 +180,7 @@ function raster_pullback_kernel!( sync_threads() idx = 2 * stride * (thread - 1) + 1 if idx <= n_threads_per_workgroup - other_val_w = if idx + stride <= n_threads_per_workgroup + other_val_w = if idx + stride <= n_threads_per_workgroup ds_dpoint_weight_shared[idx + stride] else zero(T) @@ -207,54 +206,68 @@ function raster_pullback_kernel!( @inbounds CUDA.@atomic ds_dpoint_weight[point_idx] += val_w end - nothing + return nothing end # single image -raster_pullback!( - ds_dout::CuArray{<:Number, N_out}, - points::AbstractVector{<:StaticVector{N_in, <:Number}}, - rotation::StaticMatrix{N_out, N_in, <:Number}, - translation::StaticVector{N_out, <:Number}, +function raster_pullback!( + ds_dout::CuArray{<:Number,N_out}, + points::AbstractVector{<:StaticVector{N_in,<:Number}}, + rotation::StaticMatrix{N_out,N_in,<:Number}, + translation::StaticVector{N_out,<:Number}, background::Number, out_weight::Number, point_weight::CuOrFillVector{<:Number}, ds_dpoints::AbstractMatrix{<:Number}, ds_dpoint_weight::AbstractVector{<:Number}; - kwargs... -) where {N_in, N_out} = error("Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays") + kwargs..., +) where {N_in,N_out} + return error( + "Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays", + ) +end # batch of images function DiffPointRasterisation.raster_pullback!( - ds_dout::CuArray{<:Number, N_out_p1}, - points::CuVector{<:StaticVector{N_in, <:Number}}, - rotation::CuVector{<:StaticMatrix{N_out, N_in, <:Number}}, - translation::CuVector{<:StaticVector{N_out, <:Number}}, + ds_dout::CuArray{<:Number,N_out_p1}, + points::CuVector{<:StaticVector{N_in,<:Number}}, + rotation::CuVector{<:StaticMatrix{N_out,N_in,<:Number}}, + translation::CuVector{<:StaticVector{N_out,<:Number}}, background::CuOrFillVector{<:Number}, out_weight::CuOrFillVector{<:Number}, point_weight::CuOrFillVector{<:Number}, ds_dpoints::CuMatrix{TP}, - ds_drotation::CuArray{TR, 3}, + ds_drotation::CuArray{TR,3}, ds_dtranslation::CuMatrix{TT}, ds_dbackground::CuVector{<:Number}, ds_dout_weight::CuVector{OW}, ds_dpoint_weight::CuVector{PW}, -) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, OW<:Number, PW<:Number} +) where {N_in,N_out,N_out_p1,TP<:Number,TR<:Number,TT<:Number,OW<:Number,PW<:Number} T = promote_type(eltype(ds_dout), TP, TR, TT, OW, PW) batch_axis = axes(ds_dout, N_out_p1) @argcheck N_out == N_out_p1 - 1 - @argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(out_weight, 1) - @argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dout_weight, 1) + @argcheck batch_axis == + axes(rotation, 1) == + axes(translation, 1) == + axes(background, 1) == + axes(out_weight, 1) + @argcheck batch_axis == + axes(ds_drotation, 3) == + axes(ds_dtranslation, 2) == + axes(ds_dbackground, 1) == + axes(ds_dout_weight, 1) @argcheck N_out == N_out_p1 - 1 n_points = length(points) @argcheck length(ds_dpoint_weight) == n_points batch_size = length(batch_axis) - ds_dbackground = vec(sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout)) + ds_dbackground = vec( + sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout) + ) - scale = SVector{N_out, T}(size(ds_dout)[1:end-1]) / T(2) - shifts=DiffPointRasterisation.voxel_shifts(Val(N_out)) + scale = SVector{N_out,T}(size(ds_dout)[1:(end - 1)]) / T(2) + shifts = DiffPointRasterisation.voxel_shifts(Val(N_out)) ds_dpoints = fill!(ds_dpoints, zero(TP)) ds_drotation = fill!(ds_drotation, zero(TR)) @@ -262,19 +275,34 @@ function DiffPointRasterisation.raster_pullback!( ds_dout_weight = fill!(ds_dout_weight, zero(OW)) ds_dpoint_weight = fill!(ds_dpoint_weight, zero(PW)) - args = (T, ds_dout, points, rotation, translation, out_weight, point_weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dout_weight, ds_dpoint_weight) + args = ( + T, + ds_dout, + points, + rotation, + translation, + out_weight, + point_weight, + shifts, + scale, + ds_dpoints, + ds_drotation, + ds_dtranslation, + ds_dout_weight, + ds_dpoint_weight, + ) ndrange = (n_points, batch_size, 2^N_out) workgroup_size(threads) = (1, min(threads ÷ (2^N_out), batch_size), 2^N_out) function shmem(threads) - _, bs_p_wg, n_voxel = workgroup_size(threads) - ((N_out + 1) * n_voxel + N_in) * bs_p_wg * sizeof(T) + _, bs_p_wg, n_voxel = workgroup_size(threads) + return ((N_out + 1) * n_voxel + N_in) * bs_p_wg * sizeof(T) # ((N_out + 1) * threads + N_in * bs_p_wg) * sizeof(T) end - let kernel = @cuda launch=false raster_pullback_kernel!(args...) + let kernel = @cuda launch = false raster_pullback_kernel!(args...) config = CUDA.launch_configuration(kernel.fun; shmem) workgroup_sz = workgroup_size(config.threads) blocks = cld.(ndrange, workgroup_sz) @@ -292,9 +320,16 @@ function DiffPointRasterisation.raster_pullback!( ) end +function DiffPointRasterisation.default_ds_dpoints_batched( + points::CuVector{<:AbstractVector{TP}}, N_in, batch_size +) where {TP<:Number} + return similar(points, TP, (N_in, length(points))) +end -DiffPointRasterisation.default_ds_dpoints_batched(points::CuVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points))) - -DiffPointRasterisation.default_ds_dpoint_weight_batched(points::CuVector{<:AbstractVector{<:Number}}, T, batch_size) = similar(points, T) +function DiffPointRasterisation.default_ds_dpoint_weight_batched( + points::CuVector{<:AbstractVector{<:Number}}, T, batch_size +) + return similar(points, T) +end end # module \ No newline at end of file diff --git a/ext/DiffPointRasterisationChainRulesCoreExt.jl b/ext/DiffPointRasterisationChainRulesCoreExt.jl index 3c94989..e64341b 100644 --- a/ext/DiffPointRasterisationChainRulesCoreExt.jl +++ b/ext/DiffPointRasterisationChainRulesCoreExt.jl @@ -4,89 +4,91 @@ using DiffPointRasterisation, ChainRulesCore, StaticArrays # single image function ChainRulesCore.rrule( - ::typeof(DiffPointRasterisation.raster), + ::typeof(DiffPointRasterisation.raster), grid_size, - points::AbstractVector{<:StaticVector{N_in, T}}, + points::AbstractVector{<:StaticVector{N_in,T}}, rotation::AbstractMatrix{<:Number}, translation::AbstractVector{<:Number}, - optional_args... -) where {N_in, T<:Number} + optional_args..., +) where {N_in,T<:Number} out = raster(grid_size, points, rotation, translation, optional_args...) function raster_pullback(ds_dout) out_pb = raster_pullback!( - unthunk(ds_dout), - points, - rotation, - translation, - optional_args..., + unthunk(ds_dout), points, rotation, translation, optional_args... ) - ds_dpoints = reinterpret(reshape, SVector{N_in, T}, out_pb.points) - return NoTangent(), NoTangent(), ds_dpoints, values(out_pb)[2:3+length(optional_args)]... + ds_dpoints = reinterpret(reshape, SVector{N_in,T}, out_pb.points) + return NoTangent(), + NoTangent(), ds_dpoints, + values(out_pb)[2:(3 + length(optional_args))]... end return out, raster_pullback end -ChainRulesCore.rrule( - f::typeof(DiffPointRasterisation.raster), +function ChainRulesCore.rrule( + f::typeof(DiffPointRasterisation.raster), grid_size, points::AbstractVector{<:AbstractVector{<:Number}}, rotation::AbstractMatrix{<:Number}, translation::AbstractVector{<:Number}, - optional_args... -) = ChainRulesCore.rrule( - f, - grid_size, - DiffPointRasterisation.inner_to_sized(points), - rotation, - translation, - optional_args... + optional_args..., ) + return ChainRulesCore.rrule( + f, + grid_size, + DiffPointRasterisation.inner_to_sized(points), + rotation, + translation, + optional_args..., + ) +end # batch of images function ChainRulesCore.rrule( - ::typeof(DiffPointRasterisation.raster), + ::typeof(DiffPointRasterisation.raster), grid_size, - points::AbstractVector{<:StaticVector{N_in, TP}}, - rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, - translation::AbstractVector{<:StaticVector{N_out, TT}}, - optional_args... -) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number} + points::AbstractVector{<:StaticVector{N_in,TP}}, + rotation::AbstractVector{<:StaticMatrix{N_out,N_in,TR}}, + translation::AbstractVector{<:StaticVector{N_out,TT}}, + optional_args..., +) where {N_in,N_out,TP<:Number,TR<:Number,TT<:Number} out = raster(grid_size, points, rotation, translation, optional_args...) function raster_pullback(ds_dout) out_pb = raster_pullback!( - unthunk(ds_dout), - points, - rotation, - translation, - optional_args..., + unthunk(ds_dout), points, rotation, translation, optional_args... ) - ds_dpoints = reinterpret(reshape, SVector{N_in, TP}, out_pb.points) + ds_dpoints = reinterpret(reshape, SVector{N_in,TP}, out_pb.points) L = N_out * N_in - ds_drotation = reinterpret(reshape, SMatrix{N_out, N_in, TR, L}, reshape(out_pb.rotation, L, :)) - ds_dtranslation = reinterpret(reshape, SVector{N_out, TT}, out_pb.translation) - return NoTangent(), NoTangent(), ds_dpoints, ds_drotation, ds_dtranslation, values(out_pb)[4:3+length(optional_args)]... + ds_drotation = reinterpret( + reshape, SMatrix{N_out,N_in,TR,L}, reshape(out_pb.rotation, L, :) + ) + ds_dtranslation = reinterpret(reshape, SVector{N_out,TT}, out_pb.translation) + return NoTangent(), + NoTangent(), ds_dpoints, ds_drotation, ds_dtranslation, + values(out_pb)[4:(3 + length(optional_args))]... end return out, raster_pullback end -ChainRulesCore.rrule( - f::typeof(DiffPointRasterisation.raster), +function ChainRulesCore.rrule( + f::typeof(DiffPointRasterisation.raster), grid_size, points::AbstractVector{<:AbstractVector{<:Number}}, rotation::AbstractVector{<:AbstractMatrix{<:Number}}, translation::AbstractVector{<:AbstractVector{<:Number}}, - optional_args... -) = ChainRulesCore.rrule( - f, - grid_size, - DiffPointRasterisation.inner_to_sized(points), - DiffPointRasterisation.inner_to_sized(rotation), - DiffPointRasterisation.inner_to_sized(translation), - optional_args... + optional_args..., ) + return ChainRulesCore.rrule( + f, + grid_size, + DiffPointRasterisation.inner_to_sized(points), + DiffPointRasterisation.inner_to_sized(rotation), + DiffPointRasterisation.inner_to_sized(translation), + optional_args..., + ) +end end # module DiffPointRasterisationChainRulesCoreExt \ No newline at end of file diff --git a/src/interface.jl b/src/interface.jl index 0d5e4ea..7bbeb63 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -59,10 +59,7 @@ function raster! end # Step 1: Allocate output ############################################### -function raster( - grid_size::Tuple, - args..., -) +function raster(grid_size::Tuple, args...) eltypes = deep_eltype.(args) T = promote_type(eltypes...) points = args[1] @@ -76,21 +73,23 @@ function raster( batch_size = length(rotation) out = similar(points, T, (grid_size..., batch_size)) end - raster!(out, args...) + return raster!(out, args...) end deep_eltype(el) = deep_eltype(typeof(el)) deep_eltype(t::Type) = t deep_eltype(t::Type{<:AbstractArray}) = deep_eltype(eltype(t)) - ############################################### # Step 2: Fill default arguments if necessary ############################################### -@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any, 3}) = raster!(out, args..., default_background(args[2])) -@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any, 4}) = raster!(out, args..., default_out_weight(args[2])) -@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any, 5}) = raster!(out, args..., default_point_weight(args[1])) +@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any,3}) = + raster!(out, args..., default_background(args[2])) +@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any,4}) = + raster!(out, args..., default_out_weight(args[2])) +@inline raster!(out::AbstractArray{<:Number}, args::Vararg{Any,5}) = + raster!(out, args..., default_point_weight(args[1])) ############################################### # Step 3: Convenience interface for single image: @@ -98,7 +97,7 @@ deep_eltype(t::Type{<:AbstractArray}) = deep_eltype(eltype(t)) # length-1 vec of arguments ############################################### -raster!( +function raster!( out::AbstractArray{<:Number}, points::AbstractVector{<:AbstractVector{<:Number}}, rotation::AbstractMatrix{<:Number}, @@ -106,24 +105,28 @@ raster!( background::Number, weight::Number, point_weight::AbstractVector{<:Number}, -) = drop_last_dim( - raster!( - append_singleton_dim(out), - points, - @SVector([rotation]), - @SVector([translation]), - @SVector([background]), - @SVector([weight]), - point_weight, - ) ) + return drop_last_dim( + raster!( + append_singleton_dim(out), + points, + @SVector([rotation]), + @SVector([translation]), + @SVector([background]), + @SVector([weight]), + point_weight, + ), + ) +end ############################################### # Step 4: Convert arguments to canonical form, # i.e. vectors of statically sized arrays ############################################### -raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector, 6}) = raster!(out, inner_to_sized.(args)...) +function raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector,6}) + return raster!(out, inner_to_sized.(args)...) +end ############################################### # Step 5: Error on inconsistent dimensions @@ -132,24 +135,30 @@ raster!(out::AbstractArray{<:Number}, args::Vararg{AbstractVector, 6}) = raster! # if N_out_rot == N_out_trans this should not be called # because the actual implementation specializes on N_out function raster!( - ::AbstractArray{<:Number, N_out}, - ::AbstractVector{<:StaticVector{N_in, <:Number}}, - ::AbstractVector{<:StaticMatrix{N_out_rot, N_in_rot, <:Number}}, - ::AbstractVector{<:StaticVector{N_out_trans, <:Number}}, + ::AbstractArray{<:Number,N_out}, + ::AbstractVector{<:StaticVector{N_in,<:Number}}, + ::AbstractVector{<:StaticMatrix{N_out_rot,N_in_rot,<:Number}}, + ::AbstractVector{<:StaticVector{N_out_trans,<:Number}}, ::AbstractVector{<:Number}, ::AbstractVector{<:Number}, ::AbstractVector{<:Number}, -) where {N_in, N_out, N_in_rot, N_out_rot, N_out_trans} +) where {N_in,N_out,N_in_rot,N_out_rot,N_out_trans} if N_out_trans != N_out - error("Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!") + error( + "Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!", + ) end if N_out_rot != N_out - error("Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!") + error( + "Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!", + ) end if N_in_rot != N_in - error("Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!") + error( + "Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!", + ) end - error("Dispatch error. Should not arrive here. Please file a bug.") + return error("Dispatch error. Should not arrive here. Please file a bug.") end # now similar for pullback @@ -180,15 +189,16 @@ See also [Raster a single point cloud to a batch of poses](@ref) """ function raster_pullback! end - ############################################### # Step 1: Fill default arguments if necessary ############################################### -@inline raster_pullback!(ds_out::AbstractArray{<:Number}, args::Vararg{Any, 3}; kwargs...) = raster_pullback!(ds_out, args..., default_background(args[2]); kwargs...) -@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any, 4}; kwargs...) = raster_pullback!(ds_dout, args..., default_out_weight(args[2]); kwargs...) -@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any, 5}; kwargs...) = raster_pullback!(ds_dout, args..., default_point_weight(args[1]); kwargs...) - +@inline raster_pullback!(ds_out::AbstractArray{<:Number}, args::Vararg{Any,3}; kwargs...) = + raster_pullback!(ds_out, args..., default_background(args[2]); kwargs...) +@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any,4}; kwargs...) = + raster_pullback!(ds_dout, args..., default_out_weight(args[2]); kwargs...) +@inline raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{Any,5}; kwargs...) = + raster_pullback!(ds_dout, args..., default_point_weight(args[1]); kwargs...) ############################################### # Step 2: Convert arguments to canonical form, @@ -196,7 +206,7 @@ function raster_pullback! end ############################################### # single image -raster_pullback!( +function raster_pullback!( ds_dout::AbstractArray{<:Number}, points::AbstractVector{<:AbstractVector{<:Number}}, rotation::AbstractMatrix{<:Number}, @@ -204,163 +214,202 @@ raster_pullback!( background::Number, out_weight::Number, point_weight::AbstractVector{<:Number}; - kwargs... -) = raster_pullback!( - ds_dout, - inner_to_sized(points), - to_sized(rotation), - to_sized(translation), - background, - out_weight, - point_weight; - kwargs... + kwargs..., ) + return raster_pullback!( + ds_dout, + inner_to_sized(points), + to_sized(rotation), + to_sized(translation), + background, + out_weight, + point_weight; + kwargs..., + ) +end # batch of images -raster_pullback!(ds_dout::AbstractArray{<:Number}, args::Vararg{AbstractVector, 6}; kwargs...) = raster_pullback!(ds_dout, inner_to_sized.(args)...; kwargs...) - +function raster_pullback!( + ds_dout::AbstractArray{<:Number}, args::Vararg{AbstractVector,6}; kwargs... +) + return raster_pullback!(ds_dout, inner_to_sized.(args)...; kwargs...) +end ############################################### # Step 3: Allocate output ############################################### # single image -raster_pullback!( - ds_dout::AbstractArray{<:Number, N_out}, - inp_points::AbstractVector{<:StaticVector{N_in, TP}}, - inp_rotation::StaticMatrix{N_out, N_in, <:Number}, - inp_translation::StaticVector{N_out, <:Number}, +function raster_pullback!( + ds_dout::AbstractArray{<:Number,N_out}, + inp_points::AbstractVector{<:StaticVector{N_in,TP}}, + inp_rotation::StaticMatrix{N_out,N_in,<:Number}, + inp_translation::StaticVector{N_out,<:Number}, inp_background::Number, inp_out_weight::Number, inp_point_weight::AbstractVector{PW}; - points::AbstractMatrix{TP} = default_ds_dpoints_single(inp_points, N_in), - point_weight::AbstractVector{PW} = similar(inp_points, PW), - kwargs... -) where {N_in, N_out, TP<:Number, PW<:Number} = raster_pullback!( - ds_dout, - inp_points, - inp_rotation, - inp_translation, - inp_background, - inp_out_weight, - inp_point_weight, - points, - point_weight; - kwargs... -) + points::AbstractMatrix{TP}=default_ds_dpoints_single(inp_points, N_in), + point_weight::AbstractVector{PW}=similar(inp_points, PW), + kwargs..., +) where {N_in,N_out,TP<:Number,PW<:Number} + return raster_pullback!( + ds_dout, + inp_points, + inp_rotation, + inp_translation, + inp_background, + inp_out_weight, + inp_point_weight, + points, + point_weight; + kwargs..., + ) +end # batch of images -raster_pullback!( +function raster_pullback!( ds_dout::AbstractArray{<:Number}, - inp_points::AbstractVector{<:StaticVector{N_in, TP}}, - inp_rotation::AbstractVector{<:StaticMatrix{N_out, N_in, TR}}, - inp_translation::AbstractVector{<:StaticVector{N_out, TT}}, + inp_points::AbstractVector{<:StaticVector{N_in,TP}}, + inp_rotation::AbstractVector{<:StaticMatrix{N_out,N_in,TR}}, + inp_translation::AbstractVector{<:StaticVector{N_out,TT}}, inp_background::AbstractVector{TB}, inp_out_weight::AbstractVector{OW}, inp_point_weight::AbstractVector{PW}; - points::AbstractArray{TP} = default_ds_dpoints_batched(inp_points, N_in, length(inp_rotation)), - rotation::AbstractArray{TR, 3} = similar(inp_points, TR, (N_out, N_in, length(inp_rotation))), - translation::AbstractMatrix{TT} = similar(inp_points, TT, (N_out, length(inp_translation))), - background::AbstractVector{TB} = similar(inp_points, TB, (length(inp_background))), - out_weight::AbstractVector{OW} = similar(inp_points, OW, (length(inp_out_weight))), - point_weight::AbstractArray{PW} = default_ds_dpoint_weight_batched(inp_points, PW, length(inp_rotation)), -) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, TB<:Number, OW<:Number, PW<:Number} = raster_pullback!( - ds_dout, - inp_points, - inp_rotation, - inp_translation, - inp_background, - inp_out_weight, - inp_point_weight, - points, - rotation, - translation, - background, - out_weight, - point_weight, -) - + points::AbstractArray{TP}=default_ds_dpoints_batched( + inp_points, N_in, length(inp_rotation) + ), + rotation::AbstractArray{TR,3}=similar( + inp_points, TR, (N_out, N_in, length(inp_rotation)) + ), + translation::AbstractMatrix{TT}=similar( + inp_points, TT, (N_out, length(inp_translation)) + ), + background::AbstractVector{TB}=similar(inp_points, TB, (length(inp_background))), + out_weight::AbstractVector{OW}=similar(inp_points, OW, (length(inp_out_weight))), + point_weight::AbstractArray{PW}=default_ds_dpoint_weight_batched( + inp_points, PW, length(inp_rotation) + ), +) where {N_in,N_out,TP<:Number,TR<:Number,TT<:Number,TB<:Number,OW<:Number,PW<:Number} + return raster_pullback!( + ds_dout, + inp_points, + inp_rotation, + inp_translation, + inp_background, + inp_out_weight, + inp_point_weight, + points, + rotation, + translation, + background, + out_weight, + point_weight, + ) +end ############################################### # Step 4: Error on inconsistent dimensions ############################################### # single image -raster_pullback!( - ::AbstractArray{<:Number, N_out}, - ::AbstractVector{<:StaticVector{N_in, <:Number}}, - ::StaticMatrix{N_out_rot, N_in_rot, <:Number}, - ::StaticVector{N_out_trans, <:Number}, +function raster_pullback!( + ::AbstractArray{<:Number,N_out}, + ::AbstractVector{<:StaticVector{N_in,<:Number}}, + ::StaticMatrix{N_out_rot,N_in_rot,<:Number}, + ::StaticVector{N_out_trans,<:Number}, ::Number, ::Number, ::AbstractVector{<:Number}, ::AbstractMatrix{<:Number}, ::AbstractVector{<:Number}; - kwargs... -) where {N_in, N_out, N_in_rot, N_out_rot, N_out_trans} = error_dimensions( - N_in, - N_out, - N_in_rot, - N_out_rot, - N_out_trans -) + kwargs..., +) where {N_in,N_out,N_in_rot,N_out_rot,N_out_trans} + return error_dimensions(N_in, N_out, N_in_rot, N_out_rot, N_out_trans) +end # batch of images -raster_pullback!( - ::AbstractArray{<:Number, N_out_p1}, - ::AbstractVector{<:StaticVector{N_in, <:Number}}, - ::AbstractVector{<:StaticMatrix{N_out_rot, N_in_rot, <:Number}}, - ::AbstractVector{<:StaticVector{N_out_trans, <:Number}}, +function raster_pullback!( + ::AbstractArray{<:Number,N_out_p1}, + ::AbstractVector{<:StaticVector{N_in,<:Number}}, + ::AbstractVector{<:StaticMatrix{N_out_rot,N_in_rot,<:Number}}, + ::AbstractVector{<:StaticVector{N_out_trans,<:Number}}, ::AbstractVector{<:Number}, ::AbstractVector{<:Number}, ::AbstractVector{<:Number}, ::AbstractArray{<:Number}, - ::AbstractArray{<:Number, 3}, + ::AbstractArray{<:Number,3}, ::AbstractMatrix{<:Number}, ::AbstractVector{<:Number}, ::AbstractVector{<:Number}, ::AbstractArray{<:Number}, -) where {N_in, N_out_p1, N_in_rot, N_out_rot, N_out_trans} = error_dimensions( - N_in, - N_out_p1 - 1, - N_in_rot, - N_out_rot, - N_out_trans -) +) where {N_in,N_out_p1,N_in_rot,N_out_rot,N_out_trans} + return error_dimensions(N_in, N_out_p1 - 1, N_in_rot, N_out_rot, N_out_trans) +end function error_dimensions(N_in, N_out, N_in_rot, N_out_rot, N_out_trans) if N_out_trans != N_out - error("Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!") + error( + "Dimension of translation (got $N_out_trans) and output dimentsion (got $N_out) must agree!", + ) end if N_out_rot != N_out - error("Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!") + error( + "Row dimension of rotation (got $N_out_rot) and output dimentsion (got $N_out) must agree!", + ) end if N_in_rot != N_in - error("Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!") + error( + "Column dimension of rotation (got $N_in_rot) and points (got $N_in) must agree!", + ) end - error("Dispatch error. Should not arrive here. Please file a bug.") + return error("Dispatch error. Should not arrive here. Please file a bug.") end default_background(rotation::AbstractMatrix, T=eltype(rotation)) = zero(T) -default_background(rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation))) = Zeros(T, length(rotation)) +function default_background( + rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation)) +) + return Zeros(T, length(rotation)) +end -default_background(rotation::AbstractArray{_T, 3} where _T, T=eltype(rotation)) = Zeros(T, size(rotation, 3)) +function default_background(rotation::AbstractArray{_T,3} where {_T}, T=eltype(rotation)) + return Zeros(T, size(rotation, 3)) +end default_out_weight(rotation::AbstractMatrix, T=eltype(rotation)) = one(T) -default_out_weight(rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation))) = Ones(T, length(rotation)) - -default_out_weight(rotation::AbstractArray{_T, 3} where _T, T=eltype(rotation)) = Ones(T, size(rotation, 3)) +function default_out_weight( + rotation::AbstractVector{<:AbstractMatrix}, T=eltype(eltype(rotation)) +) + return Ones(T, length(rotation)) +end -default_point_weight(points::AbstractVector{<:AbstractVector{T}}) where {T<:Number} = Ones(T, length(points)) +function default_out_weight(rotation::AbstractArray{_T,3} where {_T}, T=eltype(rotation)) + return Ones(T, size(rotation, 3)) +end -default_ds_dpoints_single(points::AbstractVector{<:AbstractVector{TP}}, N_in) where {TP<:Number} = similar(points, TP, (N_in, length(points))) +function default_point_weight(points::AbstractVector{<:AbstractVector{T}}) where {T<:Number} + return Ones(T, length(points)) +end -default_ds_dpoints_batched(points::AbstractVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points), min(batch_size, Threads.nthreads()))) +function default_ds_dpoints_single( + points::AbstractVector{<:AbstractVector{TP}}, N_in +) where {TP<:Number} + return similar(points, TP, (N_in, length(points))) +end -default_ds_dpoint_weight_batched(points::AbstractVector{<:AbstractVector{<:Number}}, T, batch_size) = similar(points, T, (length(points), min(batch_size, Threads.nthreads()))) +function default_ds_dpoints_batched( + points::AbstractVector{<:AbstractVector{TP}}, N_in, batch_size +) where {TP<:Number} + return similar(points, TP, (N_in, length(points), min(batch_size, Threads.nthreads()))) +end +function default_ds_dpoint_weight_batched( + points::AbstractVector{<:AbstractVector{<:Number}}, T, batch_size +) + return similar(points, T, (length(points), min(batch_size, Threads.nthreads()))) +end @testitem "raster interface" begin include("../test/data.jl") @@ -451,12 +500,7 @@ default_ds_dpoint_weight_batched(points::AbstractVector{<:AbstractVector{<:Numbe ) end @testset "default arguments all as non-static array" begin - @test out ≈ raster( - D.grid_size_3d, - D.points, - D.rotations, - D.translations_3d, - ) + @test out ≈ raster(D.grid_size_3d, D.points, D.rotations, D.translations_3d) end end @@ -546,12 +590,7 @@ default_ds_dpoint_weight_batched(points::AbstractVector{<:AbstractVector{<:Numbe ) end @testset "default arguments all as non-static array" begin - @test out ≈ raster( - D.grid_size_2d, - D.points, - D.projections, - D.translations_2d, - ) + @test out ≈ raster(D.grid_size_2d, D.points, D.projections, D.translations_2d) end end end \ No newline at end of file diff --git a/src/raster.jl b/src/raster.jl index a3148d1..8cc72dc 100644 --- a/src/raster.jl +++ b/src/raster.jl @@ -3,33 +3,46 @@ ############################################### function raster!( - out::AbstractArray{T, N_out_p1}, - points::AbstractVector{<:StaticVector{N_in, <:Number}}, - rotation::AbstractVector{<:StaticMatrix{N_out, N_in, <:Number}}, - translation::AbstractVector{<:StaticVector{N_out, <:Number}}, + out::AbstractArray{T,N_out_p1}, + points::AbstractVector{<:StaticVector{N_in,<:Number}}, + rotation::AbstractVector{<:StaticMatrix{N_out,N_in,<:Number}}, + translation::AbstractVector{<:StaticVector{N_out,<:Number}}, background::AbstractVector{<:Number}, out_weight::AbstractVector{<:Number}, point_weight::AbstractVector{<:Number}, -) where {T<:Number, N_in, N_out, N_out_p1} +) where {T<:Number,N_in,N_out,N_out_p1} @argcheck N_out == N_out_p1 - 1 DimensionMismatch out_batch_dim = ndims(out) batch_size = size(out, out_batch_dim) - @argcheck batch_size == length(rotation) == length(translation) == length(background) == length(out_weight) DimensionMismatch + @argcheck batch_size == + length(rotation) == + length(translation) == + length(background) == + length(out_weight) DimensionMismatch n_points = length(points) @argcheck length(point_weight) == n_points - scale = SVector{N_out, T}(size(out)[1:end-1]) / T(2) - shifts=voxel_shifts(Val(N_out)) + scale = SVector{N_out,T}(size(out)[1:(end - 1)]) / T(2) + shifts = voxel_shifts(Val(N_out)) out .= reshape(background, ntuple(_ -> 1, Val(N_out))..., length(background)) args = (out, points, rotation, translation, out_weight, point_weight, shifts, scale) backend = get_backend(out) ndrange = (2^N_out, n_points, batch_size) - workgroup_size = 1024 + workgroup_size = 1024 raster_kernel!(backend, workgroup_size, ndrange)(args...) - out + return out end -@kernel function raster_kernel!(out::AbstractArray{T}, points, rotations, translations::AbstractVector{<:StaticVector{N_out}}, out_weights, point_weights, shifts, scale) where {T, N_out} +@kernel function raster_kernel!( + out::AbstractArray{T}, + points, + rotations, + translations::AbstractVector{<:StaticVector{N_out}}, + out_weights, + point_weights, + shifts, + scale, +) where {T,N_out} neighbor_voxel_id, point_idx, batch_idx = @index(Global, NTuple) point = @inbounds points[point_idx] @@ -40,18 +53,16 @@ end origin = (-@SVector ones(T, N_out)) - translation coord_reference_voxel, deltas = reference_coordinate_and_deltas( - point, - rotation, - origin, - scale, + point, rotation, origin, scale + ) + voxel_idx = CartesianIndex( + CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift), batch_idx ) - voxel_idx = CartesianIndex(CartesianIndex(Tuple(coord_reference_voxel)) + CartesianIndex(shift), batch_idx) if voxel_idx in CartesianIndices(out) val = voxel_weight(deltas, shift, weight) @inbounds Atomix.@atomic out[voxel_idx] += val end - nothing end """ @@ -72,10 +83,7 @@ Before `point` is discretized into this grid, it is first translated by `-origin` and then scaled by `scale`. """ @inline function reference_coordinate_and_deltas( - point::AbstractVector{T}, - rotation, - origin, - scale, + point::AbstractVector{T}, rotation, origin, scale ) where {T} projected_point = rotation * point # coordinate of transformed point in output coordinate system @@ -89,14 +97,14 @@ Before `point` is discretized into this grid, it is first translated by deltas_lower = coord - (coord_reference_voxel .- T(0.5)) # distances to lower (first column) and upper (second column) integer coordinates deltas = [deltas_lower one(T) .- deltas_lower] - coord_reference_voxel, deltas + return coord_reference_voxel, deltas end -@inline function voxel_weight(deltas, shift::NTuple{N, Int}, point_weight) where {N} +@inline function voxel_weight(deltas, shift::NTuple{N,Int}, point_weight) where {N} lower_upper = mod1.(shift, 2) delta_idxs = SVector{N}(CartesianIndex.(ntuple(identity, Val(N)), lower_upper)) val = prod(@inbounds @view deltas[delta_idxs]) * point_weight - val + return val end @testitem "raster correctness" begin @@ -113,12 +121,15 @@ end points_four_cross = reduce( vcat, [ - points_single_1pix_right, points_single_1pix_up, points_single_1pix_left, points_single_1pix_down - ] + points_single_1pix_right, + points_single_1pix_up, + points_single_1pix_left, + points_single_1pix_down, + ], ) - no_rotation = Float64[1;0;;0;1;;] - rotation_90_deg = Float64[0;1;;-1;0;;] + no_rotation = Float64[1; 0;; 0; 1] + rotation_90_deg = Float64[0; 1;; -1; 0] no_translation = zeros(2) translation_halfpix_right = [0.0, 0.2] @@ -129,7 +140,14 @@ end # -------- interpolations --------- - out = raster(grid_size, points_single_center, no_rotation, no_translation, zero_background, out_weight) + out = raster( + grid_size, + points_single_center, + no_rotation, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 0 0 0 @@ -137,8 +155,15 @@ end 0 0 0 0 0 0 0 0 0 0 ] - - out = raster(grid_size, points_single_1pix_right, no_rotation, no_translation, zero_background, out_weight) + + out = raster( + grid_size, + points_single_1pix_right, + no_rotation, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 0 0 0 @@ -147,7 +172,14 @@ end 0 0 0 0 0 ] - out = raster(grid_size, points_single_halfpix_down, no_rotation, no_translation, zero_background, out_weight) + out = raster( + grid_size, + points_single_halfpix_down, + no_rotation, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 0 0 0 @@ -156,7 +188,14 @@ end 0 0 0 0 0 ] - out = raster(grid_size, points_single_halfpix_down_and_right, no_rotation, no_translation, zero_background, out_weight) + out = raster( + grid_size, + points_single_halfpix_down_and_right, + no_rotation, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 0 0 0 @@ -167,7 +206,14 @@ end # -------- translations --------- - out = raster(grid_size, points_four_cross, no_rotation, no_translation, zero_background, out_weight) + out = raster( + grid_size, + points_four_cross, + no_rotation, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 4 0 0 @@ -176,7 +222,14 @@ end 0 0 0 0 0 ] - out = raster(grid_size, points_four_cross, no_rotation, translation_halfpix_right, zero_background, out_weight) + out = raster( + grid_size, + points_four_cross, + no_rotation, + translation_halfpix_right, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 2 2 0 @@ -185,7 +238,14 @@ end 0 0 0 0 0 ] - out = raster(grid_size, points_four_cross, no_rotation, translation_1pix_down, zero_background, out_weight) + out = raster( + grid_size, + points_four_cross, + no_rotation, + translation_1pix_down, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 0 0 0 @@ -196,7 +256,14 @@ end # -------- rotations --------- - out = raster(grid_size, points_single_1pix_right, rotation_90_deg, no_translation, zero_background, out_weight) + out = raster( + grid_size, + points_single_1pix_right, + rotation_90_deg, + no_translation, + zero_background, + out_weight, + ) @test out ≈ [ 0 0 0 0 0 0 0 4 0 0 @@ -207,7 +274,15 @@ end # -------- point weights --------- - out = raster(grid_size, points_four_cross, no_rotation, no_translation, zero_background, 1.0, [1.0, 2.0, 3.0, 4.0]) + out = raster( + grid_size, + points_four_cross, + no_rotation, + no_translation, + zero_background, + 1.0, + [1.0, 2.0, 3.0, 4.0], + ) @test out ≈ [ 0 0 0 0 0 0 0 2 0 0 @@ -216,7 +291,15 @@ end 0 0 0 0 0 ] - out = raster(grid_size, points_four_cross, no_rotation, translation_halfpix_right, zero_background, 2.0, [1.0, 2.0, 3.0, 4.0]) + out = raster( + grid_size, + points_four_cross, + no_rotation, + translation_halfpix_right, + zero_background, + 2.0, + [1.0, 2.0, 3.0, 4.0], + ) @test out ≈ [ 0 0 0 0 0 0 0 2 2 0 @@ -226,7 +309,6 @@ end ] end - @testitem "raster inference and allocations" begin using BenchmarkTools, CUDA, StaticArrays include("../test/data.jl") @@ -234,38 +316,70 @@ end # check type stability # single image - @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_static, D.rotation, D.translation_3d) - @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_static, D.projection, D.translation_2d) + @inferred DiffPointRasterisation.raster( + D.grid_size_3d, D.points_static, D.rotation, D.translation_3d + ) + @inferred DiffPointRasterisation.raster( + D.grid_size_2d, D.points_static, D.projection, D.translation_2d + ) # batched canonical - @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_static, D.rotations_static, D.translations_3d_static) - @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_static, D.projections_static, D.translations_2d_static) + @inferred DiffPointRasterisation.raster( + D.grid_size_3d, D.points_static, D.rotations_static, D.translations_3d_static + ) + @inferred DiffPointRasterisation.raster( + D.grid_size_2d, D.points_static, D.projections_static, D.translations_2d_static + ) # batched reinterpret reshape - @inferred DiffPointRasterisation.raster(D.grid_size_3d, D.points_reinterp, D.rotations_reinterp, D.translations_3d_reinterp) - @inferred DiffPointRasterisation.raster(D.grid_size_2d, D.points_reinterp, D.projections_reinterp, D.translations_2d_reinterp) + @inferred DiffPointRasterisation.raster( + D.grid_size_3d, D.points_reinterp, D.rotations_reinterp, D.translations_3d_reinterp + ) + @inferred DiffPointRasterisation.raster( + D.grid_size_2d, + D.points_reinterp, + D.projections_reinterp, + D.translations_2d_reinterp, + ) if CUDA.functional() # single image - @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points_static), cu(D.rotation), cu(D.translation_3d)) - @inferred DiffPointRasterisation.raster(D.grid_size_2d, cu(D.points_static), cu(D.projection), cu(D.translation_2d)) + @inferred DiffPointRasterisation.raster( + D.grid_size_3d, cu(D.points_static), cu(D.rotation), cu(D.translation_3d) + ) + @inferred DiffPointRasterisation.raster( + D.grid_size_2d, cu(D.points_static), cu(D.projection), cu(D.translation_2d) + ) # batched - @inferred DiffPointRasterisation.raster(D.grid_size_3d, cu(D.points_static), cu(D.rotations_static), cu(D.translations_3d_static)) - @inferred DiffPointRasterisation.raster(D.grid_size_2d, cu(D.points_static), cu(D.projections_static), cu(D.translations_2d_static)) + @inferred DiffPointRasterisation.raster( + D.grid_size_3d, + cu(D.points_static), + cu(D.rotations_static), + cu(D.translations_3d_static), + ) + @inferred DiffPointRasterisation.raster( + D.grid_size_2d, + cu(D.points_static), + cu(D.projections_static), + cu(D.translations_2d_static), + ) end # Ideally the sinlge image (non batched) case would be allocation-free. # The switch to KernelAbstractions made this allocating. # set test to broken for now. - out_3d = Array{Float64, 3}(undef, D.grid_size_3d...) - out_2d = Array{Float64, 2}(undef, D.grid_size_2d...) - allocations = @ballocated DiffPointRasterisation.raster!($out_3d, $D.points_static, $D.rotation, $D.translation_3d) evals=1 samples=1 - @test allocations == 0 broken=true - allocations = @ballocated DiffPointRasterisation.raster!($out_2d, $D.points_static, $D.projection, $D.translation_2d) evals=1 samples=1 - @test allocations == 0 broken=true + out_3d = Array{Float64,3}(undef, D.grid_size_3d...) + out_2d = Array{Float64,2}(undef, D.grid_size_2d...) + allocations = @ballocated DiffPointRasterisation.raster!( + $out_3d, $D.points_static, $D.rotation, $D.translation_3d + ) evals = 1 samples = 1 + @test allocations == 0 broken = true + allocations = @ballocated DiffPointRasterisation.raster!( + $out_2d, $D.points_static, $D.projection, $D.translation_2d + ) evals = 1 samples = 1 + @test allocations == 0 broken = true end - @testitem "raster batched consistency" begin include("../test/data.jl") @@ -273,21 +387,45 @@ end out_3d = zeros(D.grid_size_3d..., D.batch_size) out_3d_batched = zeros(D.grid_size_3d..., D.batch_size) - for (out_i, args...) in zip(eachslice(out_3d, dims=4), D.rotations, D.translations_3d, D.backgrounds, D.weights) + for (out_i, args...) in zip( + eachslice(out_3d; dims=4), D.rotations, D.translations_3d, D.backgrounds, D.weights + ) raster!(out_i, D.more_points, args..., D.more_point_weights) end - DiffPointRasterisation.raster!(out_3d_batched, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights, D.more_point_weights) + DiffPointRasterisation.raster!( + out_3d_batched, + D.more_points, + D.rotations, + D.translations_3d, + D.backgrounds, + D.weights, + D.more_point_weights, + ) # raster_project out_2d = zeros(D.grid_size_2d..., D.batch_size) out_2d_batched = zeros(D.grid_size_2d..., D.batch_size) - for (out_i, args...) in zip(eachslice(out_2d, dims=3), D.projections, D.translations_2d, D.backgrounds, D.weights) + for (out_i, args...) in zip( + eachslice(out_2d; dims=3), + D.projections, + D.translations_2d, + D.backgrounds, + D.weights, + ) DiffPointRasterisation.raster!(out_i, D.more_points, args..., D.more_point_weights) end - DiffPointRasterisation.raster!(out_2d_batched, D.more_points, D.projections, D.translations_2d, D.backgrounds, D.weights, D.more_point_weights) + DiffPointRasterisation.raster!( + out_2d_batched, + D.more_points, + D.projections, + D.translations_2d, + D.backgrounds, + D.weights, + D.more_point_weights, + ) @test out_2d_batched ≈ out_2d end \ No newline at end of file diff --git a/src/raster_pullback.jl b/src/raster_pullback.jl index 4678a8c..f981b6f 100644 --- a/src/raster_pullback.jl +++ b/src/raster_pullback.jl @@ -1,19 +1,22 @@ # single image function raster_pullback!( - ds_dout::AbstractArray{<:Number, N_out}, - points::AbstractVector{<:StaticVector{N_in, <:Number}}, - rotation::StaticMatrix{N_out, N_in, TR}, - translation::StaticVector{N_out, TT}, + ds_dout::AbstractArray{<:Number,N_out}, + points::AbstractVector{<:StaticVector{N_in,<:Number}}, + rotation::StaticMatrix{N_out,N_in,TR}, + translation::StaticVector{N_out,TT}, background::Number, out_weight::OW, point_weight::AbstractVector{<:Number}, ds_dpoints::AbstractMatrix{TP}, ds_dpoint_weight::AbstractVector{PW}; accumulate_ds_dpoints=false, -) where {N_in, N_out, TP<:Number, TR<:Number, TT<:Number, OW<:Number, PW<:Number} +) where {N_in,N_out,TP<:Number,TR<:Number,TT<:Number,OW<:Number,PW<:Number} T = promote_type(eltype(ds_dout), TP, TR, TT, OW, PW) @argcheck size(ds_dpoints, 1) == N_in - @argcheck length(point_weight) == length(points) == length(ds_dpoint_weight) == size(ds_dpoints, 2) + @argcheck length(point_weight) == + length(points) == + length(ds_dpoint_weight) == + size(ds_dpoints, 2) # The strategy followed here is to redo some of the calculations # made in the forward pass instead of caching them in the forward # pass and reusing them here. @@ -23,8 +26,8 @@ function raster_pullback!( end origin = (-@SVector ones(TT, N_out)) - translation - scale = SVector{N_out, T}(size(ds_dout)) / 2 - shifts=voxel_shifts(Val(N_out)) + scale = SVector{N_out,T}(size(ds_dout)) / 2 + shifts = voxel_shifts(Val(N_out)) all_density_idxs = CartesianIndices(ds_dout) # initialize some output for accumulation @@ -34,13 +37,10 @@ function raster_pullback!( # loop over points for (pt_idx, point) in enumerate(points) - point = SVector{N_in, TP}(point) + point = SVector{N_in,TP}(point) point_weight_i = point_weight[pt_idx] coord_reference_voxel, deltas = reference_coordinate_and_deltas( - point, - rotation, - origin, - scale, + point, rotation, origin, scale ) ds_dcoord = @SVector zeros(T, N_out) @@ -52,49 +52,63 @@ function raster_pullback!( ds_dout_i = ds_dout[voxel_idx] - ds_dweight = voxel_weight( - deltas, - shift, - ds_dout_i, - ) + ds_dweight = voxel_weight(deltas, shift, ds_dout_i) ds_dout_weight += ds_dweight * point_weight_i ds_dpoint_weight_i += ds_dweight * out_weight factor = ds_dout_i * out_weight * point_weight_i # loop over dimensions of point - ds_dcoord += SVector(factor .* ntuple(n -> interpolation_weight(n, N_out, deltas, shift), Val(N_out))) + ds_dcoord += SVector( + factor .* + ntuple(n -> interpolation_weight(n, N_out, deltas, shift), Val(N_out)), + ) end scaled = ds_dcoord .* scale ds_dtranslation += scaled ds_drotation += scaled * point' - ds_dpoint = rotation' * scaled + ds_dpoint = rotation' * scaled @view(ds_dpoints[:, pt_idx]) .+= ds_dpoint ds_dpoint_weight[pt_idx] += ds_dpoint_weight_i end - return (; points=ds_dpoints, rotation=ds_drotation, translation=ds_dtranslation, background=sum(ds_dout), out_weight=ds_dout_weight, point_weight=ds_dpoint_weight) + return (; + points=ds_dpoints, + rotation=ds_drotation, + translation=ds_dtranslation, + background=sum(ds_dout), + out_weight=ds_dout_weight, + point_weight=ds_dpoint_weight, + ) end # batch of images function raster_pullback!( - ds_dout::AbstractArray{<:Number, N_out_p1}, - points::AbstractVector{<:StaticVector{N_in, <:Number}}, - rotation::AbstractVector{<:StaticMatrix{N_out, N_in, <:Number}}, - translation::AbstractVector{<:StaticVector{N_out, <:Number}}, + ds_dout::AbstractArray{<:Number,N_out_p1}, + points::AbstractVector{<:StaticVector{N_in,<:Number}}, + rotation::AbstractVector{<:StaticMatrix{N_out,N_in,<:Number}}, + translation::AbstractVector{<:StaticVector{N_out,<:Number}}, background::AbstractVector{<:Number}, out_weight::AbstractVector{<:Number}, point_weight::AbstractVector{<:Number}, - ds_dpoints::AbstractArray{<:Number, 3}, - ds_drotation::AbstractArray{<:Number, 3}, + ds_dpoints::AbstractArray{<:Number,3}, + ds_drotation::AbstractArray{<:Number,3}, ds_dtranslation::AbstractMatrix{<:Number}, ds_dbackground::AbstractVector{<:Number}, ds_dout_weight::AbstractVector{<:Number}, ds_dpoint_weight::AbstractMatrix{<:Number}, -) where {N_in, N_out, N_out_p1} +) where {N_in,N_out,N_out_p1} batch_axis = axes(ds_dout, N_out_p1) @argcheck N_out == N_out_p1 - 1 - @argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(out_weight, 1) - @argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dout_weight, 1) + @argcheck batch_axis == + axes(rotation, 1) == + axes(translation, 1) == + axes(background, 1) == + axes(out_weight, 1) + @argcheck batch_axis == + axes(ds_drotation, 3) == + axes(ds_dtranslation, 2) == + axes(ds_dbackground, 1) == + axes(ds_dout_weight, 1) fill!(ds_dpoints, zero(eltype(ds_dpoints))) fill!(ds_dpoint_weight, zero(eltype(ds_dpoint_weight))) @@ -102,18 +116,37 @@ function raster_pullback!( Threads.@threads for (idxs, ichunk) in chunks(batch_axis, n_threads) for i in idxs - args_i = (selectdim(ds_dout, N_out_p1, i), points, rotation[i], translation[i], background[i], out_weight[i], point_weight) - result_i = raster_pullback!(args_i..., view(ds_dpoints, :, :, ichunk), view(ds_dpoint_weight, :, ichunk); accumulate_ds_dpoints=true) + args_i = ( + selectdim(ds_dout, N_out_p1, i), + points, + rotation[i], + translation[i], + background[i], + out_weight[i], + point_weight, + ) + result_i = raster_pullback!( + args_i..., + view(ds_dpoints, :, :, ichunk), + view(ds_dpoint_weight, :, ichunk); + accumulate_ds_dpoints=true, + ) ds_drotation[:, :, i] .= result_i.rotation ds_dtranslation[:, i] = result_i.translation ds_dbackground[i] = result_i.background ds_dout_weight[i] = result_i.out_weight end end - return (; points=dropdims(sum(ds_dpoints; dims=3); dims=3), rotation=ds_drotation, translation=ds_dtranslation, background=ds_dbackground, out_weight=ds_dout_weight, point_weight=dropdims(sum(ds_dpoint_weight; dims=2); dims=2)) + return (; + points=dropdims(sum(ds_dpoints; dims=3); dims=3), + rotation=ds_drotation, + translation=ds_dtranslation, + background=ds_dbackground, + out_weight=ds_dout_weight, + point_weight=dropdims(sum(ds_dpoint_weight; dims=2); dims=2), + ) end - function interpolation_weight(n, N, deltas, shift) val = @inbounds shift[n] == 1 ? one(eltype(deltas)) : -one(eltype(deltas)) # loop over other dimensions @@ -123,7 +156,7 @@ function interpolation_weight(n, N, deltas, shift) end val *= deltas[other_n, mod1(shift[other_n], 2)] end - val + return val end @testitem "raster_pullback! inference and allocations" begin @@ -135,7 +168,9 @@ end ds_dout_2d_batched = randn(D.grid_size_2d..., D.batch_size) ds_dpoints = similar(D.points_array) - ds_dpoints_batched = similar(D.points_array, (size(D.points_array)..., Threads.nthreads())) + ds_dpoints_batched = similar( + D.points_array, (size(D.points_array)..., Threads.nthreads()) + ) ds_drotations = similar(D.rotations_array) ds_dprojections = similar(D.projections_array) ds_dtranslations_3d = similar(D.translations_3d_array) @@ -143,7 +178,9 @@ end ds_dbackgrounds = similar(D.backgrounds) ds_dweights = similar(D.weights) ds_dpoint_weights = similar(D.point_weights) - ds_dpoint_weights_batched = similar(D.point_weights, (size(D.point_weights)..., Threads.nthreads())) + ds_dpoint_weights_batched = similar( + D.point_weights, (size(D.point_weights)..., Threads.nthreads()) + ) args_batched_3d = ( ds_dout_3d_batched, @@ -179,13 +216,33 @@ end function to_cuda(args) args_cu = adapt(CuArray, args) args_cu = Base.setindex(args_cu, args_cu[8][:, :, 1], 8) # ds_dpoint without batch dim - args_cu = Base.setindex(args_cu, args_cu[13][:, 1], 13) # ds_dpoint_weight without batch dim + return args_cu = Base.setindex(args_cu, args_cu[13][:, 1], 13) # ds_dpoint_weight without batch dim end # check type stability # single image - @inferred DiffPointRasterisation.raster_pullback!(ds_dout_3d, D.points_static, D.rotation, D.translation_3d, D.background, D.weight, D.point_weights, ds_dpoints, ds_dpoint_weights) - @inferred DiffPointRasterisation.raster_pullback!(ds_dout_2d, D.points_static, D.projection, D.translation_2d, D.background, D.weight, D.point_weights, ds_dpoints, ds_dpoint_weights) + @inferred DiffPointRasterisation.raster_pullback!( + ds_dout_3d, + D.points_static, + D.rotation, + D.translation_3d, + D.background, + D.weight, + D.point_weights, + ds_dpoints, + ds_dpoint_weights, + ) + @inferred DiffPointRasterisation.raster_pullback!( + ds_dout_2d, + D.points_static, + D.projection, + D.translation_2d, + D.background, + D.weight, + D.point_weights, + ds_dpoints, + ds_dpoint_weights, + ) # batched @inferred DiffPointRasterisation.raster_pullback!(args_batched_3d...) @inferred DiffPointRasterisation.raster_pullback!(args_batched_2d...) @@ -207,22 +264,37 @@ end $(D.point_weights), $ds_dpoints, $ds_dpoint_weights, - ) evals=1 samples=1 + ) evals = 1 samples = 1 @test allocations == 0 end - @testitem "raster_pullback! threaded" begin include("../test/data.jl") ds_dout = randn(D.grid_size_3d..., D.batch_size) - ds_dargs_threaded = DiffPointRasterisation.raster_pullback!(ds_dout, D.more_points, D.rotations, D.translations_3d, D.backgrounds, D.weights, D.more_point_weights) + ds_dargs_threaded = DiffPointRasterisation.raster_pullback!( + ds_dout, + D.more_points, + D.rotations, + D.translations_3d, + D.backgrounds, + D.weights, + D.more_point_weights, + ) ds_dpoints = Matrix{Float64}[] ds_dpoint_weight = Vector{Float64}[] - for i in 1:D.batch_size - ds_dargs_i = @views raster_pullback!(ds_dout[:, :, :, i], D.more_points, D.rotations[i], D.translations_3d[i], D.backgrounds[i], D.weights[i], D.more_point_weights) + for i in 1:(D.batch_size) + ds_dargs_i = @views raster_pullback!( + ds_dout[:, :, :, i], + D.more_points, + D.rotations[i], + D.translations_3d[i], + D.backgrounds[i], + D.weights[i], + D.more_point_weights, + ) push!(ds_dpoints, ds_dargs_i.points) push!(ds_dpoint_weight, ds_dargs_i.point_weight) @views begin @@ -237,12 +309,28 @@ end ds_dout = zeros(D.grid_size_2d..., D.batch_size) - ds_dargs_threaded = DiffPointRasterisation.raster_pullback!(ds_dout, D.more_points, D.projections, D.translations_2d, D.backgrounds, D.weights, D.more_point_weights) + ds_dargs_threaded = DiffPointRasterisation.raster_pullback!( + ds_dout, + D.more_points, + D.projections, + D.translations_2d, + D.backgrounds, + D.weights, + D.more_point_weights, + ) ds_dpoints = Matrix{Float64}[] ds_dpoint_weight = Vector{Float64}[] - for i in 1:D.batch_size - ds_dargs_i = @views raster_pullback!(ds_dout[:, :, i], D.more_points, D.projections[i], D.translations_2d[i], D.backgrounds[i], D.weights[i], D.more_point_weights) + for i in 1:(D.batch_size) + ds_dargs_i = @views raster_pullback!( + ds_dout[:, :, i], + D.more_points, + D.projections[i], + D.translations_2d[i], + D.backgrounds[i], + D.weights[i], + D.more_point_weights, + ) push!(ds_dpoints, ds_dargs_i.points) push!(ds_dpoint_weight, ds_dargs_i.point_weight) @views begin diff --git a/src/util.jl b/src/util.jl index 5482a8e..50bd068 100644 --- a/src/util.jl +++ b/src/util.jl @@ -4,12 +4,13 @@ Return a N-tuple containing the bit-representation of k """ -digitstuple(k, ::Val{N}, int_type=Int64) where {N} = ntuple(i -> int_type(k>>(i-1) % 2), N) +digitstuple(k, ::Val{N}, int_type=Int64) where {N} = + ntuple(i -> int_type(k >> (i - 1) % 2), N) @testitem "digitstuple" begin - @test DiffPointRasterisation.digitstuple(5, Val(3)) == (1, 0, 1) - @test DiffPointRasterisation.digitstuple(2, Val(2)) == (0, 1) - @test DiffPointRasterisation.digitstuple(2, Val(4)) == (0, 1, 0, 0) + @test DiffPointRasterisation.digitstuple(5, Val(3)) == (1, 0, 1) + @test DiffPointRasterisation.digitstuple(2, Val(2)) == (0, 1) + @test DiffPointRasterisation.digitstuple(2, Val(4)) == (0, 1, 0, 0) end """ @@ -22,7 +23,8 @@ For a N-dimensional voxel grid, return a 2^N-tuple of N-tuples, where each element of the outer tuple is a cartesian coordinate shift from the "upper left" voxel. """ -voxel_shifts(::Val{N}, int_type=Int64) where {N} = ntuple(k -> digitstuple(k-1, Val(N), int_type), Val(2^N)) +voxel_shifts(::Val{N}, int_type=Int64) where {N} = + ntuple(k -> digitstuple(k - 1, Val(N), int_type), Val(2^N)) @testitem "voxel_shifts" begin @inferred DiffPointRasterisation.voxel_shifts(Val(4)) @@ -31,20 +33,35 @@ voxel_shifts(::Val{N}, int_type=Int64) where {N} = ntuple(k -> digitstuple(k-1, @test DiffPointRasterisation.voxel_shifts(Val(2)) == ((0, 0), (1, 0), (0, 1), (1, 1)) - @test DiffPointRasterisation.voxel_shifts(Val(3)) == ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1)) + @test DiffPointRasterisation.voxel_shifts(Val(3)) == ( + (0, 0, 0), + (1, 0, 0), + (0, 1, 0), + (1, 1, 0), + (0, 0, 1), + (1, 0, 1), + (0, 1, 1), + (1, 1, 1), + ) end -to_sized(arg::StaticArray{<:Any, <:Number}) = arg +to_sized(arg::StaticArray{<:Any,<:Number}) = arg -to_sized(arg::AbstractArray{T}) where {T<:Number} = SizedArray{Tuple{size(arg)...}, T}(arg) +to_sized(arg::AbstractArray{T}) where {T<:Number} = SizedArray{Tuple{size(arg)...},T}(arg) inner_to_sized(arg::AbstractVector{<:Number}) = arg inner_to_sized(arg::AbstractVector{<:StaticArray}) = arg -inner_to_sized(arg::AbstractVector{<:AbstractArray{<:Number}}) = inner_to_sized(arg, Val(size(arg[1]))) +function inner_to_sized(arg::AbstractVector{<:AbstractArray{<:Number}}) + return inner_to_sized(arg, Val(size(arg[1]))) +end -inner_to_sized(arg::AbstractVector{<:AbstractArray{T}}, ::Val{sz}) where {sz, T<:Number} = SizedArray{Tuple{sz...}, T}.(arg) +function inner_to_sized( + arg::AbstractVector{<:AbstractArray{T}}, ::Val{sz} +) where {sz,T<:Number} + return SizedArray{Tuple{sz...},T}.(arg) +end @testitem "inner_to_sized" begin using StaticArrays @@ -74,11 +91,10 @@ inner_to_sized(arg::AbstractVector{<:AbstractArray{T}}, ::Val{sz}) where {sz, T< inp = [randn(3, 2) for _ in 1:5] out = DiffPointRasterisation.inner_to_sized(inp) @test out == inp - @test out isa Vector{<:StaticMatrix{3, 2}} + @test out isa Vector{<:StaticMatrix{3,2}} end end - @inline append_singleton_dim(a) = reshape(a, size(a)..., 1) @inline append_singleton_dim(a::Number) = [a] @@ -88,9 +104,13 @@ end @testitem "append drop dim" begin using BenchmarkTools a = randn(2, 3, 4) - a2 = DiffPointRasterisation.drop_last_dim(DiffPointRasterisation.append_singleton_dim(a)) - @test a2 === a broken=true - - allocations = @ballocated DiffPointRasterisation.drop_last_dim(DiffPointRasterisation.append_singleton_dim($a)) evals=1 samples=1 - @test allocations == 0 broken=true + a2 = DiffPointRasterisation.drop_last_dim( + DiffPointRasterisation.append_singleton_dim(a) + ) + @test a2 === a broken = true + + allocations = @ballocated DiffPointRasterisation.drop_last_dim( + DiffPointRasterisation.append_singleton_dim($a) + ) evals = 1 samples = 1 + @test allocations == 0 broken = true end \ No newline at end of file diff --git a/test/chainrules.jl b/test/chainrules.jl index 47a2ea0..6b0bbdd 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -3,30 +3,88 @@ using ChainRulesTestUtils, ChainRulesCore include("data.jl") - test_rrule(raster, D.grid_size_3d, D.points_static, D.rotation ⊢ D.rotation_tangent, D.translation_3d, D.background, D.weight, D.point_weights) + test_rrule( + raster, + D.grid_size_3d, + D.points_static, + D.rotation ⊢ D.rotation_tangent, + D.translation_3d, + D.background, + D.weight, + D.point_weights, + ) # default arguments - test_rrule(raster, D.grid_size_3d, D.points_static, D.rotation ⊢ D.rotation_tangent, D.translation_3d) + test_rrule( + raster, + D.grid_size_3d, + D.points_static, + D.rotation ⊢ D.rotation_tangent, + D.translation_3d, + ) - - test_rrule(raster, D.grid_size_2d, D.points_static, D.projection ⊢ D.projection_tangent, D.translation_2d, D.background, D.weight, D.point_weights) + test_rrule( + raster, + D.grid_size_2d, + D.points_static, + D.projection ⊢ D.projection_tangent, + D.translation_2d, + D.background, + D.weight, + D.point_weights, + ) # default arguments - test_rrule(raster, D.grid_size_2d, D.points_static, D.projection ⊢ D.projection_tangent, D.translation_2d) + test_rrule( + raster, + D.grid_size_2d, + D.points_static, + D.projection ⊢ D.projection_tangent, + D.translation_2d, + ) end @testitem "ChainRules batch" begin using ChainRulesTestUtils include("data.jl") - test_rrule(raster, D.grid_size_3d, D.points_static, D.rotations_static ⊢ D.rotation_tangents_static, D.translations_3d_static, D.backgrounds, D.weights, D.point_weights) + test_rrule( + raster, + D.grid_size_3d, + D.points_static, + D.rotations_static ⊢ D.rotation_tangents_static, + D.translations_3d_static, + D.backgrounds, + D.weights, + D.point_weights, + ) # default arguments - test_rrule(raster, D.grid_size_3d, D.points_static, D.rotations_static ⊢ D.rotation_tangents_static, D.translations_3d_static) - + test_rrule( + raster, + D.grid_size_3d, + D.points_static, + D.rotations_static ⊢ D.rotation_tangents_static, + D.translations_3d_static, + ) - test_rrule(raster, D.grid_size_2d, D.points_static, D.projections_static ⊢ D.projection_tangents_static, D.translations_2d_static, D.backgrounds, D.weights, D.point_weights) + test_rrule( + raster, + D.grid_size_2d, + D.points_static, + D.projections_static ⊢ D.projection_tangents_static, + D.translations_2d_static, + D.backgrounds, + D.weights, + D.point_weights, + ) # default arguments - test_rrule(raster, D.grid_size_2d, D.points_static, D.projections_static ⊢ D.projection_tangents_static, D.translations_2d_static) + test_rrule( + raster, + D.grid_size_2d, + D.points_static, + D.projections_static ⊢ D.projection_tangents_static, + D.translations_2d_static, + ) end \ No newline at end of file diff --git a/test/cuda.jl b/test/cuda.jl index 1934b87..df74c43 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -5,18 +5,34 @@ include("data.jl") include("util.jl") cuda_available = CUDA.functional() - + # no projection - args = (D.grid_size_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights, D.more_point_weights) - @test cuda_cpu_agree(raster, args...) skip=!cuda_available + args = ( + D.grid_size_3d, + D.more_points, + D.rotations_static, + D.translations_3d_static, + D.backgrounds, + D.weights, + D.more_point_weights, + ) + @test cuda_cpu_agree(raster, args...) skip = !cuda_available # default arguments args = (D.grid_size_3d, D.more_points, D.rotations_static, D.translations_3d_static) - @test cuda_cpu_agree(raster, args...) skip=!cuda_available + @test cuda_cpu_agree(raster, args...) skip = !cuda_available # projection - args = (D.grid_size_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights, D.more_point_weights) - @test cuda_cpu_agree(raster, args...) skip=!cuda_available + args = ( + D.grid_size_2d, + D.more_points, + D.projections_static, + D.translations_2d_static, + D.backgrounds, + D.weights, + D.more_point_weights, + ) + @test cuda_cpu_agree(raster, args...) skip = !cuda_available end @testitem "CUDA backward" begin @@ -28,17 +44,33 @@ end # no projection ds_dout_3d = randn(D.grid_size_3d..., D.batch_size) - args = (ds_dout_3d, D.more_points, D.rotations_static, D.translations_3d_static, D.backgrounds, D.weights, D.more_point_weights) - @test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available + args = ( + ds_dout_3d, + D.more_points, + D.rotations_static, + D.translations_3d_static, + D.backgrounds, + D.weights, + D.more_point_weights, + ) + @test cuda_cpu_agree(raster_pullback!, args...) skip = !cuda_available # default arguments args = (ds_dout_3d, D.more_points, D.rotations_static, D.translations_3d_static) - @test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available + @test cuda_cpu_agree(raster_pullback!, args...) skip = !cuda_available # projection ds_dout_2d = randn(D.grid_size_2d..., D.batch_size) - args = (ds_dout_2d, D.more_points, D.projections_static, D.translations_2d_static, D.backgrounds, D.weights, D.more_point_weights) - @test cuda_cpu_agree(raster_pullback!, args...) skip=!cuda_available + args = ( + ds_dout_2d, + D.more_points, + D.projections_static, + D.translations_2d_static, + D.backgrounds, + D.weights, + D.more_point_weights, + ) + @test cuda_cpu_agree(raster_pullback!, args...) skip = !cuda_available end # The follwing currently fails. diff --git a/test/data.jl b/test/data.jl index 682a2c3..ae4b9b9 100644 --- a/test/data.jl +++ b/test/data.jl @@ -7,7 +7,7 @@ function batch_size_for_test() while (Threads.nthreads() > 1) && (batch_size % Threads.nthreads() == 0) batch_size += 1 end - batch_size + return batch_size end const P = @SMatrix Float64[ @@ -23,25 +23,30 @@ const points = [0.4 * randn(3) for _ in 1:10] const points_static = SVector{3}.(points) const points_array = Matrix{Float64}(undef, 3, length(points)) eachcol(points_array) .= points -const points_reinterp = reinterpret(reshape, SVector{3, Float64}, points_array) +const points_reinterp = reinterpret(reshape, SVector{3,Float64}, points_array) const more_points = [0.4 * @SVector randn(3) for _ in 1:100_000] const rotation = rand(RotMatrix3{Float64}) const rotations_static = rand(RotMatrix3{Float64}, batch_size)::Vector{<:StaticMatrix} const rotations = (Array.(rotations_static))::Vector{Matrix{Float64}} -const rotations_array = Array{Float64, 3}(undef, 3, 3, batch_size) +const rotations_array = Array{Float64,3}(undef, 3, 3, batch_size) eachslice(rotations_array; dims=3) .= rotations -const rotations_reinterp = reinterpret(reshape, SMatrix{3, 3, Float64, 9}, reshape(rotations_array, 9, :)) +const rotations_reinterp = reinterpret( + reshape, SMatrix{3,3,Float64,9}, reshape(rotations_array, 9, :) +) const rotation_tangent = Array(rand(RotMatrix3)) -const rotation_tangents_static = rand(RotMatrix3{Float64}, batch_size)::Vector{<:StaticMatrix} +const rotation_tangents_static = + rand(RotMatrix3{Float64}, batch_size)::Vector{<:StaticMatrix} const rotation_tangents = (Array.(rotation_tangents_static))::Vector{Matrix{Float64}} const projection = P * rand(RotMatrix3) const projections_static = Ref(P) .* rand(RotMatrix3{Float64}, batch_size) const projections = (Array.(projections_static))::Vector{Matrix{Float64}} -const projections_array = Array{Float64, 3}(undef, 2, 3, batch_size) +const projections_array = Array{Float64,3}(undef, 2, 3, batch_size) eachslice(projections_array; dims=3) .= projections -const projections_reinterp = reinterpret(reshape, SMatrix{2, 3, Float64, 6}, reshape(projections_array, 6, :)) +const projections_reinterp = reinterpret( + reshape, SMatrix{2,3,Float64,6}, reshape(projections_array, 6, :) +) const projection_tangent = Array(P * rand(RotMatrix3)) const projection_tangents_static = Ref(P) .* rand(RotMatrix3{Float64}, batch_size) const projection_tangents = (Array.(projection_tangents_static))::Vector{Matrix{Float64}} @@ -52,12 +57,16 @@ const translations_3d_static = [0.1 * @SVector randn(3) for _ in 1:batch_size] const translations_3d = (Array.(translations_3d_static))::Vector{Vector{Float64}} const translations_3d_array = Matrix{Float64}(undef, 3, batch_size) eachcol(translations_3d_array) .= translations_3d -const translations_3d_reinterp = reinterpret(reshape, SVector{3, Float64}, translations_3d_array) +const translations_3d_reinterp = reinterpret( + reshape, SVector{3,Float64}, translations_3d_array +) const translations_2d_static = [0.1 * @SVector randn(2) for _ in 1:batch_size] const translations_2d = (Array.(translations_2d_static))::Vector{Vector{Float64}} const translations_2d_array = Matrix{Float64}(undef, 2, batch_size) eachcol(translations_2d_array) .= translations_2d -const translations_2d_reinterp = reinterpret(reshape, SVector{2, Float64}, translations_2d_array) +const translations_2d_reinterp = reinterpret( + reshape, SVector{2,Float64}, translations_2d_array +) const background = 0.1 const backgrounds = collect(1:1.0:batch_size) @@ -65,11 +74,11 @@ const backgrounds = collect(1:1.0:batch_size) const weight = rand() const weights = 10 .* rand(batch_size) -const point_weights = let +const point_weights = let w = rand(length(points)) w ./ sum(w) end -const more_point_weights = let +const more_point_weights = let w = rand(length(more_points)) w ./ sum(w) end diff --git a/test/runtests.jl b/test/runtests.jl index d2878d7..ff4bd86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,8 @@ -using TestItemRunner: @run_package_tests +using TestItemRunner: @run_package_tests, @testitem +@testitem "Aqua.test_all" begin + import Aqua + Aqua.test_all(DiffPointRasterisation) +end @run_package_tests # filter=ti-> occursin("CUDA", ti.name) \ No newline at end of file diff --git a/test/util.jl b/test/util.jl index 0db64a8..daace99 100644 --- a/test/util.jl +++ b/test/util.jl @@ -3,18 +3,16 @@ function run_cuda(f, args...) return f(cu_args...) end - function cuda_cpu_agree(f, args...) out_cpu = f(args...) out_cuda = run_cuda(f, args...) - is_approx_equal(out_cuda, out_cpu) + return is_approx_equal(out_cuda, out_cpu) end function is_approx_equal(actual::AbstractArray, expected::AbstractArray) - Array(actual) ≈ expected + return Array(actual) ≈ expected end - function is_approx_equal(actual::NamedTuple, expected::NamedTuple) actual_cpu = adapt(Array, actual) for prop in propertynames(expected) @@ -22,7 +20,9 @@ function is_approx_equal(actual::NamedTuple, expected::NamedTuple) actual_elem = getproperty(actual_cpu, prop) expected_elem = getproperty(expected, prop) if !(actual_elem ≈ expected_elem) - throw("Values differ:\nActual: $(string(actual_elem)) \nExpected: $(string(expected_elem))") + throw( + "Values differ:\nActual: $(string(actual_elem)) \nExpected: $(string(expected_elem))", + ) return false end catch e @@ -30,5 +30,5 @@ function is_approx_equal(actual::NamedTuple, expected::NamedTuple) rethrow() end end - true + return true end \ No newline at end of file