Skip to content

Commit bdda07f

Browse files
authored
Merge branch 'main' into main
2 parents 00822bb + 13a39ea commit bdda07f

File tree

14 files changed

+317
-50
lines changed

14 files changed

+317
-50
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
|
1818
<a href="https://argmin-rs.github.io/argmin/argmin/">Docs (main branch)</a>
1919
|
20-
<a href="https://github.com/argmin-rs/argmin/tree/v0.5.0/examples">Examples (latest release)</a>
20+
<a href="https://github.com/argmin-rs/argmin/tree/argmin-v0.8.1/argmin/examples">Examples (latest release)</a>
2121
|
2222
<a href="https://github.com/argmin-rs/argmin/tree/main/argmin/examples">Examples (main branch)</a>
2323
</p>

argmin-math/src/ndarray_m/inv.rs

+18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ macro_rules! make_inv {
2222
Ok(<Self as Inverse>::inv(&self)?)
2323
}
2424
}
25+
26+
// inverse for scalars (1d solvers)
27+
impl ArgminInv<$t> for $t {
28+
#[inline]
29+
fn inv(&self) -> Result<$t, Error> {
30+
Ok(1.0 / self)
31+
}
32+
}
2533
};
2634
}
2735

@@ -60,6 +68,16 @@ mod tests {
6068
}
6169
}
6270
}
71+
72+
item! {
73+
#[test]
74+
fn [<test_inv_scalar_ $t>]() {
75+
let a = 2.0;
76+
let target = 0.5;
77+
let res = <$t as ArgminInv<$t>>::inv(&a).unwrap();
78+
assert!(((res - target) as f64).abs() < 0.000001);
79+
}
80+
}
6381
};
6482
}
6583

argmin/Cargo.toml

+9-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ serde1 = ["serde", "serde_json", "rand/serde1", "bincode", "slog-json", "rand_xo
5454
_ndarrayl = ["argmin-math/ndarray_latest-serde", "argmin-math/_dev_linalg_latest"]
5555
_nalgebral = ["argmin-math/nalgebra_latest-serde"]
5656
# When adding new features, please consider adding them to either `full` (for users)
57-
# or `_full_dev` (only for local development, tesing and computing test coverage).
57+
# or `_full_dev` (only for local development, testing and computing test coverage).
5858
full = ["default", "slog-logger", "serde1", "ctrlc"]
5959
_full_dev = ["full", "_ndarrayl", "_nalgebral"]
6060

@@ -129,6 +129,10 @@ required-features = ["argmin-math/nalgebra_latest-serde", "slog-logger"]
129129
name = "morethuente"
130130
required-features = ["slog-logger"]
131131

132+
[[example]]
133+
name = "neldermead-cubic"
134+
required-features = ["slog-logger"]
135+
132136
[[example]]
133137
name = "neldermead"
134138
required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]
@@ -177,6 +181,10 @@ required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]
177181
name = "steepestdescent"
178182
required-features = ["slog-logger"]
179183

184+
[[example]]
185+
name = "steepestdescent_manifold"
186+
required-features = ["slog-logger"]
187+
180188
[[example]]
181189
name = "trustregion_nd"
182190
required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]

argmin/examples/neldermead-cubic.rs

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright 2018-2022 argmin developers
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5+
// http://opensource.org/licenses/MIT>, at your option. This file may not be
6+
// copied, modified, or distributed except according to those terms.
7+
8+
//! A (hopefully) simple example of using Nelder-Mead to find the roots of a
9+
//! cubic polynomial.
10+
//!
11+
//! You can run this example with:
12+
//! `cargo run --example neldermead-cubic --features slog-logger`
13+
14+
use argmin::core::observers::{ObserverMode, SlogLogger};
15+
use argmin::core::{CostFunction, Error, Executor, State};
16+
use argmin::solver::neldermead::NelderMead;
17+
18+
/// Coefficients describing a cubic `f(x) = ax^3 + bx^2 + cx + d`
19+
#[derive(Clone, Copy)]
20+
struct Cubic {
21+
/// Coefficient of the `x^3` term
22+
a: f64,
23+
/// Coefficient of the `x^2` term
24+
b: f64,
25+
/// Coefficient of the `x` term
26+
c: f64,
27+
/// Coefficient of the `x^0` term
28+
d: f64,
29+
}
30+
31+
impl Cubic {
32+
/// Evaluate the cubic at `x`.
33+
fn eval(self, x: f64) -> f64 {
34+
self.a * x.powi(3) + self.b * x.powi(2) + self.c * x + self.d
35+
}
36+
}
37+
38+
impl CostFunction for Cubic {
39+
type Param = f64;
40+
type Output = f64;
41+
42+
fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
43+
// The cost function is the evaluation of the polynomial with our
44+
// parameter, squared. The parameter is a guess of `x`, and the
45+
// objective is to minimize `x` (i.e. find a polynomial root). The
46+
// square value can be considered an error. We want the error to (1)
47+
// always be positive and (2) bigger the further it is from a polynomial
48+
// root.
49+
Ok(self.eval(*p).powi(2))
50+
}
51+
}
52+
53+
fn run() -> Result<(), Error> {
54+
// Define the cost function. This needs to be something with an
55+
// implementation of `CostFunction`; in this case, the impl is right
56+
// above. Here, our cubic is `(x-2)(x+2)(x-5)`; see
57+
// <https://www.wolframalpha.com/input?i=%28x-2%29%28x%2B2%29%28x-5%29> for
58+
// more info.
59+
let cost = Cubic {
60+
a: 1.0,
61+
b: -5.0,
62+
c: -4.0,
63+
d: 20.0,
64+
};
65+
66+
// Let's find a root of the cubic (+5).
67+
{
68+
// Set up solver -- note that the proper choice of the vertices is very
69+
// important! This example should find 5, because our vertices are 6 and 7.
70+
let solver = NelderMead::new(vec![6.0, 7.0]).with_sd_tolerance(0.0001)?;
71+
72+
// Run solver
73+
let res = Executor::new(cost, solver)
74+
.configure(|state| state.max_iters(100))
75+
.add_observer(SlogLogger::term(), ObserverMode::Always)
76+
.run()?;
77+
78+
// Wait a second (lets the logger flush everything before printing again)
79+
std::thread::sleep(std::time::Duration::from_secs(1));
80+
81+
// Print result
82+
println!(
83+
"Polynomial root: {}",
84+
res.state.get_best_param().expect("Found a root")
85+
);
86+
}
87+
88+
// Now find -2.
89+
{
90+
let solver = NelderMead::new(vec![-3.0, -4.0]).with_sd_tolerance(0.0001)?;
91+
let res = Executor::new(cost, solver)
92+
.configure(|state| state.max_iters(100))
93+
.add_observer(SlogLogger::term(), ObserverMode::Always)
94+
.run()?;
95+
std::thread::sleep(std::time::Duration::from_secs(1));
96+
println!("{res}");
97+
println!(
98+
"Polynomial root: {}",
99+
res.state.get_best_param().expect("Found a root")
100+
);
101+
}
102+
103+
// This example will find +2, even though it might look like we're trying to
104+
// find +5.
105+
{
106+
let solver = NelderMead::new(vec![4.0, 6.0]).with_sd_tolerance(0.0001)?;
107+
let res = Executor::new(cost, solver)
108+
.configure(|state| state.max_iters(100))
109+
.add_observer(SlogLogger::term(), ObserverMode::Always)
110+
.run()?;
111+
std::thread::sleep(std::time::Duration::from_secs(1));
112+
println!("{res}");
113+
println!(
114+
"Polynomial root: {}",
115+
res.state.get_best_param().expect("Found a root")
116+
);
117+
}
118+
119+
Ok(())
120+
}
121+
122+
fn main() {
123+
if let Err(ref e) = run() {
124+
println!("{e}");
125+
std::process::exit(1);
126+
}
127+
}

argmin/examples/steepestdescent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ impl Gradient for Rosenbrock {
3838
}
3939

4040
fn run() -> Result<(), Error> {
41-
// Define cost function (must implement `ArgminOperator`)
41+
// Define cost function (must implement `CostFunction` and `Gradient`)
4242
let cost = Rosenbrock { a: 1.0, b: 100.0 };
4343

4444
// Define initial parameter vector
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright 2018-2022 argmin developers
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5+
// http://opensource.org/licenses/MIT>, at your option. This file may not be
6+
// copied, modified, or distributed except according to those terms.
7+
8+
#![allow(unused_imports)]
9+
10+
use argmin::core::observers::{ObserverMode, SlogLogger};
11+
use argmin::core::{CostFunction, Error, Executor, Gradient};
12+
use argmin::solver::gradientdescent::SteepestDescent;
13+
use argmin::solver::linesearch::condition::{ArmijoCondition, LineSearchCondition};
14+
use argmin::solver::linesearch::BacktrackingLineSearch;
15+
use argmin_math::ArgminScaledAdd;
16+
17+
use serde::{Deserialize, Serialize};
18+
19+
#[derive(Clone, Copy, Debug)]
20+
struct ClosestPointOnCircle {
21+
x: f64,
22+
y: f64,
23+
}
24+
25+
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
26+
struct CirclePoint {
27+
angle: f64,
28+
}
29+
30+
impl CostFunction for ClosestPointOnCircle {
31+
type Param = CirclePoint;
32+
type Output = f64;
33+
34+
fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
35+
let x_circ = p.angle.cos();
36+
let y_circ = p.angle.sin();
37+
let x_diff = x_circ - self.x;
38+
let y_diff = y_circ - self.y;
39+
Ok(x_diff.powi(2) + y_diff.powi(2))
40+
}
41+
}
42+
43+
impl Gradient for ClosestPointOnCircle {
44+
type Param = CirclePoint;
45+
type Gradient = f64;
46+
47+
fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
48+
Ok(2.0 * (p.angle.cos() - self.x) * (-p.angle.sin())
49+
+ 2.0 * (p.angle.sin() - self.y) * p.angle.cos())
50+
}
51+
}
52+
53+
impl ArgminScaledAdd<f64, f64, CirclePoint> for CirclePoint {
54+
fn scaled_add(&self, alpha: &f64, delta: &f64) -> Self {
55+
CirclePoint {
56+
angle: self.angle + alpha * delta,
57+
}
58+
}
59+
}
60+
61+
fn run() -> Result<(), Error> {
62+
// Define cost function (must implement `CostFunction` and `Gradient`)
63+
let cost = ClosestPointOnCircle { x: 1.0, y: 1.0 };
64+
65+
// Define initial parameter vector
66+
let init_param = CirclePoint { angle: 0.0 };
67+
68+
// Pick a line search.
69+
let cond = ArmijoCondition::new(0.5)?;
70+
let linesearch = BacktrackingLineSearch::new(cond);
71+
72+
// Set up solver
73+
let solver = SteepestDescent::new(linesearch);
74+
75+
// Run solver
76+
let res = Executor::new(cost, solver)
77+
.configure(|state| state.param(init_param).max_iters(10))
78+
.add_observer(SlogLogger::term(), ObserverMode::Always)
79+
.run()?;
80+
81+
// Wait a second (lets the logger flush everything first)
82+
std::thread::sleep(std::time::Duration::from_secs(1));
83+
84+
// print result
85+
println!("{res}");
86+
Ok(())
87+
}
88+
89+
fn main() {
90+
if let Err(ref e) = run() {
91+
println!("{e}");
92+
std::process::exit(1);
93+
}
94+
}

argmin/src/solver/conjugategradient/cg.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,10 @@ mod tests {
247247

248248
assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON);
249249
assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON);
250-
let r0 = vec![2.0f64, 2.0];
250+
let r0 = [2.0f64, 2.0];
251251
assert_relative_eq!(r0[0], r.as_ref().unwrap()[0], epsilon = f64::EPSILON);
252252
assert_relative_eq!(r0[1], r.as_ref().unwrap()[1], epsilon = f64::EPSILON);
253-
let pp = vec![-2.0f64, -2.0];
253+
let pp = [-2.0f64, -2.0];
254254
assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON);
255255
assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON);
256256
assert_relative_eq!(rtr, 8.0, epsilon = f64::EPSILON);

argmin/src/solver/gradientdescent/steepestdescent.rs

+28-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
use crate::core::{
99
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
10-
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, KV,
10+
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV,
1111
};
1212
use argmin_math::ArgminMul;
1313
#[cfg(feature = "serde1")]
@@ -54,24 +54,27 @@ impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescen
5454
where
5555
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
5656
P: Clone + SerializeAlias + DeserializeOwnedAlias,
57-
G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
58-
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
57+
G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, G>,
58+
L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
5959
F: ArgminFloat,
6060
{
6161
const NAME: &'static str = "Steepest Descent";
6262

6363
fn next_iter(
6464
&mut self,
6565
problem: &mut Problem<O>,
66-
mut state: IterState<P, G, (), (), (), F>,
66+
state: IterState<P, G, (), (), (), F>,
6767
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
68-
let param_new = state.take_param().ok_or_else(argmin_error_closure!(
69-
NotInitialized,
70-
concat!(
71-
"`SteepestDescent` requires an initial parameter vector. ",
72-
"Please provide an initial guess via `Executor`s `configure` method."
73-
)
74-
))?;
68+
let param_new = state
69+
.get_param()
70+
.ok_or_else(argmin_error_closure!(
71+
NotInitialized,
72+
concat!(
73+
"`SteepestDescent` requires an initial parameter vector. ",
74+
"Please provide an initial guess via `Executor`s `configure` method."
75+
)
76+
))?
77+
.clone();
7578
let new_cost = problem.cost(&param_new)?;
7679
let new_grad = problem.gradient(&param_new)?;
7780

@@ -153,6 +156,20 @@ mod tests {
153156
);
154157
}
155158

159+
#[test]
160+
fn test_next_iter_prev_param_not_erased() {
161+
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
162+
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
163+
let mut sd = SteepestDescent::new(linesearch);
164+
let (state, _kv) = sd
165+
.next_iter(
166+
&mut Problem::new(TestProblem::new()),
167+
IterState::new().param(vec![1.0, 2.0]),
168+
)
169+
.unwrap();
170+
state.prev_param.unwrap();
171+
}
172+
156173
#[test]
157174
fn test_next_iter_regression() {
158175
struct SDProblem {}

0 commit comments

Comments
 (0)