diff --git a/src/cli.rs b/src/cli.rs index 24ff203..1f8909d 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -37,6 +37,15 @@ enum Commands { #[arg(long, default_value = "", help = "Filter by frequency")] frequency: String, + + #[arg(long, action, group = "status", help = "Filter by completed status")] + completed: bool, + + #[arg(long, action, group = "status", help = "Filter by waiting status")] + waiting: bool, + + #[arg(long, action, group = "status", help = "Filter by missed status")] + missed: bool, }, #[command(about = "Create a new streak", long_about = None, short_flag = 'a')] Add { @@ -140,7 +149,11 @@ fn build_table(streaks: Vec) -> String { header_style.paint("\nTotal").to_string(), ]); - let (width, _) = dimensions().unwrap(); + let (width, _) = match dimensions() { + Some((w, _)) => (w, 0), + None => (80, 0), + }; + dbg!(&width); for streak in streaks.iter() { let mut wrapped_text = String::new(); @@ -269,6 +282,9 @@ pub fn parse() { sort_by, search, frequency, + completed, + waiting, + missed, } => { let mut streak_list = match search.is_empty() { true => db.get_all(), @@ -286,6 +302,21 @@ pub fn parse() { .collect(); } + if *completed { + streak_list = streak_list + .into_iter() + .filter(|s| s.is_completed()) + .collect(); + } + + if *missed { + streak_list = streak_list.into_iter().filter(|s| s.is_missed()).collect(); + } + + if *waiting { + streak_list = streak_list.into_iter().filter(|s| s.is_waiting()).collect(); + } + streak_list = sort_streaks(streak_list, sort_by.0, sort_by.1); println!("{}", build_table(streak_list)); } @@ -322,7 +353,7 @@ pub fn parse() { #[cfg(test)] mod tests { - use super::{get_sort_order, Streak}; + use super::get_sort_order; use assert_cmd::Command; use assert_fs::TempDir; use rstest::*; @@ -464,4 +495,62 @@ mod tests { .assert() .success(); } + + #[rstest] + fn test_frequency_filter(mut command: Command) { + let temp = TempDir::new().unwrap(); + + command + .arg("--database-url") + .arg(format!( + "{}/{}", + temp.path().display(), + "test-frequency-filter.ron" + )) + .arg("list") + .arg("--frequency") + .arg("daily") + .assert() + .success(); + } + + #[rstest] + fn test_frequency_filter_and_sort(mut command: Command) { + let temp = TempDir::new().unwrap(); + + command + .arg("--database-url") + .arg(format!( + "{}/{}", + temp.path().display(), + "test-frequency-filter-sort.ron" + )) + .arg("list") + .arg("--frequency") + .arg("daily") + .arg("--sort-by") + .arg("task+") + .assert() + .success(); + } + + #[rstest] + #[case("completed")] + #[case("missed")] + #[case("waiting")] + fn test_filter_by_status(mut command: Command, #[case] status: &str) { + let temp = TempDir::new().unwrap(); + + command + .arg("--database-url") + .arg(format!( + "{}/{}", + temp.path().display(), + "test-filter-by-status.ron" + )) + .arg("list") + .arg(format!("--{status}")) + .assert() + .success(); + } } diff --git a/src/streak.rs b/src/streak.rs index 12f7485..75c5215 100644 --- a/src/streak.rs +++ b/src/streak.rs @@ -1,11 +1,11 @@ use std::fmt::Display; +use crate::cli::{SortByDirection, SortByField}; #[allow(unused_imports)] use chrono::{Local, NaiveDate}; use clap::ValueEnum; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::cli::{SortByDirection, SortByField}; #[derive( Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, ValueEnum, Serialize, Deserialize, @@ -157,6 +157,18 @@ impl Streak { } } + pub fn is_completed(&self) -> bool { + self.status() == Status::Done + } + + pub fn is_missed(&self) -> bool { + self.status() == Status::Missed + } + + pub fn is_waiting(&self) -> bool { + self.status() == Status::Waiting + } + pub fn update(&mut self, new_self: Streak) { let id = self.id; *self = new_self; @@ -178,7 +190,11 @@ impl Default for Streak { } } -pub fn sort_streaks(mut streaks: Vec, sort_field: SortByField, sort_direction: SortByDirection) -> Vec { +pub fn sort_streaks( + mut streaks: Vec, + sort_field: SortByField, + sort_direction: SortByDirection, +) -> Vec { match (sort_field, sort_direction) { (SortByField::Task, SortByDirection::Ascending) => { streaks.sort_by(|a, b| a.task.cmp(&b.task))