Skip to content

Commit

Permalink
fix: Improve hist binning around breakpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley committed Nov 28, 2024
1 parent 9f3f012 commit c9ebbb5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
19 changes: 8 additions & 11 deletions crates/polars-ops/src/chunked_array/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,22 @@ where
let mut count: Vec<IdxSize> = vec![0; num_bins];
let min_break: f64 = breaks[0];
let max_break: f64 = breaks[num_bins];
let width = breaks[1] - min_break; // guaranteed at least one bin
let scale = num_bins as f64 / (max_break - min_break);

for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();
if include_lower && item == min_break {
count[0] += 1;
} else if item == max_break {
count[num_bins - 1] += 1;
} else if item > min_break && item < max_break {
let width_multiple = (item - min_break) / width;
let idx = width_multiple.floor();
// handle the case where item lands on the boundary
let idx = if idx == width_multiple {
if item > min_break && item <= max_break {
let idx = scale * (item - min_break);
let idx_floor = idx.floor();
let idx = if idx == idx_floor {
idx - 1.0
} else {
idx
idx_floor
};
count[idx as usize] += 1;
} else if include_lower && item == min_break {
count[0] += 1;
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/operations/test_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,18 @@ def test_hist_same_values_20030() -> None:
}
)
assert_frame_equal(out, expected)


def test_hist_breakpoint_accuracy() -> None:
s = pl.Series([1, 2, 3, 4])
out = s.hist(bin_count=3)
expected = pl.DataFrame(
{
"breakpoint": pl.Series([2.0, 3.0, 4.0], dtype=pl.Float64),
"category": pl.Series(
["(0.997, 2.0]", "(2.0, 3.0]", "(3.0, 4.0]"], dtype=pl.Categorical
),
"count": pl.Series([2, 1, 1], dtype=pl.get_index_type()),
}
)
assert_frame_equal(out, expected)

0 comments on commit c9ebbb5

Please sign in to comment.