Skip to content

Commit

Permalink
Adress review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Jan 31, 2025
1 parent 31f6d6e commit e596579
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 36 deletions.
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/linear_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
/// Substitutes the records of the dataset with their scaled version.
/// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -575,7 +577,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = LinearScaler::standard()
.fit(&dataset)
.unwrap()
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/norm_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
{
/// Substitutes the records of the dataset with their scaled versions with unit norm.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -160,7 +162,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = NormScaler::l2().transform(dataset);
assert_eq!(original_feature_names, transformed.feature_names())
}
Expand Down
6 changes: 4 additions & 2 deletions algorithms/linfa-preprocessing/src/whitening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
for FittedWhitener<F>
{
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let feature_names = x.feature_names().to_vec();
let target_names = x.target_names().to_vec();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_target_names(target_names)
}
}

Expand Down Expand Up @@ -334,7 +336,7 @@ mod tests {
#[test]
fn test_retain_feature_names() {
let dataset = linfa_datasets::diabetes();
let original_feature_names = dataset.feature_names();
let original_feature_names = dataset.feature_names().to_vec();
let transformed = Whitener::cholesky()
.fit(&dataset)
.unwrap()
Expand Down
8 changes: 7 additions & 1 deletion algorithms/linfa-trees/src/decision_trees/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,13 @@ where
/// a matrix of features `x` and an array of labels `y`.
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let x = dataset.records();
let feature_names = dataset.feature_names();
let feature_names = if dataset.feature_names().is_empty() {
(0..x.nfeatures())
.map(|idx| format!("feature-{idx}"))
.collect()
} else {
dataset.feature_names().to_vec()
};
let all_idxs = RowMask::all(x.nrows());
let sorted_indices: Vec<_> = (0..(x.ncols()))
.map(|feature_idx| {
Expand Down
9 changes: 8 additions & 1 deletion datasets/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,11 @@ pub fn linnerud() -> Dataset<f64, f64> {
let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();

let feature_names = vec!["Chins", "Situps", "Jumps"];
let target_names = vec!["Weight", "Waist", "Pulse"];

Dataset::new(input_array, output_array).with_feature_names(feature_names)
Dataset::new(input_array, output_array)
.with_feature_names(feature_names)
.with_target_names(target_names)
}

#[cfg(test)]
Expand Down Expand Up @@ -261,6 +264,10 @@ mod tests {
let feature_names = vec!["Chins", "Situps", "Jumps"];
assert_eq!(ds.feature_names(), feature_names);

// check for target names
let target_names = vec!["Weight", "Waist", "Pulse"];
assert_eq!(ds.target_names(), target_names);

// get the mean per target: Weight, Waist, Pulse
let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
Expand Down
2 changes: 1 addition & 1 deletion src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl<F: Float> PearsonCorrelation<F> {
PearsonCorrelation {
pearson_coeffs,
p_values,
feature_names: dataset.feature_names(),
feature_names: dataset.feature_names().to_vec(),
}
}

Expand Down
53 changes: 29 additions & 24 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,8 @@ impl<R: Records, S> DatasetBase<R, S> {
/// A feature name gives a human-readable string describing the purpose of a single feature.
/// This allow the reader to understand its purpose while analysing results, for example
/// correlation analysis or feature importance.
pub fn feature_names(&self) -> Vec<String> {
if !self.feature_names.is_empty() {
self.feature_names.clone()
} else {
(0..self.records.nfeatures())
.map(|idx| format!("feature-{idx}"))
.collect()
}
pub fn feature_names(&self) -> &[String] {
&self.feature_names
}

/// Return records of a dataset
Expand Down Expand Up @@ -114,20 +108,14 @@ impl<R: Records, S> DatasetBase<R, S> {
}

/// Updates the feature names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of features
pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
let feature_names = names.into_iter().map(|x| x.into()).collect();

self.feature_names = feature_names;

self
}

/// Updates the target names of a dataset
pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
let target_names = names.into_iter().map(|x| x.into()).collect();

self.target_names = target_names;

assert!(
names.is_empty() || names.len() == self.nfeatures(),
"Wrong number of feature names"
);
self.feature_names = names.into_iter().map(|x| x.into()).collect();
self
}
}
Expand All @@ -143,6 +131,18 @@ impl<X, Y> Dataset<X, Y> {
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
/// Updates the target names of a dataset
///
/// **Panics** when given names not empty and length does not equal to the number of targets
pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
assert!(
names.is_empty() || names.len() == self.ntargets(),
"Wrong number of target names"
);
self.target_names = names.into_iter().map(|x| x.into()).collect();
self
}

/// Map targets with a function `f`
///
/// # Example
Expand Down Expand Up @@ -238,6 +238,7 @@ impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Creates a view of a dataset
pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
Expand Down Expand Up @@ -290,6 +291,7 @@ impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T
impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
where
T: AsTargets<Elem = L> + FromTargetArray<'a>,
T::View: AsTargets<Elem = L>,
{
/// Split dataset into two disjoint chunks
///
Expand Down Expand Up @@ -984,7 +986,8 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
let n2 = self.nsamples() - n1;

let feature_names = self.feature_names();
let feature_names = self.feature_names().to_vec();
let target_names = self.target_names().to_vec();

// split records into two disjoint arrays
let mut array_buf = self.records.into_raw_vec();
Expand Down Expand Up @@ -1017,10 +1020,12 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
// create new datasets with attached weights
let dataset1 = Dataset::new(first, first_targets)
.with_weights(self.weights)
.with_feature_names(feature_names.clone());
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());
let dataset2 = Dataset::new(second, second_targets)
.with_weights(second_weights)
.with_feature_names(feature_names);
.with_feature_names(feature_names.clone())
.with_target_names(target_names.clone());

(dataset1, dataset2)
}
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/impl_targets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
where
T: FromTargetArray<'a, Elem = L>,
T::Owned: Labels<Elem = L>,
T::View: Labels<Elem = L>,
T::View: Labels<Elem = L> + AsTargets,
{
type Owned = CountedTargets<L, T::Owned>;
type View = CountedTargets<L, T::View>;
Expand Down
5 changes: 3 additions & 2 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ mod tests {
let dataset = Dataset::new(
array![[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
array![[1, 2], [3, 4], [5, 6]],
).with_target_names(vec!["a", "b"]);
)
.with_target_names(vec!["a", "b"]);

let res = dataset
.target_iter()
Expand All @@ -568,7 +569,7 @@ mod tests {
let mut iter = dataset.target_iter();
let first = iter.next();
let second = iter.next();

assert_eq!(vec!["a"], first.unwrap().target_names());
assert_eq!(vec!["b"], second.unwrap().target_names());

Expand Down

0 comments on commit e596579

Please sign in to comment.