|
10 | 10 | //! For details see [`CMAES`].
|
11 | 11 |
|
12 | 12 | use crate::core::{
|
13 |
| - ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, KV, |
| 13 | + ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, SyncAlias, |
| 14 | + KV, |
14 | 15 | };
|
15 | 16 | use argmin_math::{
|
16 | 17 | ArgminAdd, ArgminArgsort, ArgminAxisIter, ArgminBroadcast, ArgminDiv, ArgminDot,
|
@@ -45,10 +46,10 @@ use std::ops::{AddAssign, MulAssign};
|
45 | 46 | /// }
|
46 | 47 | ///
|
47 | 48 | /// impl CostFunction for Rosenbrock {
|
48 |
| -/// type Param = Vec<f32>; |
49 |
| -/// type Output = f32; |
| 49 | +/// type Param = Vec<f32>; |
| 50 | +/// type Output = f32; |
50 | 51 | ///
|
51 |
| -/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> { |
| 52 | +/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> { |
52 | 53 | /// Ok(rosenbrock_2d(p, self.a, self.b))
|
53 | 54 | /// }
|
54 | 55 | /// }
|
@@ -203,11 +204,12 @@ where
|
203 | 204 |
|
204 | 205 | impl<O, P, F> Solver<O, PopulationState<P, F, P::Array2D>> for CMAES<P, F>
|
205 | 206 | where
|
206 |
| - O: CostFunction<Param = P, Output = F>, |
| 207 | + O: CostFunction<Param = P, Output = F> + SyncAlias, |
207 | 208 | Vec<F>: ArgminArgsort,
|
208 | 209 | F: ArgminFloat + MulAssign + AddAssign + NumCast + ArgminDiv<P, P>,
|
209 | 210 | P: SerializeAlias
|
210 | 211 | + Clone
|
| 212 | + + SyncAlias |
211 | 213 | + ArgminTransition
|
212 | 214 | + ArgminSize<usize>
|
213 | 215 | + ArgminZeroLike
|
@@ -251,12 +253,9 @@ where
|
251 | 253 |
|
252 | 254 | state.population = Some(self.generate());
|
253 | 255 |
|
254 |
| - let fitness: Vec<F> = state |
255 |
| - .get_population() |
256 |
| - .unwrap() |
257 |
| - .row_iterator() |
258 |
| - .map(|p| problem.cost(&p).unwrap()) |
259 |
| - .collect(); |
| 256 | + let fitness: Vec<F> = problem |
| 257 | + .bulk_cost(&state.get_population().unwrap().row_iterator().collect()) |
| 258 | + .unwrap(); |
260 | 259 |
|
261 | 260 | let fitness_indices = fitness.argsort();
|
262 | 261 |
|
@@ -411,7 +410,6 @@ mod tests {
|
411 | 410 | assert!(state.best_individual.is_some());
|
412 | 411 |
|
413 | 412 | let solution = state.best_individual.unwrap();
|
414 |
| - println!("{:?}", solution); |
415 | 413 | assert!((solution[0] - 1.0).abs() <= precision);
|
416 | 414 | assert!((solution[1] - 1.0).abs() <= precision);
|
417 | 415 | }
|
|
0 commit comments