Skip to content

Commit

Permalink
refactor: Separate Empirical::mean_and_var
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon authored and YeungOnion committed Dec 3, 2024
1 parent f685851 commit 7456947
Showing 1 changed file with 39 additions and 33 deletions.
72 changes: 39 additions & 33 deletions src/distribution/empirical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ mod non_nan {
/// ```
#[derive(Clone, PartialEq, Debug)]
pub struct Empirical {
mean_and_var: Option<(f64, f64)>,
// keys are data points, values are number of data points with equal value
data: BTreeMap<NonNan<f64>, u64>,

// The following fields are only logically valid if !data.is_empty():
/// Total amount of data points (== sum of all _values_ inside self.data).
/// Must be 0 iff data.is_empty()
sum: u64,
mean: f64,
var: f64,
}

impl Empirical {
Expand All @@ -84,9 +85,10 @@ impl Empirical {
#[allow(clippy::result_unit_err)]
pub fn new() -> Result<Empirical, ()> {
Ok(Empirical {
sum: 0,
mean_and_var: None,
data: BTreeMap::new(),
sum: 0,
mean: 0.0,
var: 0.0,
})
}

Expand All @@ -97,17 +99,10 @@ impl Empirical {
};

self.sum += 1;
match self.mean_and_var {
Some((mean, var)) => {
let sum = self.sum as f64;
let var = var + (sum - 1.) * (data_point - mean) * (data_point - mean) / sum;
let mean = mean + (data_point - mean) / sum;
self.mean_and_var = Some((mean, var));
}
None => {
self.mean_and_var = Some((data_point, 0.));
}
}
let sum = self.sum as f64;
self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.mean += (data_point - self.mean) / sum;

*self.data.entry(map_key).or_insert(0) += 1;
}

Expand All @@ -117,21 +112,25 @@ impl Empirical {
None => return,
};

if let (Some(val), Some((mean, var))) = (self.data.remove(&map_key), self.mean_and_var) {
if val == 1 && self.data.is_empty() {
self.mean_and_var = None;
self.sum = 0;
return;
};
// reset mean and var
let sum = self.sum as f64;
let mean = (sum * mean - data_point) / (sum - 1.);
let var = var - (sum - 1.) * (data_point - mean) * (data_point - mean) / sum;
self.sum -= 1;
if val != 1 {
self.data.insert(map_key, val - 1);
};
self.mean_and_var = Some((mean, var));
let val = match self.data.remove(&map_key) {
Some(v) => v,
None => return,
};

if val == 1 && self.data.is_empty() {
self.sum = 0;
self.mean = 0.0;
self.var = 0.0;
return;
};

// reset mean and var
let sum = self.sum as f64;
self.mean = (sum * self.mean - data_point) / (sum - 1.);
self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum;
self.sum -= 1;
if val != 1 {
self.data.insert(map_key, val - 1);
}
}

Expand Down Expand Up @@ -232,12 +231,19 @@ impl Min<f64> for Empirical {

impl Distribution<f64> for Empirical {
fn mean(&self) -> Option<f64> {
self.mean_and_var.map(|(mean, _)| mean)
if self.data.is_empty() {
None
} else {
Some(self.mean)
}
}

fn variance(&self) -> Option<f64> {
self.mean_and_var
.map(|(_, var)| var / (self.sum as f64 - 1.))
if self.data.is_empty() {
None
} else {
Some(self.var / (self.sum as f64 - 1.))
}
}
}

Expand Down Expand Up @@ -293,7 +299,7 @@ mod tests {
#[test]
fn test_remove_nonexisting() {
let mut empirical = Empirical::new().unwrap();

empirical.add(5.2);
// should not panic
empirical.remove(10.0);
Expand Down

0 comments on commit 7456947

Please sign in to comment.