Skip to content

Commit

Permalink
Merge pull request #34 from SciNim/rbf
Browse files Browse the repository at this point in the history
fix rbf bug
  • Loading branch information
HugoGranstrom authored Jan 1, 2023
2 parents a17b80c + 75a8774 commit 04d3227
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
24 changes: 18 additions & 6 deletions src/numericalnim/rbf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int =
result += (km(point, i, grid.gridDelta) - 1) * grid.gridSize ^ (grid.gridDim - i - 1)

proc constructMeshedPatches*[T](grid: RbfGrid[T]): Tensor[float] =
meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim))
if grid.gridSize == 1:
@[@[0.5].cycle(grid.gridDim)].toTensor
else:
meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim))

template dist2(p1, p2: Tensor[float]): float =
var result = 0.0
Expand Down Expand Up @@ -85,14 +88,16 @@ proc findAllBetween*[T](grid: RbfGrid[T], x: Tensor[float], rho1, rho2: float):
if rho1*rho1 <= d and d <= rho2*rho2:
result.add i

proc calcGridSize(nPoints, nDims: int, gridSize: int): int =
if gridSize > 0:
gridSize
else:
max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1)

proc newRbfGrid*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0): RbfGrid[T] =
let nPoints = points.shape[0]
let nDims = points.shape[1]
let gridSize =
if gridSize > 0:
gridSize
else:
max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1)
let gridSize = calcGridSize(nPoints, nDims, gridSize)
let delta = 1 / gridSize
result = RbfGrid[T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, indices: newSeq[seq[int]](gridSize ^ nDims))
for row in 0 ..< nPoints:
Expand Down Expand Up @@ -152,13 +157,20 @@ proc newRbf*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbf
## epsilon: shape parameter. Default 1.
assert points.shape[0] == values.shape[0]
assert points.shape.len == 2 and values.shape.len == 2

let upperLimit = max(points, 0)
let lowerLimit = min(points, 0)
let limits = (upper: upperLimit, lower: lowerLimit)
let scaledPoints = points.scalePoint(limits)

let nPoints = points.shape[0]
let nDims = points.shape[1]
let gridSize = calcGridSize(nPoints, nDims, gridSize)

let dataGrid = newRbfGrid(scaledPoints, values, gridSize)
let patchPoints = dataGrid.constructMeshedPatches()
let nPatches = patchPoints.shape[0]

var patchRbfs: seq[RbfBaseType[T]] #= newTensor[RbfBaseType[T]](nPatches, 1)
var patchIndices: seq[int]
for i in 0 ..< nPatches:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,6 @@ test "rbf f=x*y*z":
let yTest = rbfObj.eval(xTest)
let yCorrect = xTest[_, 0] *. xTest[_, 1] *. xTest[_, 2]
for x in abs(yCorrect - yTest):
check x < 0.03
check mean_squared_error(yTest, yCorrect) < 1e-4
check x < 0.11
check mean_squared_error(yTest, yCorrect) < 1.4e-4

0 comments on commit 04d3227

Please sign in to comment.