Skip to content

Commit c17e77b

Browse files
Switches CMA-ES to bulk_cost
1 parent 61de867 commit c17e77b

File tree

1 file changed

+10
-12
lines changed
  • argmin/src/solver/cma_es

1 file changed

+10
-12
lines changed

argmin/src/solver/cma_es/mod.rs

+10-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
//! For details see [`CMAES`].
1111
1212
use crate::core::{
13-
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, KV,
13+
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, SyncAlias,
14+
KV,
1415
};
1516
use argmin_math::{
1617
ArgminAdd, ArgminArgsort, ArgminAxisIter, ArgminBroadcast, ArgminDiv, ArgminDot,
@@ -45,10 +46,10 @@ use std::ops::{AddAssign, MulAssign};
4546
/// }
4647
///
4748
/// impl CostFunction for Rosenbrock {
48-
/// type Param = Vec<f32>;
49-
/// type Output = f32;
49+
/// type Param = Vec<f32>;
50+
/// type Output = f32;
5051
///
51-
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
52+
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
5253
/// Ok(rosenbrock_2d(p, self.a, self.b))
5354
/// }
5455
/// }
@@ -203,11 +204,12 @@ where
203204

204205
impl<O, P, F> Solver<O, PopulationState<P, F, P::Array2D>> for CMAES<P, F>
205206
where
206-
O: CostFunction<Param = P, Output = F>,
207+
O: CostFunction<Param = P, Output = F> + SyncAlias,
207208
Vec<F>: ArgminArgsort,
208209
F: ArgminFloat + MulAssign + AddAssign + NumCast + ArgminDiv<P, P>,
209210
P: SerializeAlias
210211
+ Clone
212+
+ SyncAlias
211213
+ ArgminTransition
212214
+ ArgminSize<usize>
213215
+ ArgminZeroLike
@@ -251,12 +253,9 @@ where
251253

252254
state.population = Some(self.generate());
253255

254-
let fitness: Vec<F> = state
255-
.get_population()
256-
.unwrap()
257-
.row_iterator()
258-
.map(|p| problem.cost(&p).unwrap())
259-
.collect();
256+
let fitness: Vec<F> = problem
257+
.bulk_cost(&state.get_population().unwrap().row_iterator().collect())
258+
.unwrap();
260259

261260
let fitness_indices = fitness.argsort();
262261

@@ -411,7 +410,6 @@ mod tests {
411410
assert!(state.best_individual.is_some());
412411

413412
let solution = state.best_individual.unwrap();
414-
println!("{:?}", solution);
415413
assert!((solution[0] - 1.0).abs() <= precision);
416414
assert!((solution[1] - 1.0).abs() <= precision);
417415
}

0 commit comments

Comments
 (0)