Skip to content

Commit

Permalink
add scaling and gather
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao committed Jan 12, 2024
1 parent ffd5b5c commit 5b68c50
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/ChebParticleMesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module ChebParticleMesh

using LinearAlgebra, FastTransforms, SpecialFunctions, LoopVectorization, FFTW

export horner
export horner, funcpack
export Wkb, FWkb
export ChebCoef, pwcheb_eval, pwcheb_eval!, f_eval, f_eval!
export GridInfo, GridBox, PadIndex, ImageIndex
export id_image2pad, id_pad2image, grid_revise_pad!
Expand Down
61 changes: 61 additions & 0 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
function gather_single(q::T, pos::NTuple{N, T}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

potential_i = zero(T)

cheb_value = gridbox.cheb_value
near_id_image = image_grid_id(pos, gridinfo)
near_pos_image = image_grid_pos(near_id_image, gridinfo)
for i in 1:N
dx = pos[i] - near_pos_image[i]
pwcheb_eval!(dx, cheb_value[i], chebcoefs[i])
end

for id in Iterators.product([- gridinfo.w[i] : gridinfo.w[i] for i in 1:N]...)
potential_i += real(gridbox.image_grid[(near_id_image.id .+ id)...]) * prod(cheb_value[i][id[i] + gridinfo.w[i] + 1] for i in 1:N)
end

return q * 4π * prod(gridinfo.h) * potential_i
end

function gather(qs::Vector{T}, poses::Vector{NTuple{N, T}}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

@assert length(qs) == length(poses)

potential = zero(T)
for i in 1:length(qs)
potential += gather_single(qs[i], poses[i], gridinfo, gridbox, chebcoefs)
end

return potential
end

function gather_single_direct(q::T, pos::NTuple{N, T}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

potential_i = zero(T)

cheb_value = gridbox.cheb_value
near_id_image = image_grid_id(pos, gridinfo)
near_pos_image = image_grid_pos(near_id_image, gridinfo)
for i in 1:N
dx = pos[i] - near_pos_image[i]
f_eval!(dx, cheb_value[i], chebcoefs[i])
end

for id in Iterators.product([- gridinfo.w[i] : gridinfo.w[i] for i in 1:N]...)
potential_i += real(gridbox.image_grid[(near_id_image.id .+ id)...]) * prod(cheb_value[i][id[i] + gridinfo.w[i] + 1] for i in 1:N)
end

return q * 4π * prod(gridinfo.h) * potential_i
end

function gather_direct(qs::Vector{T}, poses::Vector{NTuple{N, T}}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

@assert length(qs) == length(poses)

potential = zero(T)
for i in 1:length(qs)
potential += gather_single_direct(qs[i], poses[i], gridinfo, gridbox, chebcoefs)
end

return potential
end
4 changes: 2 additions & 2 deletions src/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function GridInfo(N_real::NTuple{N, Int}, w::NTuple{N, Int}, periodicity::NTuple
end

function GridBox(grid_info::GridInfo{N, T}) where{N, T<:Union{Float32, Float64}}
pad_grid = zeros(T, grid_info.N_pad...)
pad_grid = zeros(Complex{T}, grid_info.N_pad...)

image_grid = view(pad_grid, [mod1.(1 + grid_info.pad[i] - grid_info.image[i]:grid_info.pad[i] + grid_info.N_real[i] + grid_info.image[i], grid_info.N_pad[i]) for i in 1:N]...)

Expand All @@ -54,7 +54,7 @@ function grid_revise_pad!(gridbox::GridBox{N, T}) where{N, T}
function grid_revise_pad!(gridbox::GridBox{N, T}) where{N, T}

for i in eachindex(gridbox.pad_grid)
gridbox.pad_grid[i] = zero(T)
gridbox.pad_grid[i] = zero(Complex{T})
end

return nothing
Expand Down
4 changes: 2 additions & 2 deletions src/interpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
function interpolate!(qs::Vector{T}, poses::Vector{NTuple{N, T}}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

@assert length(qs) == length(poses)
grid_revise_image!(gridbox)
grid_revise_pad!(gridbox)

for i in 1:length(qs)
interpolate_single!(qs[i], poses[i], gridinfo, gridbox, chebcoefs)
Expand Down Expand Up @@ -47,7 +47,7 @@ end
function interpolate_direct!(qs::Vector{T}, poses::Vector{NTuple{N, T}}, gridinfo::GridInfo{N, T}, gridbox::GridBox{N, T}, chebcoefs::NTuple{N, ChebCoef{T}}) where{N, T}

@assert length(qs) == length(poses)
grid_revise_image!(gridbox)
grid_revise_pad!(gridbox)

for i in 1:length(qs)
interpolate_single_direct!(qs[i], poses[i], gridinfo, gridbox, chebcoefs)
Expand Down
17 changes: 17 additions & 0 deletions src/scaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function ScalingFactor(f::Function, gridinfo::GridInfo{N, T}) where{N, T}
factors = zeros(Complex{T}, gridinfo.N_pad...)

for i in Iterators.product([1:gridinfo.N_pad[d] for d in 1:N]...)
k = [gridinfo.k[d][i[d]] for d in 1:N]
factors[i...] = Complex{T}(f(k...))
end

return ScalingFactor{N, T}(f, factors)
end

function scale!(gridbox::GridBox{N, T}, scalingfactor::ScalingFactor{N, T}) where{N, T}

gridbox.pad_grid .*= scalingfactor.factors

return gridbox
end
9 changes: 7 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ struct GridInfo{N, T}
end

mutable struct GridBox{N, T}
pad_grid::Array{T, N}
image_grid::SubArray{T, N}
pad_grid::Array{Complex{T}, N}
image_grid::SubArray{Complex{T}, N}

cheb_value::Vector{Array{T, 1}}
end
Expand All @@ -36,4 +36,9 @@ end

struct ImageIndex{N} <: AbstractIndex
id::NTuple{N, Int}
end

struct ScalingFactor{N, T}
f::Function
factors::Array{Complex{T}, N}
end
46 changes: 46 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,49 @@
return ex
end

"""
function fft!(gridbox::GridBox{N, T}) where{N, T}
calculate the inplace fft on the padded grid
"""
function FFTW.fft!(gridbox::GridBox{N, T}) where{N, T}

fft!(gridbox.pad_grid)

return gridbox
end

"""
function ifft!(gridbox::GridBox{N, T}) where{N, T}
calculate the inplace ifft on the padded grid
"""
function FFTW.ifft!(gridbox::GridBox{N, T}) where{N, T}

ifft!(gridbox.pad_grid)

return gridbox
end

function funcpack(f::Function, args::Vector)
g = x -> f(x, args...)
return g
end

"""
function Wkb(x::T, width::T, β::T) where{T<:Real}
WKB kernel function, used as window function for the Chebyshev interpolation
"""
function Wkb(x::T, width::T, β::T) where{T<:Real}
return T(besseli(0, β * sqrt(one(T) - (x / width)^2)) / besseli(0, β) * (abs(x) <= width))
end

"""
function FWkb(k::T, width::T, β::T) where{T<:Real}
Fourier transform of Wkb kernel function
"""
function FWkb(k::T, width::T, β::T) where{T}
return T(2 * width * sinh(sqrt^2 - (k * width)^2)) / (besseli(0, β) * sqrt^2 - (k * width)^2)))
end

0 comments on commit 5b68c50

Please sign in to comment.