Skip to content

Commit

Permalink
reorganize init checks for RustCycle
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle Carow authored and Kyle Carow committed Jan 29, 2024
1 parent 7bbbeba commit 4ae6654
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions rust/fastsim-core/src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,21 +633,7 @@ impl SerdeAPI for RustCycle {
const ACCEPTED_STR_FORMATS: &'static [&'static str] = &["yaml", "json", "csv"];

fn init(&mut self) -> anyhow::Result<()> {
ensure!(!self.is_empty(), "Deserialized cycle is empty");
let cyc_len = self.len();
ensure!(
self.mps.len() == cyc_len,
"Length of `mps` does not match length of `time_s`"
);
ensure!(
self.grade.len() == cyc_len,
"Length of `grade` does not match length of `time_s`"
);
ensure!(
self.road_type.len() == cyc_len,
"Length of `road_type` does not match length of `time_s`"
);
Ok(())
self.init_checks()
}

fn to_file<P: AsRef<Path>>(&self, filepath: P) -> anyhow::Result<()> {
Expand Down Expand Up @@ -738,7 +724,7 @@ impl TryFrom<HashMap<String, Vec<f64>>> for RustCycle {
let time_s = Array::from_vec(
hashmap
.get("time_s")
.with_context(|| "`time_s` not in HashMap")?
.with_context(|| format!("`time_s` not in HashMap: {hashmap:?}"))?
.to_owned(),
);
let cyc_len = time_s.len();
Expand All @@ -747,7 +733,7 @@ impl TryFrom<HashMap<String, Vec<f64>>> for RustCycle {
mps: Array::from_vec(
hashmap
.get("mps")
.with_context(|| "`mps` not in HashMap")?
.with_context(|| format!("`mps` not in HashMap: {hashmap:?}"))?
.to_owned(),
),
grade: Array::from_vec(
Expand Down Expand Up @@ -783,6 +769,20 @@ impl From<RustCycle> for HashMap<String, Vec<f64>> {

/// pure Rust methods that need to be separate due to pymethods incompatibility
impl RustCycle {
fn init_checks(&self) -> anyhow::Result<()> {
ensure!(!self.is_empty(), "Deserialized cycle is empty");
ensure!(self.is_sorted(), "Deserialized cycle is not sorted in time");
ensure!(
self.are_fields_equal_length(),
"Deserialized cycle has unequal field lengths\ntime_s: {}\nmps: {}\ngrade: {}\nroad_type: {}",
self.time_s.len(),
self.mps.len(),
self.grade.len(),
self.road_type.len(),
);
Ok(())
}

/// Load cycle from CSV file, parsing name from filepath
pub fn from_csv_file<P: AsRef<Path>>(filepath: P) -> anyhow::Result<Self> {
let filepath = filepath.as_ref();
Expand Down Expand Up @@ -853,6 +853,21 @@ impl RustCycle {
self.len() == 0
}

pub fn is_sorted(&self) -> bool {
self.time_s
.as_slice()
.unwrap()
.windows(2)
.all(|window| window[0] < window[1])
}

pub fn are_fields_equal_length(&self) -> bool {
let cyc_len = self.len();
[self.mps.len(), self.grade.len(), self.road_type.len()]
.iter()
.all(|len| len == &cyc_len)
}

pub fn test_cyc() -> Self {
Self {
time_s: Array::range(0.0, 10.0, 1.0),
Expand Down

0 comments on commit 4ae6654

Please sign in to comment.