diff --git a/NEWS.md b/NEWS.md index 285c9b7fb..fa6322c47 100644 --- a/NEWS.md +++ b/NEWS.md @@ -32,6 +32,8 @@ double (#577). - The conversion of R's `POSIXct` class to Polars datetime now works correctly with millisecond precision (#589). +- `$filter()`, `$filter()`, and `pl$when()` now allow multiple conditions + to be separated by commas, like `lf$filter(pl$col("foo") == 1, pl$col("bar") != 2)` (#598). ## polars 0.11.0 diff --git a/R/dataframe__frame.R b/R/dataframe__frame.R index 3fbb90edd..d608e2b4e 100644 --- a/R/dataframe__frame.R +++ b/R/dataframe__frame.R @@ -164,7 +164,7 @@ pl$DataFrame = function(..., make_names_unique = TRUE, schema = NULL) { # keys are tentative new column names keys = names(largs) - if (length(keys) == 0) keys <- rep(NA_character_, length(largs)) + if (length(keys) == 0) keys = rep(NA_character_, length(largs)) keys = mapply(largs, keys, FUN = function(column, key) { if (is.na(key) || nchar(key) == 0) { if (inherits(column, "RPolarsSeries")) { @@ -291,7 +291,7 @@ DataFrame.property_setters = new.env(parent = emptyenv()) pstop(err = paste("no setter method for", name)) } - if (polars_optenv$strictly_immutable) self <- self$clone() + if (polars_optenv$strictly_immutable) self = self$clone() func = DataFrame.property_setters[[name]] func(self, value) self @@ -791,10 +791,8 @@ DataFrame_tail = function(n) { #' Filter rows of a DataFrame #' @name DataFrame_filter #' -#' @description This is equivalent to `dplyr::filter()`. Note that rows where -#' the condition returns `NA` are dropped, unlike base subsetting with `[`. +#' @inherit LazyFrame_filter description params details #' -#' @param bool_expr Polars expression which will evaluate to a boolean. #' @keywords DataFrame #' @return A DataFrame with only the rows where the conditions are `TRUE`. #' @examples @@ -802,14 +800,18 @@ DataFrame_tail = function(n) { #' #' df$filter(pl$col("Sepal.Length") > 5) #' +#' # This is equivalent to +#' # df$filter(pl$col("Sepal.Length") > 5 & pl$col("Petal.Width") < 1) +#' df$filter(pl$col("Sepal.Length") > 5, pl$col("Petal.Width") < 1) +#' #' # rows where condition is NA are dropped #' iris2 = iris #' iris2[c(1, 3, 5), "Species"] = NA #' df = pl$DataFrame(iris2) #' #' df$filter(pl$col("Species") == "setosa") -DataFrame_filter = function(bool_expr) { - .pr$DataFrame$lazy(self)$filter(bool_expr)$collect() +DataFrame_filter = function(...) { + .pr$DataFrame$lazy(self)$filter(...)$collect() } #' Group a DataFrame @@ -1542,7 +1544,7 @@ DataFrame_glimpse = function(..., return_as_string = FALSE) { max_col_name_trunc = 50 parse_column_ = \(col_name, dtype) { dtype_str = dtype_str_repr(dtype) |> unwrap_or(paste0("??", str_string(dtype))) - if (inherits(dtype, "RPolarsDataType")) dtype_str <- paste0(" <", dtype_str, ">") + if (inherits(dtype, "RPolarsDataType")) dtype_str = paste0(" <", dtype_str, ">") val = self$select(pl$col(col_name)$slice(0, max_num_value))$to_list()[[1]] val_str = paste(val, collapse = ", ") if (nchar(col_name) > max_col_name_trunc) { diff --git a/R/functions__whenthen.R b/R/functions__whenthen.R index 152f4e9a2..a5681f98c 100644 --- a/R/functions__whenthen.R +++ b/R/functions__whenthen.R @@ -2,7 +2,7 @@ #' @name Expr_when_then_otherwise #' @description Start a “when, then, otherwise” expression. #' @keywords Expr -#' @param condition Into Expr into a boolean mask to branch by. Strings interpreted as column. +#' @param ... Into Expr into a boolean mask to branch by. #' @param statement Into Expr value to insert in when() or otherwise(). #' Strings interpreted as column. #' @return Expr @@ -37,14 +37,36 @@ #' a nested when-then-otherwise expression. #' #' @examples -#' df = pl$DataFrame(mtcars) -#' wtt = -#' pl$when(pl$col("cyl") <= 4)$then(pl$lit("<=4cyl"))$ -#' when(pl$col("cyl") <= 6)$then(pl$lit("<=6cyl"))$ -#' otherwise(pl$lit(">6cyl"))$alias("cyl_groups") -#' print(wtt) -#' df$with_columns(wtt) -pl$when = function(condition) { +#' df = pl$DataFrame(foo = c(1, 3, 4), bar = c(3, 4, 0)) +#' +#' # Add a column with the value 1, where column "foo" > 2 and the value -1 where it isn’t. +#' df$with_columns( +#' pl$when(pl$col("foo") > 2)$then(1)$otherwise(-1)$alias("val") +#' ) +#' +#' # With multiple when, thens chained: +#' df$with_columns( +#' pl$when(pl$col("foo") > 2) +#' $then(1) +#' $when(pl$col("bar") > 2) +#' $then(4) +#' $otherwise(-1) +#' $alias("val") +#' ) +#' +#' # Pass multiple predicates, each of which must be met: +#' df$with_columns( +#' val = pl$when( +#' pl$col("bar") > 0, +#' pl$col("foo") %% 2 != 0 +#' ) +#' $then(99) +#' $otherwise(-1) +#' ) +pl$when = function(...) { + condition = unpack_bool_expr(...) |> + unwrap("in pl$when():") + .pr$When$new(condition) |> unwrap("in pl$when():") } @@ -57,7 +79,10 @@ When_then = function(statement) { unwrap("in $then():") } -Then_when = function(condition) { +Then_when = function(...) { + condition = unpack_bool_expr(...) |> + unwrap("in $when():") + .pr$Then$when(self, condition) |> unwrap("in $when():") } @@ -72,7 +97,10 @@ ChainedWhen_then = function(statement) { unwrap("in $then():") } -ChainedThen_when = function(condition) { +ChainedThen_when = function(...) { + condition = unpack_bool_expr(...) |> + unwrap("in $when():") + .pr$ChainedThen$when(self, condition) |> unwrap("in $when():") } diff --git a/R/lazyframe__lazy.R b/R/lazyframe__lazy.R index 81bcb1f4e..f9c4e0e0c 100644 --- a/R/lazyframe__lazy.R +++ b/R/lazyframe__lazy.R @@ -266,15 +266,31 @@ LazyFrame_with_row_count = function(name, offset = NULL) { .pr$LazyFrame$with_row_count(self, name, offset) |> unwrap() } -#' @title Apply filter to LazyFrame -#' @description Filter rows with an Expression defining a boolean column +#' Apply filter to LazyFrame +#' +#' Filter rows with an Expression defining a boolean column. +#' Multiple expressions are combined with `&` (AND). +#' This is equivalent to [dplyr::filter()]. +#' +#' Rows where the condition returns `NA` are dropped. #' @keywords LazyFrame -#' @param expr one Expr or string naming a column +#' @param ... Polars expressions which will evaluate to a boolean. #' @return A new `LazyFrame` object with add/modified column. #' @docType NULL -#' @usage LazyFrame_filter(expr) -#' @examples pl$LazyFrame(iris)$filter(pl$col("Species") == "setosa")$collect() -LazyFrame_filter = "use_extendr_wrapper" +#' @examples +#' lf = pl$LazyFrame(iris) +#' +#' lf$filter(pl$col("Species") == "setosa")$collect() +#' +#' # This is equivalent to +#' # lf$filter(pl$col("Sepal.Length") > 5 & pl$col("Petal.Width") < 1) +#' lf$filter(pl$col("Sepal.Length") > 5, pl$col("Petal.Width") < 1) +LazyFrame_filter = function(...) { + bool_expr = unpack_bool_expr(...) |> + unwrap("in $filter()") + + .pr$LazyFrame$filter(self, bool_expr) +} #' @title Get optimization settings #' @description Get the current optimization toggles for the lazy query @@ -1148,8 +1164,8 @@ LazyFrame_join_asof = function( tolerance = NULL, allow_parallel = TRUE, force_parallel = FALSE) { - if (!is.null(by)) by_left <- by_right <- by - if (!is.null(on)) left_on <- right_on <- on + if (!is.null(by)) by_left = by_right = by + if (!is.null(on)) left_on = right_on = on tolerance_str = if (is.character(tolerance)) tolerance else NULL tolerance_num = if (!is.character(tolerance)) tolerance else NULL diff --git a/R/utils.R b/R/utils.R index f6326054f..ecdf8ccb0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -117,6 +117,27 @@ unpack_list = function(..., skip_classes = NULL) { } } +#' Convert dot-dot-dot to bool expression +#' @noRd +#' @return Result, a list has `ok` (RPolarsExpr class) and `err` (RPolarsErr class) +#' @examples +#' unpack_bool_expr(pl$lit(TRUE), pl$lit(FALSE)) +unpack_bool_expr = function(..., .msg = NULL) { + dots = list2(...) + + if (!is.null(names(dots))) { + return(Err_plain( + "Detected a named input.", + "This usually means that you've used `=` instead of `==`." + )) + } + + dots |> + Reduce(`&`, x = _) |> + result(msg = .msg) |> + suppressWarnings() +} + #' Simple SQL CASE WHEN implementation for R #' @noRd #' @description Inspired by data.table::fcase + dplyr::case_when. diff --git a/man/DataFrame_filter.Rd b/man/DataFrame_filter.Rd index feb895424..a785cf491 100644 --- a/man/DataFrame_filter.Rd +++ b/man/DataFrame_filter.Rd @@ -4,23 +4,31 @@ \alias{DataFrame_filter} \title{Filter rows of a DataFrame} \usage{ -DataFrame_filter(bool_expr) +DataFrame_filter(...) } \arguments{ -\item{bool_expr}{Polars expression which will evaluate to a boolean.} +\item{...}{Polars expressions which will evaluate to a boolean.} } \value{ A DataFrame with only the rows where the conditions are \code{TRUE}. } \description{ -This is equivalent to \code{dplyr::filter()}. Note that rows where -the condition returns \code{NA} are dropped, unlike base subsetting with \code{[}. +Filter rows with an Expression defining a boolean column. +Multiple expressions are combined with \code{&} (AND). +This is equivalent to \code{\link[dplyr:filter]{dplyr::filter()}}. +} +\details{ +Rows where the condition returns \code{NA} are dropped. } \examples{ df = pl$DataFrame(iris) df$filter(pl$col("Sepal.Length") > 5) +# This is equivalent to +# df$filter(pl$col("Sepal.Length") > 5 & pl$col("Petal.Width") < 1) +df$filter(pl$col("Sepal.Length") > 5, pl$col("Petal.Width") < 1) + # rows where condition is NA are dropped iris2 = iris iris2[c(1, 3, 5), "Species"] = NA diff --git a/man/Expr_when_then_otherwise.Rd b/man/Expr_when_then_otherwise.Rd index 78e81cb62..7b067fc80 100644 --- a/man/Expr_when_then_otherwise.Rd +++ b/man/Expr_when_then_otherwise.Rd @@ -11,7 +11,7 @@ \alias{ChainedThen} \title{when-then-otherwise Expr} \arguments{ -\item{condition}{Into Expr into a boolean mask to branch by. Strings interpreted as column.} +\item{...}{Into Expr into a boolean mask to branch by.} \item{statement}{Into Expr value to insert in when() or otherwise(). Strings interpreted as column.} @@ -51,12 +51,31 @@ This statemachine ensures only syntacticly allowed methods are availble at any s a nested when-then-otherwise expression. } \examples{ -df = pl$DataFrame(mtcars) -wtt = - pl$when(pl$col("cyl") <= 4)$then(pl$lit("<=4cyl"))$ - when(pl$col("cyl") <= 6)$then(pl$lit("<=6cyl"))$ - otherwise(pl$lit(">6cyl"))$alias("cyl_groups") -print(wtt) -df$with_columns(wtt) +df = pl$DataFrame(foo = c(1, 3, 4), bar = c(3, 4, 0)) + +# Add a column with the value 1, where column "foo" > 2 and the value -1 where it isn’t. +df$with_columns( + pl$when(pl$col("foo") > 2)$then(1)$otherwise(-1)$alias("val") +) + +# With multiple when, thens chained: +df$with_columns( + pl$when(pl$col("foo") > 2) + $then(1) + $when(pl$col("bar") > 2) + $then(4) + $otherwise(-1) + $alias("val") +) + +# Pass multiple predicates, each of which must be met: +df$with_columns( + val = pl$when( + pl$col("bar") > 0, + pl$col("foo") \%\% 2 != 0 + ) + $then(99) + $otherwise(-1) +) } \keyword{Expr} diff --git a/man/LazyFrame_filter.Rd b/man/LazyFrame_filter.Rd index a7c9d272e..41cf09c64 100644 --- a/man/LazyFrame_filter.Rd +++ b/man/LazyFrame_filter.Rd @@ -3,22 +3,30 @@ \name{LazyFrame_filter} \alias{LazyFrame_filter} \title{Apply filter to LazyFrame} -\format{ -An object of class \code{character} of length 1. -} \usage{ -LazyFrame_filter(expr) +LazyFrame_filter(...) } \arguments{ -\item{expr}{one Expr or string naming a column} +\item{...}{Polars expressions which will evaluate to a boolean.} } \value{ A new \code{LazyFrame} object with add/modified column. } \description{ -Filter rows with an Expression defining a boolean column +Filter rows with an Expression defining a boolean column. +Multiple expressions are combined with \code{&} (AND). +This is equivalent to \code{\link[dplyr:filter]{dplyr::filter()}}. +} +\details{ +Rows where the condition returns \code{NA} are dropped. } \examples{ -pl$LazyFrame(iris)$filter(pl$col("Species") == "setosa")$collect() +lf = pl$LazyFrame(iris) + +lf$filter(pl$col("Species") == "setosa")$collect() + +# This is equivalent to +# lf$filter(pl$col("Sepal.Length") > 5 & pl$col("Petal.Width") < 1) +lf$filter(pl$col("Sepal.Length") > 5, pl$col("Petal.Width") < 1) } \keyword{LazyFrame} diff --git a/tests/testthat/_snaps/after-wrappers.md b/tests/testthat/_snaps/after-wrappers.md index 58731aed8..51c4b4048 100644 --- a/tests/testthat/_snaps/after-wrappers.md +++ b/tests/testthat/_snaps/after-wrappers.md @@ -252,6 +252,154 @@ [172] "unique_counts" "upper_bound" "value_counts" [175] "var" "where" "xor" +--- + + Code + ls(.pr[[private_key]]) + Output + [1] "abs" "add" + [3] "agg_groups" "alias" + [5] "all" "and" + [7] "any" "append" + [9] "approx_n_unique" "arccos" + [11] "arccosh" "arcsin" + [13] "arcsinh" "arctan" + [15] "arctanh" "arg_max" + [17] "arg_min" "arg_sort" + [19] "arg_unique" "backward_fill" + [21] "bin_contains" "bin_decode_base64" + [23] "bin_decode_hex" "bin_encode_base64" + [25] "bin_encode_hex" "bin_ends_with" + [27] "bin_starts_with" "bottom_k" + [29] "cast" "cat_get_categories" + [31] "cat_set_ordering" "ceil" + [33] "clip" "clip_max" + [35] "clip_min" "col" + [37] "cols" "corr" + [39] "cos" "cosh" + [41] "count" "cov" + [43] "cum_count" "cum_max" + [45] "cum_min" "cum_prod" + [47] "cum_sum" "cumulative_eval" + [49] "diff" "div" + [51] "dot" "drop_nans" + [53] "drop_nulls" "dt_cast_time_unit" + [55] "dt_combine" "dt_convert_time_zone" + [57] "dt_day" "dt_epoch_seconds" + [59] "dt_hour" "dt_iso_year" + [61] "dt_microsecond" "dt_millisecond" + [63] "dt_minute" "dt_month" + [65] "dt_nanosecond" "dt_offset_by" + [67] "dt_ordinal_day" "dt_quarter" + [69] "dt_replace_time_zone" "dt_round" + [71] "dt_second" "dt_strftime" + [73] "dt_time" "dt_total_days" + [75] "dt_total_hours" "dt_total_microseconds" + [77] "dt_total_milliseconds" "dt_total_minutes" + [79] "dt_total_nanoseconds" "dt_total_seconds" + [81] "dt_truncate" "dt_week" + [83] "dt_weekday" "dt_with_time_unit" + [85] "dt_year" "dtype_cols" + [87] "entropy" "eq" + [89] "eq_missing" "ewm_mean" + [91] "ewm_std" "ewm_var" + [93] "exclude" "exclude_dtype" + [95] "exp" "explode" + [97] "extend_constant" "fill_nan" + [99] "fill_null" "fill_null_with_strategy" + [101] "filter" "first" + [103] "flatten" "floor" + [105] "floor_div" "forward_fill" + [107] "gather" "gather_every" + [109] "gt" "gt_eq" + [111] "hash" "head" + [113] "implode" "interpolate" + [115] "is_duplicated" "is_finite" + [117] "is_first_distinct" "is_in" + [119] "is_infinite" "is_last_distinct" + [121] "is_nan" "is_not_nan" + [123] "is_not_null" "is_null" + [125] "is_unique" "kurtosis" + [127] "last" "len" + [129] "list_arg_max" "list_arg_min" + [131] "list_contains" "list_diff" + [133] "list_eval" "list_gather" + [135] "list_get" "list_join" + [137] "list_lengths" "list_max" + [139] "list_mean" "list_min" + [141] "list_reverse" "list_shift" + [143] "list_slice" "list_sort" + [145] "list_sum" "list_to_struct" + [147] "list_unique" "lit" + [149] "log" "log10" + [151] "lower_bound" "lt" + [153] "lt_eq" "map_batches" + [155] "map_batches_in_background" "map_elements_in_background" + [157] "max" "mean" + [159] "median" "meta_eq" + [161] "meta_has_multiple_outputs" "meta_is_regex_projection" + [163] "meta_output_name" "meta_pop" + [165] "meta_roots" "meta_tree_format" + [167] "meta_undo_aliases" "min" + [169] "mode" "mul" + [171] "n_unique" "name_keep" + [173] "name_map" "name_prefix" + [175] "name_suffix" "name_to_lowercase" + [177] "name_to_uppercase" "nan_max" + [179] "nan_min" "neq" + [181] "neq_missing" "new_count" + [183] "new_first" "new_last" + [185] "not" "null_count" + [187] "or" "over" + [189] "pct_change" "peak_max" + [191] "peak_min" "pow" + [193] "print" "product" + [195] "quantile" "rank" + [197] "rechunk" "reinterpret" + [199] "rem" "rep" + [201] "repeat_by" "reshape" + [203] "reverse" "rolling" + [205] "rolling_corr" "rolling_cov" + [207] "rolling_max" "rolling_mean" + [209] "rolling_median" "rolling_min" + [211] "rolling_quantile" "rolling_skew" + [213] "rolling_std" "rolling_sum" + [215] "rolling_var" "round" + [217] "sample_frac" "sample_n" + [219] "search_sorted" "shift" + [221] "shift_and_fill" "shrink_dtype" + [223] "shuffle" "sign" + [225] "sin" "sinh" + [227] "skew" "slice" + [229] "sort" "sort_by" + [231] "std" "str_base64_decode" + [233] "str_base64_encode" "str_concat" + [235] "str_contains" "str_count_matches" + [237] "str_ends_with" "str_explode" + [239] "str_extract" "str_extract_all" + [241] "str_hex_decode" "str_hex_encode" + [243] "str_json_extract" "str_json_path_match" + [245] "str_len_bytes" "str_len_chars" + [247] "str_pad_end" "str_pad_start" + [249] "str_parse_int" "str_replace" + [251] "str_replace_all" "str_slice" + [253] "str_split" "str_split_exact" + [255] "str_splitn" "str_starts_with" + [257] "str_strip_chars" "str_strip_chars_end" + [259] "str_strip_chars_start" "str_to_date" + [261] "str_to_datetime" "str_to_lowercase" + [263] "str_to_time" "str_to_titlecase" + [265] "str_to_uppercase" "str_zfill" + [267] "struct_field_by_name" "struct_rename_fields" + [269] "sub" "sum" + [271] "tail" "tan" + [273] "tanh" "timestamp" + [275] "to_physical" "top_k" + [277] "unique" "unique_counts" + [279] "unique_stable" "upper_bound" + [281] "value_counts" "var" + [283] "xor" + # public and private methods of each class When Code diff --git a/tests/testthat/test-lazy.R b/tests/testthat/test-lazy.R index 5edb8a0dc..f421c88e9 100644 --- a/tests/testthat/test-lazy.R +++ b/tests/testthat/test-lazy.R @@ -544,7 +544,7 @@ test_that("melt vs data.table::melt", { )$lazy() rdf = plf$collect()$to_data_frame() - dtt = data.table(rdf) + dtt = data.table::data.table(rdf) melt_mod = \(...) { data.table::melt(variable.factor = FALSE, value.factor = FALSE, ...) @@ -871,3 +871,15 @@ test_that("with_context works", { data.frame(feature_0 = c(-1, 0, 1)) ) }) + +test_that("Multiple conditions in filter", { + expect_identical( + pl$LazyFrame(mtcars)$filter( + pl$col("cyl") > 6, + pl$col("mpg") > 15 + )$collect()$to_data_frame(), + pl$LazyFrame(mtcars)$filter( + pl$col("cyl") > 6 & pl$col("mpg") > 15 + )$collect()$to_data_frame() + ) +}) diff --git a/tests/testthat/test-whenthen.R b/tests/testthat/test-whenthen.R index 0ff78c241..6854324c2 100644 --- a/tests/testthat/test-whenthen.R +++ b/tests/testthat/test-whenthen.R @@ -83,3 +83,32 @@ test_that("when-then-otherwise", { ) ) }) + +test_that("when-then multiple predicates", { + df = pl$DataFrame(foo = c(1, 3, 4), bar = c(3, 4, 0)) + + expect_identical( + df$with_columns( + val = pl$when( + pl$col("bar") > 0, + pl$col("foo") %% 2 != 0 + ) + $then(99) + $when( + pl$col("bar") == 0, + pl$col("foo") %% 2 == 0 + ) + $then(-1) + $otherwise(NA) + )$to_data_frame(), + data.frame( + foo = c(1, 3, 4), + bar = c(3, 4, 0), + val = c(99, 99, -1) + ) + ) +}) + +test_that("named input is not allowed in when", { + expect_error(pl$when(foo = 1), "Detected a named input") +})