From e66fcfe9fee8dcff007848614de9ed104c912eeb Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 17 Aug 2024 20:38:58 +0200 Subject: [PATCH] `n_distinct()`: restore `na.rm = TRUE` capability with single vector argument --- R/backend-dbplyr__duckdb_connection.R | 19 ++++++++++++++----- .../test-backend-dbplyr__duckdb_connection.R | 14 +++++++++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 9512db8f0..a36bf377a 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -78,13 +78,22 @@ duckdb_n_distinct <- function(..., na.rm = FALSE) { check_dots_unnamed() if (!identical(na.rm, FALSE)) { - stop("Parameter `na.rm = TRUE` in n_distinct() is currently not supported in DuckDB backend.", call. = FALSE) - } + cols <- list(...) + + # check for more than one vector argument: only one vector is supported + # Why not use ROW() as well? Because duckdb's FILTER clause does not support + # a windowing context as of now: https://duckdb.org/docs/sql/query_syntax/filter.html + if (length(cols) > 1) { + stop("n_distinct(): Only one vector argument is currently supported when `na.rm = TRUE`.", call. = FALSE) + } - # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function - str_struct <- paste0("row(", paste0(list(...), collapse = ", "), ")") + return(sql(paste0("COUNT(DISTINCT ", cols[[1]], ")"))) + } else { + # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function + str_struct <- paste0("row(", paste0(list(...), collapse = ", "), ")") - sql(paste0("COUNT(DISTINCT ", str_struct, ")")) + return(sql(paste0("COUNT(DISTINCT ", str_struct, ")"))) + } } # Customized translation functions for DuckDB SQL diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 6afa52569..0a4435c1e 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -281,6 +281,9 @@ test_that("aggregators translated correctly", { expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT row(x))}")) expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT row(x)) OVER ()}")) expect_equal(translate(n_distinct(x), window = TRUE, vars_group = "y"), sql(r"{COUNT(DISTINCT row(x)) OVER (PARTITION BY y)}")) + expect_equal(translate(n_distinct(x, na.rm = TRUE), window = FALSE), sql(r"{COUNT(DISTINCT x)}")) + expect_equal(translate(n_distinct(x, na.rm = TRUE), window = TRUE), sql(r"{COUNT(DISTINCT x) OVER ()}")) + expect_equal(translate(n_distinct(x, na.rm = TRUE), window = TRUE, vars_group = "y"), sql(r"{COUNT(DISTINCT x) OVER (PARTITION BY y)}")) }) test_that("two variable aggregates are translated correctly", { @@ -317,8 +320,17 @@ test_that("n_distinct() computations are correct", { df <- tbl(con, "df") df_na <- tbl(con, "df_na") + expect_equal( + pull(summarize(df, n = n_distinct(x, na.rm = TRUE)), n), + 2 + ) + expect_equal( + pull(summarize(df_na, n = n_distinct(x, na.rm = TRUE)), n), + 2 + ) + expect_error( - pull(summarize(df, n = n_distinct(x, na.rm = TRUE)), n) + pull(summarize(df, n = n_distinct(x, y, na.rm = TRUE)), n) ) # single column is working as usual