diff --git a/NEWS.md b/NEWS.md index f725ba52b9..ba4cd217f2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # dplyr (development version) +* `nest_join()` now preserves the type of `y` (#6295). + +* `*_join()` now error if you supply them with additional arguments that + aren't used (#6228). + * `df |> arrange(mydesc::desc(x))` works correctly when the mydesc re-exports `dplyr::desc()` (#6231). diff --git a/R/join.r b/R/join.r index d4cae09074..5522b5df72 100644 --- a/R/join.r +++ b/R/join.r @@ -193,6 +193,7 @@ inner_join <- function(x, suffix = c(".x", ".y"), ..., keep = NULL) { + check_dots_used() UseMethod("inner_join") } @@ -231,6 +232,7 @@ left_join <- function(x, suffix = c(".x", ".y"), ..., keep = NULL) { + check_dots_used() UseMethod("left_join") } @@ -269,6 +271,7 @@ right_join <- function(x, suffix = c(".x", ".y"), ..., keep = NULL) { + check_dots_used() UseMethod("right_join") } @@ -307,6 +310,7 @@ full_join <- function(x, suffix = c(".x", ".y"), ..., keep = NULL) { + check_dots_used() UseMethod("full_join") } @@ -380,6 +384,7 @@ NULL #' @export #' @rdname filter-joins semi_join <- function(x, y, by = NULL, copy = FALSE, ...) { + check_dots_used() UseMethod("semi_join") } @@ -393,6 +398,7 @@ semi_join.data.frame <- function(x, y, by = NULL, copy = FALSE, ..., na_matches #' @export #' @rdname filter-joins anti_join <- function(x, y, by = NULL, copy = FALSE, ...) { + check_dots_used() UseMethod("anti_join") } @@ -443,6 +449,7 @@ nest_join <- function(x, keep = NULL, name = NULL, ...) { + check_dots_used() UseMethod("nest_join") } @@ -506,7 +513,7 @@ nest_join.data.frame <- function(x, # changing the key vars because of the cast new_cols <- vec_cast(out[names(x_key)], vec_ptype2(x_key, y_key)) - y_out <- set_names(y_in[vars$y$out], names(vars$y$out)) + y_out <- set_names(y[vars$y$out], names(vars$y$out)) new_cols[[name_var]] <- map(y_loc, vec_slice, x = y_out) out <- dplyr_col_modify(out, new_cols) diff --git a/tests/testthat/_snaps/join.md b/tests/testthat/_snaps/join.md index 93db061d16..47e284b28a 100644 --- a/tests/testthat/_snaps/join.md +++ b/tests/testthat/_snaps/join.md @@ -60,6 +60,72 @@ Message Joining, by = "x" +# error if passed additional arguments + + Code + inner_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + left_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + right_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + full_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + nest_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + anti_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + Code + semi_join(df1, df2, on = "a") + Message + Joining, by = "a" + Condition + Error: + ! Arguments in `...` must be used. + x Problematic argument: + * on = "a" + # nest_join computes common columns Code diff --git a/tests/testthat/test-join.r b/tests/testthat/test-join.r index ae77383276..fff1805b44 100644 --- a/tests/testthat/test-join.r +++ b/tests/testthat/test-join.r @@ -234,6 +234,21 @@ test_that("filtering joins compute common columns", { expect_snapshot(out <- semi_join(df1, df2)) }) +test_that("error if passed additional arguments", { + df1 <- data.frame(a = 1:3) + df2 <- data.frame(a = 1) + + expect_snapshot(error = TRUE, { + inner_join(df1, df2, on = "a") + left_join(df1, df2, on = "a") + right_join(df1, df2, on = "a") + full_join(df1, df2, on = "a") + nest_join(df1, df2, on = "a") + anti_join(df1, df2, on = "a") + semi_join(df1, df2, on = "a") + }) +}) + # nest_join --------------------------------------------------------------- test_that("nest_join returns list of tibbles (#3570)",{ @@ -246,6 +261,15 @@ test_that("nest_join returns list of tibbles (#3570)",{ expect_s3_class(out$df2[[1]], "tbl_df") }) +test_that("nest_join respects types of y",{ + df1 <- tibble(x = c(1, 2), y = c(2, 3)) + df2 <- rowwise(tibble(x = c(1, 1), z = c(2, 3))) + out <- nest_join(df1, df2, by = "x", multiple = "all") + + expect_s3_class(out$df2[[1]], "rowwise_df") +}) + + test_that("nest_join computes common columns", { df1 <- tibble(x = c(1, 2), y = c(2, 3)) df2 <- tibble(x = c(1, 3), z = c(2, 3)) @@ -304,7 +328,7 @@ test_that("joins preserve groups", { # See comment in nest_join i <- count_regroups(out <- nest_join(gf1, gf2, by = "a", multiple = "all")) - expect_equal(i, 1L) + expect_equal(i, 4L) expect_equal(group_vars(out), "a") })