From 75a8774681ad5edef6bfda6c68ca13d02e2469e2 Mon Sep 17 00:00:00 2001 From: HugoGranstrom <5092565+HugoGranstrom@users.noreply.github.com> Date: Sun, 1 Jan 2023 16:22:55 +0100 Subject: [PATCH] fix bug --- src/numericalnim/rbf.nim | 24 ++++++++++++++++++------ tests/test_interpolate.nim | 4 ++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/numericalnim/rbf.nim b/src/numericalnim/rbf.nim index 0276fb2..acdbf30 100644 --- a/src/numericalnim/rbf.nim +++ b/src/numericalnim/rbf.nim @@ -56,7 +56,10 @@ proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int = result += (km(point, i, grid.gridDelta) - 1) * grid.gridSize ^ (grid.gridDim - i - 1) proc constructMeshedPatches*[T](grid: RbfGrid[T]): Tensor[float] = - meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim)) + if grid.gridSize == 1: + @[@[0.5].cycle(grid.gridDim)].toTensor + else: + meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim)) template dist2(p1, p2: Tensor[float]): float = var result = 0.0 @@ -85,14 +88,16 @@ proc findAllBetween*[T](grid: RbfGrid[T], x: Tensor[float], rho1, rho2: float): if rho1*rho1 <= d and d <= rho2*rho2: result.add i +proc calcGridSize(nPoints, nDims: int, gridSize: int): int = + if gridSize > 0: + gridSize + else: + max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1) + proc newRbfGrid*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0): RbfGrid[T] = let nPoints = points.shape[0] let nDims = points.shape[1] - let gridSize = - if gridSize > 0: - gridSize - else: - max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1) + let gridSize = calcGridSize(nPoints, nDims, gridSize) let delta = 1 / gridSize result = RbfGrid[T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, indices: newSeq[seq[int]](gridSize ^ nDims)) for row in 0 ..< nPoints: @@ -152,13 +157,20 @@ proc newRbf*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbf ## epsilon: shape parameter. Default 1. assert points.shape[0] == values.shape[0] assert points.shape.len == 2 and values.shape.len == 2 + let upperLimit = max(points, 0) let lowerLimit = min(points, 0) let limits = (upper: upperLimit, lower: lowerLimit) let scaledPoints = points.scalePoint(limits) + + let nPoints = points.shape[0] + let nDims = points.shape[1] + let gridSize = calcGridSize(nPoints, nDims, gridSize) + let dataGrid = newRbfGrid(scaledPoints, values, gridSize) let patchPoints = dataGrid.constructMeshedPatches() let nPatches = patchPoints.shape[0] + var patchRbfs: seq[RbfBaseType[T]] #= newTensor[RbfBaseType[T]](nPatches, 1) var patchIndices: seq[int] for i in 0 ..< nPatches: diff --git a/tests/test_interpolate.nim b/tests/test_interpolate.nim index e9288f0..8f10a9a 100644 --- a/tests/test_interpolate.nim +++ b/tests/test_interpolate.nim @@ -434,6 +434,6 @@ test "rbf f=x*y*z": let yTest = rbfObj.eval(xTest) let yCorrect = xTest[_, 0] *. xTest[_, 1] *. xTest[_, 2] for x in abs(yCorrect - yTest): - check x < 0.03 - check mean_squared_error(yTest, yCorrect) < 1e-4 + check x < 0.11 + check mean_squared_error(yTest, yCorrect) < 1.4e-4