Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate at arbitrary points #109

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
3 changes: 3 additions & 0 deletions src/Bcube.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using Printf # just for tmp vtk, to be removed
# import LinearSolve: solve, solve!, LinearProblem
import LinearSolve
using Symbolics # used for generation of Lagrange shape functions
using NearestNeighbors

const MAX_LENGTH_STATICARRAY = (10^6)

Expand Down Expand Up @@ -96,6 +97,8 @@ include("./mapping/mapping.jl")
include("./mapping/ref2phys.jl")
export get_cell_centers

include("mapping/findpoint.jl")

include("./cellfunction/eval_point.jl")

include("./cellfunction/cellfunction.jl")
Expand Down
47 changes: 47 additions & 0 deletions src/mapping/findpoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
struct PointFinder{T, M}
tree::T
mesh::M
end

function PointFinder(mesh::AbstractMesh)
xc0 = get_cell_centers(mesh)
xc = reshape(
[xc0[n][idim] for n in 1:length(xc0) for idim in 1:spacedim(mesh)],
spacedim(mesh),
length(xc0),
)
tree = KDTree(xc; leafsize = 10)
PointFinder{typeof(tree), typeof(mesh)}(tree, mesh)
end

function find_cell(pf::PointFinder, point)
idxs, dists = knn(pf.tree, point, 1)
c2c_n = connectivity_cell2cell_by_nodes(pf.mesh)
cellids = [idxs[1], c2c_n[idxs[1]]...]

for icell in cellids
cinfo = CellInfo(pf.mesh, icell)
cpoint = CellPoint(point, cinfo, PhysicalDomain())
isIn = point_in_cell(cinfo, cpoint)
isIn && return icell
end
return nothing

Check warning on line 28 in src/mapping/findpoint.jl

View check run for this annotation

Codecov / codecov/patch

src/mapping/findpoint.jl#L28

Added line #L28 was not covered by tests
end

function point_in_cell(cinfo, cpoint)
cpoint_ref = change_domain(cpoint, ReferenceDomain())
point_in_shape(shape(celltype(cinfo)), get_coords(cpoint_ref))
end

function point_in_shape(s::Square, x)
get_coords(s)[1][1] ≤ x[1] ≤ get_coords(s)[3][1] &&
get_coords(s)[1][2] ≤ x[2] ≤ get_coords(s)[3][2]
end

function point_in_shape(shape::AbstractShape, x)
for (normal, f2n) in zip(normals(shape), faces2nodes(shape))
dx = (x - get_coords(shape)[first(f2n)])
(dx ⋅ normal > 0) && (return false)
end
return true
end
30 changes: 30 additions & 0 deletions test/interpolation/test_cellfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,34 @@ end
@test v1 ⋅ u ≈ v1_in_2 ⋅ u
@test abs(ν2 ⋅ v1_in_2) < 1e-16
end

@testset "Point interpolation" begin
degree = 2
path = joinpath(tempdir, "mesh.msh")
gen_rectangle_mesh_with_tri_and_quad(path; nx = 10, ny = 10, xc = 0.5, yc = 0.5)
mesh = read_msh(path)
Uspace = TrialFESpace(FunctionSpace(:Lagrange, degree), mesh; size = 2)
Vspace = TestFESpace(Uspace)
dΩ = Measure(CellDomain(mesh), 2 * degree + 1)

pointFinder = Bcube.PointFinder(mesh)

u = FEFunction(Uspace)
f1((x, y)) = SA[x + 4y + 1, 7x * x + 2y * x + 2y * y + 3y + 12]
projection_l2!(u, PhysicalFunction(f1), dΩ)
npoints = 20
xp = [rand(SVector{2}) for i in 1:npoints]
icells = map(Base.Fix1(Bcube.find_cell, pointFinder), xp)
cpoints = map(
(x, i) ->
Bcube.CellPoint(x, Bcube.CellInfo(mesh, i), Bcube.PhysicalDomain()),
xp,
icells,
)
up = map(cpoints) do cpoint
uᵢ = Bcube.materialize(u, Bcube.get_cellinfo(cpoint))
Bcube.materialize(uᵢ, cpoint)
end
@test all(f1.(xp) .≈ up)
end
end
Loading