|
7 | 7 |
|
8 | 8 | use crate::core::{
|
9 | 9 | ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
|
10 |
| - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, KV, |
| 10 | + LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV, |
11 | 11 | };
|
12 | 12 | use argmin_math::ArgminMul;
|
13 | 13 | #[cfg(feature = "serde1")]
|
@@ -54,24 +54,27 @@ impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescen
|
54 | 54 | where
|
55 | 55 | O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
|
56 | 56 | P: Clone + SerializeAlias + DeserializeOwnedAlias,
|
57 |
| - G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>, |
58 |
| - L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>, |
| 57 | + G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, G>, |
| 58 | + L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>, |
59 | 59 | F: ArgminFloat,
|
60 | 60 | {
|
61 | 61 | const NAME: &'static str = "Steepest Descent";
|
62 | 62 |
|
63 | 63 | fn next_iter(
|
64 | 64 | &mut self,
|
65 | 65 | problem: &mut Problem<O>,
|
66 |
| - mut state: IterState<P, G, (), (), (), F>, |
| 66 | + state: IterState<P, G, (), (), (), F>, |
67 | 67 | ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
|
68 |
| - let param_new = state.take_param().ok_or_else(argmin_error_closure!( |
69 |
| - NotInitialized, |
70 |
| - concat!( |
71 |
| - "`SteepestDescent` requires an initial parameter vector. ", |
72 |
| - "Please provide an initial guess via `Executor`s `configure` method." |
73 |
| - ) |
74 |
| - ))?; |
| 68 | + let param_new = state |
| 69 | + .get_param() |
| 70 | + .ok_or_else(argmin_error_closure!( |
| 71 | + NotInitialized, |
| 72 | + concat!( |
| 73 | + "`SteepestDescent` requires an initial parameter vector. ", |
| 74 | + "Please provide an initial guess via `Executor`s `configure` method." |
| 75 | + ) |
| 76 | + ))? |
| 77 | + .clone(); |
75 | 78 | let new_cost = problem.cost(¶m_new)?;
|
76 | 79 | let new_grad = problem.gradient(¶m_new)?;
|
77 | 80 |
|
@@ -153,6 +156,20 @@ mod tests {
|
153 | 156 | );
|
154 | 157 | }
|
155 | 158 |
|
| 159 | + #[test] |
| 160 | + fn test_next_iter_prev_param_not_erased() { |
| 161 | + let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> = |
| 162 | + BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap()); |
| 163 | + let mut sd = SteepestDescent::new(linesearch); |
| 164 | + let (state, _kv) = sd |
| 165 | + .next_iter( |
| 166 | + &mut Problem::new(TestProblem::new()), |
| 167 | + IterState::new().param(vec![1.0, 2.0]), |
| 168 | + ) |
| 169 | + .unwrap(); |
| 170 | + state.prev_param.unwrap(); |
| 171 | + } |
| 172 | + |
156 | 173 | #[test]
|
157 | 174 | fn test_next_iter_regression() {
|
158 | 175 | struct SDProblem {}
|
|
0 commit comments