Skip to content

Commit 7dfe4b0

Browse files
Switches CMA-ES to bulk_cost
1 parent cb259aa commit 7dfe4b0

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

argmin/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
//!
7979
//! - [Particle Swarm Optimization](`crate::solver::particleswarm::ParticleSwarm`)
8080
//!
81-
//! - [CMA-ES](solver/cma_es/struct.CMAES.html)
81+
//! - [CMA-ES](`crate::solver::cma_es::CMAES`)
8282
//!
8383
//! # License
8484
//!

argmin/src/solver/cma_es/mod.rs

+10-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
//! For details see [`CMAES`].
1111
1212
use crate::core::{
13-
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, KV,
13+
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, SyncAlias,
14+
KV,
1415
};
1516
use argmin_math::{
1617
ArgminAdd, ArgminArgsort, ArgminAxisIter, ArgminBroadcast, ArgminDiv, ArgminDot,
@@ -45,10 +46,10 @@ use std::ops::{AddAssign, MulAssign};
4546
/// }
4647
///
4748
/// impl CostFunction for Rosenbrock {
48-
/// type Param = Vec<f32>;
49-
/// type Output = f32;
49+
/// type Param = Vec<f32>;
50+
/// type Output = f32;
5051
///
51-
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
52+
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
5253
/// Ok(rosenbrock_2d(p, self.a, self.b))
5354
/// }
5455
/// }
@@ -203,11 +204,12 @@ where
203204

204205
impl<O, P, F> Solver<O, PopulationState<P, F, P::Array2D>> for CMAES<P, F>
205206
where
206-
O: CostFunction<Param = P, Output = F>,
207+
O: CostFunction<Param = P, Output = F> + SyncAlias,
207208
Vec<F>: ArgminArgsort,
208209
F: ArgminFloat + MulAssign + AddAssign + NumCast + ArgminDiv<P, P>,
209210
P: SerializeAlias
210211
+ Clone
212+
+ SyncAlias
211213
+ ArgminTransition
212214
+ ArgminSize<usize>
213215
+ ArgminZeroLike
@@ -251,12 +253,9 @@ where
251253

252254
state.population = Some(self.generate());
253255

254-
let fitness: Vec<F> = state
255-
.get_population()
256-
.unwrap()
257-
.row_iterator()
258-
.map(|p| problem.cost(&p).unwrap())
259-
.collect();
256+
let fitness: Vec<F> = problem
257+
.bulk_cost(&state.get_population().unwrap().row_iterator().collect())
258+
.unwrap();
260259

261260
let fitness_indices = fitness.argsort();
262261

@@ -411,7 +410,6 @@ mod tests {
411410
assert!(state.best_individual.is_some());
412411

413412
let solution = state.best_individual.unwrap();
414-
println!("{:?}", solution);
415413
assert!((solution[0] - 1.0).abs() <= precision);
416414
assert!((solution[1] - 1.0).abs() <= precision);
417415
}

0 commit comments

Comments
 (0)