Skip to content

Commit

Permalink
Preserve type of y in nest_join()
Browse files Browse the repository at this point in the history
Fixes #6295
  • Loading branch information
hadley committed Jul 26, 2022
1 parent 27f1c67 commit 66b333b
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 2 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# dplyr (development version)

* `nest_join()` now preserves the type of `y` (#6295).

* `*_join()` now error if you supply them with additional arguments that
aren't used (#6228).

* `df |> arrange(mydesc::desc(x))` works correctly when the mydesc re-exports
`dplyr::desc()` (#6231).

Expand Down
9 changes: 8 additions & 1 deletion R/join.r
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ inner_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = NULL) {
check_dots_used()
UseMethod("inner_join")
}

Expand Down Expand Up @@ -231,6 +232,7 @@ left_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = NULL) {
check_dots_used()
UseMethod("left_join")
}

Expand Down Expand Up @@ -269,6 +271,7 @@ right_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = NULL) {
check_dots_used()
UseMethod("right_join")
}

Expand Down Expand Up @@ -307,6 +310,7 @@ full_join <- function(x,
suffix = c(".x", ".y"),
...,
keep = NULL) {
check_dots_used()
UseMethod("full_join")
}

Expand Down Expand Up @@ -380,6 +384,7 @@ NULL
#' @export
#' @rdname filter-joins
semi_join <- function(x, y, by = NULL, copy = FALSE, ...) {
check_dots_used()
UseMethod("semi_join")
}

Expand All @@ -393,6 +398,7 @@ semi_join.data.frame <- function(x, y, by = NULL, copy = FALSE, ..., na_matches
#' @export
#' @rdname filter-joins
anti_join <- function(x, y, by = NULL, copy = FALSE, ...) {
check_dots_used()
UseMethod("anti_join")
}

Expand Down Expand Up @@ -443,6 +449,7 @@ nest_join <- function(x,
keep = NULL,
name = NULL,
...) {
check_dots_used()
UseMethod("nest_join")
}

Expand Down Expand Up @@ -506,7 +513,7 @@ nest_join.data.frame <- function(x,
# changing the key vars because of the cast
new_cols <- vec_cast(out[names(x_key)], vec_ptype2(x_key, y_key))

y_out <- set_names(y_in[vars$y$out], names(vars$y$out))
y_out <- set_names(y[vars$y$out], names(vars$y$out))
new_cols[[name_var]] <- map(y_loc, vec_slice, x = y_out)

out <- dplyr_col_modify(out, new_cols)
Expand Down
66 changes: 66 additions & 0 deletions tests/testthat/_snaps/join.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,72 @@
Message
Joining, by = "x"

# error if passed additional arguments

Code
inner_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
left_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
right_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
full_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
nest_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
anti_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"
Code
semi_join(df1, df2, on = "a")
Message
Joining, by = "a"
Condition
Error:
! Arguments in `...` must be used.
x Problematic argument:
* on = "a"

# nest_join computes common columns

Code
Expand Down
26 changes: 25 additions & 1 deletion tests/testthat/test-join.r
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,21 @@ test_that("filtering joins compute common columns", {
expect_snapshot(out <- semi_join(df1, df2))
})

test_that("error if passed additional arguments", {
df1 <- data.frame(a = 1:3)
df2 <- data.frame(a = 1)

expect_snapshot(error = TRUE, {
inner_join(df1, df2, on = "a")
left_join(df1, df2, on = "a")
right_join(df1, df2, on = "a")
full_join(df1, df2, on = "a")
nest_join(df1, df2, on = "a")
anti_join(df1, df2, on = "a")
semi_join(df1, df2, on = "a")
})
})

# nest_join ---------------------------------------------------------------

test_that("nest_join returns list of tibbles (#3570)",{
Expand All @@ -246,6 +261,15 @@ test_that("nest_join returns list of tibbles (#3570)",{
expect_s3_class(out$df2[[1]], "tbl_df")
})

test_that("nest_join respects types of y",{
df1 <- tibble(x = c(1, 2), y = c(2, 3))
df2 <- rowwise(tibble(x = c(1, 1), z = c(2, 3)))
out <- nest_join(df1, df2, by = "x", multiple = "all")

expect_s3_class(out$df2[[1]], "rowwise_df")
})


test_that("nest_join computes common columns", {
df1 <- tibble(x = c(1, 2), y = c(2, 3))
df2 <- tibble(x = c(1, 3), z = c(2, 3))
Expand Down Expand Up @@ -304,7 +328,7 @@ test_that("joins preserve groups", {

# See comment in nest_join
i <- count_regroups(out <- nest_join(gf1, gf2, by = "a", multiple = "all"))
expect_equal(i, 1L)
expect_equal(i, 4L)
expect_equal(group_vars(out), "a")
})

Expand Down

0 comments on commit 66b333b

Please sign in to comment.