Skip to content

Commit

Permalink
Merge pull request #10 from microscopic-image-analysis/weights
Browse files Browse the repository at this point in the history
Weights
  • Loading branch information
trahflow authored Apr 17, 2024
2 parents 28077f6 + 46172c4 commit 4ccfcc3
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 183 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffPointRasterisation"
uuid = "f984992d-3c45-4382-99a1-cf20f5c47c61"
authors = ["Wolfhart Feldmeier <wolfhart.feldmeier@uni-jena.de>"]
version = "0.1.0"
version = "0.2.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
143 changes: 97 additions & 46 deletions ext/DiffPointRasterisationCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@ module DiffPointRasterisationCUDAExt

using DiffPointRasterisation, CUDA
using ArgCheck
using FillArrays
using StaticArrays


const CuOrFillArray{T, N} = Union{CuArray{T, N}, FillArrays.AbstractFill{T, N}}


const CuOrFillVector{T} = CuOrFillArray{T, 1}


function raster_pullback_kernel!(
::Type{T},
ds_dout,
points::AbstractVector{<:StaticVector{N_in}},
rotations::AbstractVector{<:StaticMatrix{N_out, N_in, TR}},
translations::AbstractVector{<:StaticVector{N_out, TT}},
weights,
out_weights,
point_weights,
shifts,
scale,
# outputs:
ds_dpoints,
ds_drotation,
ds_dtranslation,
ds_dweight,
ds_dout_weight,
ds_dpoint_weight,

) where {T, TR, TT, N_in, N_out}
n_voxel = blockDim().z
Expand All @@ -45,16 +54,22 @@ function raster_pullback_kernel!(
batch_idx = (blockIdx().y - 1) * batchsize_per_workgroup + b
in_batch = batch_idx <= length(rotations)

dimension = (N_out, n_voxel, batchsize_per_workgroup)
ds_dpoint_rot = CuDynamicSharedArray(T, dimension)
ds_dpoint_local = CuDynamicSharedArray(T, (N_in, batchsize_per_workgroup), sizeof(T) * prod(dimension))
dimension1 = (N_out, n_voxel, batchsize_per_workgroup)
ds_dpoint_rot_shared = CuDynamicSharedArray(T, dimension1)
offset = sizeof(T) * prod(dimension1)
dimension2 = (N_in, batchsize_per_workgroup)
ds_dpoint_shared = CuDynamicSharedArray(T, dimension2, offset)
dimension3 = (n_voxel, batchsize_per_workgroup)
offset += sizeof(T) * prod(dimension2)
ds_dpoint_weight_shared = CuDynamicSharedArray(T, dimension3, offset)

rotation = @inbounds in_batch ? rotations[batch_idx] : @SMatrix zeros(TR, N_in, N_in)
point = @inbounds points[point_idx]
point_weight = @inbounds point_weights[point_idx]

if in_batch
translation = @inbounds translations[batch_idx]
weight = @inbounds weights[batch_idx]
out_weight = @inbounds out_weights[batch_idx]
shift = @inbounds shifts[neighbor_voxel_id]
origin = (-@SVector ones(TT, N_out)) - translation

Expand All @@ -75,35 +90,38 @@ function raster_pullback_kernel!(
ds_dout[voxel_idx],
)

factor = ds_dout[voxel_idx] * weight
factor = ds_dout[voxel_idx] * out_weight * point_weight
ds_dcoord_part = SVector(factor .* ntuple(n -> DiffPointRasterisation.interpolation_weight(n, N_out, deltas, shift), Val(N_out)))
@inbounds ds_dpoint_rot[:, s, b] .= ds_dcoord_part .* scale
@inbounds ds_dpoint_rot_shared[:, s, b] .= ds_dcoord_part .* scale
else
@inbounds ds_dpoint_rot[:, s, b] .= zero(T)
@inbounds ds_dpoint_rot_shared[:, s, b] .= zero(T)
end

@inbounds CUDA.@atomic ds_dweight[batch_idx] += ds_dweight_local
@inbounds ds_dpoint_weight_shared[s, b] = ds_dweight_local * out_weight
ds_dout_weight_local = ds_dweight_local * point_weight
@inbounds CUDA.@atomic ds_dout_weight[batch_idx] += ds_dout_weight_local
else
@inbounds ds_dpoint_rot[:, s, b] .= zero(T)
@inbounds ds_dpoint_weight_shared[s, b] = zero(T)
@inbounds ds_dpoint_rot_shared[:, s, b] .= zero(T)
end

# parallel summation of ds_dpoint_rot over neighboring-voxel dimension
# parallel summation of ds_dpoint_rot_shared over neighboring-voxel dimension
# for a given thread-local batch index
stride = 1
@inbounds while stride < n_voxel
sync_threads()
idx = 2 * stride * (s - 1) + 1
dim = 1
while dim <= N_out
if idx <= n_voxel
other_val = if idx + stride <= n_voxel
ds_dpoint_rot[dim, idx + stride, b]
if idx <= n_voxel
dim = 1
while dim <= N_out
other_val_p = if idx + stride <= n_voxel
ds_dpoint_rot_shared[dim, idx + stride, b]
else
zero(T)
end
ds_dpoint_rot[dim, idx, b] += other_val
ds_dpoint_rot_shared[dim, idx, b] += other_val_p
dim += 1
end
dim += 1
end
stride *= 2
end
Expand All @@ -113,7 +131,7 @@ function raster_pullback_kernel!(
if in_batch
dim = s
if dim <= N_out
coef = ds_dpoint_rot[dim, 1, b]
coef = ds_dpoint_rot_shared[dim, 1, b]
@inbounds CUDA.@atomic ds_dtranslation[dim, batch_idx] += coef
j = 1
while j <= N_in
Expand All @@ -130,29 +148,45 @@ function raster_pullback_kernel!(
val = zero(T)
j = 1
while j <= N_out
@inbounds val += rotation[j, dim] * ds_dpoint_rot[j, 1, b]
@inbounds val += rotation[j, dim] * ds_dpoint_rot_shared[j, 1, b]
j += 1
end
@inbounds ds_dpoint_local[dim, b] = val
@inbounds ds_dpoint_shared[dim, b] = val
dim += n_voxel
end

# parallel summation of ds_dpoint_local over batch dimension
# parallel summation of ds_dpoint_shared over batch dimension
stride = 1
@inbounds while stride < batchsize_per_workgroup
sync_threads()
idx = 2 * stride * (b - 1) + 1
dim = s
while dim <= N_in
if idx <= batchsize_per_workgroup
other_val = if idx + stride <= batchsize_per_workgroup
ds_dpoint_local[dim, idx + stride]
if idx <= batchsize_per_workgroup
dim = s
while dim <= N_in
other_val_p = if idx + stride <= batchsize_per_workgroup
ds_dpoint_shared[dim, idx + stride]
else
zero(T)
end
ds_dpoint_local[dim, idx] += other_val
ds_dpoint_shared[dim, idx] += other_val_p
dim += n_voxel
end
end
stride *= 2
end

# parallel summation of ds_dpoint_weight_shared over voxel and batch dimension
stride = 1
@inbounds while stride < n_threads_per_workgroup
sync_threads()
idx = 2 * stride * (thread - 1) + 1
if idx <= n_threads_per_workgroup
other_val_w = if idx + stride <= n_threads_per_workgroup
ds_dpoint_weight_shared[idx + stride]
else
zero(T)
end
dim += n_voxel
ds_dpoint_weight_shared[idx] += other_val_w
end
stride *= 2
end
Expand All @@ -161,12 +195,18 @@ function raster_pullback_kernel!(

dim = thread
while dim <= N_in
val = ds_dpoint_local[dim, 1]
val = ds_dpoint_shared[dim, 1]
# batch might be split across blocks, so need atomic add
@inbounds CUDA.@atomic ds_dpoints[dim, point_idx] += val
dim += n_threads_per_workgroup
end

if thread == 1
val_w = ds_dpoint_weight_shared[1, 1]
# batch might be split across blocks, so need atomic add
@inbounds CUDA.@atomic ds_dpoint_weight[point_idx] += val_w
end

nothing
end

Expand All @@ -177,8 +217,10 @@ raster_pullback!(
rotation::StaticMatrix{N_out, N_in, <:Number},
translation::StaticVector{N_out, <:Number},
background::Number,
weight::Number,
ds_dpoints::AbstractMatrix{<:Number};
out_weight::Number,
point_weight::CuOrFillVector{<:Number},
ds_dpoints::AbstractMatrix{<:Number},
ds_dpoint_weight::AbstractVector{<:Number};
kwargs...
) where {N_in, N_out} = error("Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays")

Expand All @@ -188,22 +230,25 @@ function DiffPointRasterisation.raster_pullback!(
points::CuVector{<:StaticVector{N_in, <:Number}},
rotation::CuVector{<:StaticMatrix{N_out, N_in, <:Number}},
translation::CuVector{<:StaticVector{N_out, <:Number}},
background::CuVector{<:Number},
weight::CuVector{<:Number},
background::CuOrFillVector{<:Number},
out_weight::CuOrFillVector{<:Number},
point_weight::CuOrFillVector{<:Number},
ds_dpoints::CuMatrix{TP},
ds_drotation::CuArray{TR, 3},
ds_dtranslation::CuMatrix{TT},
ds_dbackground::CuVector{<:Number},
ds_dweight::CuVector{TW},
) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, TW<:Number}
T = promote_type(eltype(ds_dout), TP, TR, TT, TW)
ds_dout_weight::CuVector{OW},
ds_dpoint_weight::CuVector{PW},
) where {N_in, N_out, N_out_p1, TP<:Number, TR<:Number, TT<:Number, OW<:Number, PW<:Number}
T = promote_type(eltype(ds_dout), TP, TR, TT, OW, PW)
batch_axis = axes(ds_dout, N_out_p1)
@argcheck N_out == N_out_p1 - 1
@argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(weight, 1)
@argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dweight, 1)
@argcheck batch_axis == axes(rotation, 1) == axes(translation, 1) == axes(background, 1) == axes(out_weight, 1)
@argcheck batch_axis == axes(ds_drotation, 3) == axes(ds_dtranslation, 2) == axes(ds_dbackground, 1) == axes(ds_dout_weight, 1)
@argcheck N_out == N_out_p1 - 1

n_points = length(points)
@argcheck length(ds_dpoint_weight) == n_points
batch_size = length(batch_axis)

ds_dbackground = vec(sum!(reshape(ds_dbackground, ntuple(_ -> 1, Val(N_out))..., batch_size), ds_dout))
Expand All @@ -214,36 +259,42 @@ function DiffPointRasterisation.raster_pullback!(
ds_dpoints = fill!(ds_dpoints, zero(TP))
ds_drotation = fill!(ds_drotation, zero(TR))
ds_dtranslation = fill!(ds_dtranslation, zero(TT))
ds_dweight = fill!(ds_dweight, zero(TW))
ds_dout_weight = fill!(ds_dout_weight, zero(OW))
ds_dpoint_weight = fill!(ds_dpoint_weight, zero(PW))

args = (T, ds_dout, points, rotation, translation, weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dweight)
args = (T, ds_dout, points, rotation, translation, out_weight, point_weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dout_weight, ds_dpoint_weight)

ndrange = (n_points, batch_size, 2^N_out)

workgroup_size(threads) = (1, min(threads ÷ (2^N_out), batch_size), 2^N_out)

function shmem(threads)
batchsize_per_workgroup = workgroup_size(threads)[2]
(N_out * threads + N_in * batchsize_per_workgroup) * sizeof(T)
_, bs_p_wg, n_voxel = workgroup_size(threads)
((N_out + 1) * n_voxel + N_in) * bs_p_wg * sizeof(T)
# ((N_out + 1) * threads + N_in * bs_p_wg) * sizeof(T)
end

let kernel = @cuda launch=false raster_pullback_kernel!(args...)
config = CUDA.launch_configuration(kernel.fun; shmem)
workgroup_sz = workgroup_size(config.threads)
blocks = cld.(ndrange, workgroup_sz)
kernel(args...; threads=workgroup_sz, blocks=blocks, shmem=shmem(prod(workgroup_sz)))

kernel(args...; threads=workgroup_sz, blocks=blocks, shmem=shmem(config.threads))
end

return (;
points=ds_dpoints,
rotation=ds_drotation,
translation=ds_dtranslation,
background=ds_dbackground,
weight=ds_dweight,
out_weight=ds_dout_weight,
point_weight=ds_dpoint_weight,
)
end


DiffPointRasterisation.default_ds_dpoints_batched(points::CuVector{<:AbstractVector{TP}}, N_in, batch_size) where {TP<:Number} = similar(points, TP, (N_in, length(points)))

DiffPointRasterisation.default_ds_dpoint_weight_batched(points::CuVector{<:AbstractVector{<:Number}}, T, batch_size) = similar(points, T)

end # module
Loading

0 comments on commit 4ccfcc3

Please sign in to comment.