Skip to content

Commit

Permalink
Maintain original feature and target names after shuffling (#377)
Browse files Browse the repository at this point in the history
* maintain original feature and target names after shuffling

* fixed formatting issues

* fixed doctest

* moved test example to mod

* fixed doctest syntax
  • Loading branch information
Plutone11011 authored Feb 10, 2025
1 parent d910389 commit 1cf33f9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ where
/// ### Returns
///
/// A new shuffled version of the current Dataset
///
pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
indices.shuffle(rng);
Expand All @@ -586,6 +587,8 @@ where
let targets = T::new_targets(targets);

DatasetBase::new(records, targets)
.with_feature_names(self.feature_names().to_vec())
.with_target_names(self.target_names().to_vec())
}

#[allow(clippy::type_complexity)]
Expand Down
22 changes: 21 additions & 1 deletion src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ pub trait Labels {
mod tests {
use super::*;
use crate::error::Error;
use approx::assert_abs_diff_eq;
use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
use linfa_datasets::generate::make_dataset;
use ndarray::{array, Array1, Array2, Axis};
use rand::{rngs::SmallRng, SeedableRng};
Expand Down Expand Up @@ -1050,4 +1050,24 @@ mod tests {
let prob = -0.5;
assert_abs_diff_eq!(Pr::new_unchecked(prob).0, prob);
}

#[test]
fn test_dataset_shuffle() {
let mut rng = SmallRng::seed_from_u64(42);
let f_names = vec!["f1", "f2", "f3"];
let t_names = vec!["t1"];
let dataset = Dataset::new(
array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
array![0., 1., 3.],
)
.with_feature_names(f_names.clone())
.with_target_names(t_names.clone());

let shuffled = dataset.shuffle(&mut rng);

assert_abs_diff_ne!(dataset.records(), shuffled.records());
assert_abs_diff_ne!(dataset.targets(), shuffled.targets());
assert_eq!(f_names, shuffled.feature_names());
assert_eq!(t_names, shuffled.target_names());
}
}

0 comments on commit 1cf33f9

Please sign in to comment.