diff --git a/src/numericalnim/optimize.nim b/src/numericalnim/optimize.nim index 9e4d42a..1d45662 100644 --- a/src/numericalnim/optimize.nim +++ b/src/numericalnim/optimize.nim @@ -130,7 +130,7 @@ proc secant*(f: proc(x: float64): float64, start: array[2, float64], precision: ## Multidimensional methods ## ############################## -type LineSearchCriterion = enum +type LineSearchCriterion* = enum Armijo, Wolfe, WolfeStrong, NoLineSearch type diff --git a/tests/test_optimize.nim b/tests/test_optimize.nim index 329cdfe..a9bcbaf 100644 --- a/tests/test_optimize.nim +++ b/tests/test_optimize.nim @@ -122,6 +122,13 @@ suite "Multi-dim": for x in abs(correct - xSol): check x < 7e-10 + test "Line Search options": + for ls in LineSearchCriterion: + let op = lbfgsOptions[float](lineSearchCriterion=ls) + let xSol = lbfgs(bananaFunc, x0.clone, options=op, analyticGradient=bananaBend) + for x in abs(correct - xSol): + check x < 7e-8 + let correctParams = [10.4, -0.45].toTensor() proc fitFunc(params: Tensor[float], x: float): float = params[0] * exp(params[1] * x)