Skip to content

Commit

Permalink
Allow negative n/prop in slice_sample() (#6405)
Browse files Browse the repository at this point in the history
And generally improve logic by doing more work in `get_slice_size()`.

Fixes #6402
  • Loading branch information
hadley authored Aug 18, 2022
1 parent 978d7e3 commit ee313a6
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 62 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# dplyr (development version)

* `slice_sample()` now accepts negative `n` and `prop` values (#6402).

* `slice_*()` now requires `n` to be an integer.

* New `case_match()` function that is a "vectorised switch" variant of
`case_when()` that matches on values rather than logical expressions. It is
like a SQL "simple" `CASE WHEN` statement, whereas `case_when()` is like a SQL
Expand Down
53 changes: 27 additions & 26 deletions R/slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ slice_head <- function(.data, ..., n, prop) {
slice_head.data.frame <- function(.data, ..., n, prop) {
size <- get_slice_size(n = n, prop = prop)
idx <- function(n) {
to <- size(n)
if (to > n) {
to <- n
}
seq2(1, to)
seq2(1, size(n))
}

dplyr_local_error_call()
Expand All @@ -163,11 +159,7 @@ slice_tail <- function(.data, ..., n, prop) {
slice_tail.data.frame <- function(.data, ..., n, prop) {
size <- get_slice_size(n = n, prop = prop)
idx <- function(n) {
from <- n - size(n) + 1
if (from < 1L) {
from <- 1L
}
seq2(from, n)
seq2(n - size(n) + 1, n)
}

dplyr_local_error_call()
Expand Down Expand Up @@ -263,7 +255,7 @@ slice_sample <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE)

#' @export
slice_sample.data.frame <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) {
size <- get_slice_size(n = n, prop = prop, allow_negative = FALSE)
size <- get_slice_size(n = n, prop = prop, allow_outsize = replace)

dplyr_local_error_call()
slice(.data, local({
Expand Down Expand Up @@ -423,48 +415,57 @@ check_slice_n_prop <- function(n, prop, error_call = caller_env()) {
list(type = "n", n = 1L)
} else if (!missing(n) && missing(prop)) {
n <- check_constant(n, "n", error_call = error_call)
if (!is.numeric(n) || length(n) != 1 || is.na(n)) {
abort("`n` must be a single number.", call = error_call)
if (!is_integerish(n, n = 1) || is.na(n)) {
abort(
glue("`n` must be a round number, not {obj_type_friendly(n)}."),
call = error_call
)
}
list(type = "n", n = n)
} else if (!missing(prop) && missing(n)) {
prop <- check_constant(prop, "prop", error_call = error_call)
if (!is.numeric(prop) || length(prop) != 1 || is.na(prop)) {
abort("`prop` must be a single number.", call = error_call)
abort(
glue("`prop` must be a number, not {obj_type_friendly(prop)}."),
call = error_call
)
}
list(type = "prop", prop = prop)
} else {
abort("Must supply `n` or `prop`, but not both.", call = error_call)
}
}

get_slice_size <- function(n, prop, allow_negative = TRUE, error_call = caller_env()) {
# Always returns an integer between 0 and the group size
get_slice_size <- function(n, prop, allow_outsize = FALSE, error_call = caller_env()) {
slice_input <- check_slice_n_prop(n, prop, error_call = error_call)

if (slice_input$type == "n") {
if (slice_input$n >= 0) {
function(n) floor(slice_input$n)
} else if (allow_negative) {
function(n) ceiling(n + slice_input$n)
function(n) clamp(0, floor(slice_input$n), if (allow_outsize) Inf else n)
} else {
abort("`n` must be positive.", call = error_call)
function(n) clamp(0, ceiling(n + slice_input$n), n)
}
} else if (slice_input$type == "prop") {
if (slice_input$prop >= 0) {
function(n) floor(slice_input$prop * n)
} else if (allow_negative) {
function(n) ceiling(n + slice_input$prop * n)
function(n) clamp(0, floor(slice_input$prop * n), if (allow_outsize) Inf else n)
} else {
abort("`prop` must be positive.", call = error_call)
function(n) clamp(0, ceiling(n + slice_input$prop * n), n)
}
}
}

sample_int <- function(n, size, replace = FALSE, wt = NULL, call = caller_env()) {
if (!replace && n < size) {
size <- n
clamp <- function(min, x, max) {
if (x < min) {
min
} else if (x > max) {
max
} else {
x
}
}

sample_int <- function(n, size, replace = FALSE, wt = NULL) {
if (size == 0L) {
integer(0)
} else {
Expand Down
35 changes: 15 additions & 20 deletions tests/testthat/_snaps/slice.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,27 @@
slice_head(df, n = "a")
Condition
Error in `slice_head()`:
! `n` must be a single number.
! `n` must be a round number, not a string.
Code
slice_tail(df, n = "a")
Condition
Error in `slice_tail()`:
! `n` must be a single number.
! `n` must be a round number, not a string.
Code
slice_min(df, x, n = "a")
Condition
Error in `slice_min()`:
! `n` must be a single number.
! `n` must be a round number, not a string.
Code
slice_max(df, x, n = "a")
Condition
Error in `slice_max()`:
! `n` must be a single number.
! `n` must be a round number, not a string.
Code
slice_sample(df, n = "a")
Condition
Error in `slice_sample()`:
! `n` must be a single number.
! `n` must be a round number, not a string.

# get_slice_size() validates its inputs

Expand All @@ -91,12 +91,20 @@
get_slice_size(n = "a")
Condition
Error:
! `n` must be a single number.
! `n` must be a round number, not a string.
Code
get_slice_size(prop = "a")
Condition
Error:
! `prop` must be a single number.
! `prop` must be a number, not a string.

# n must be an integer

Code
slice_head(df, n = 1.1)
Condition
Error in `slice_head()`:
! `n` must be a round number, not a number.

# slice_*() checks that `n=` is explicitly named and ... is empty

Expand Down Expand Up @@ -252,16 +260,3 @@
Error in `slice_sample()`:
! `replace` must be `TRUE` or `FALSE`, not `NA`.

# slice_sample() handles n= and prop=

Code
slice_sample(gf, n = -1)
Condition
Error in `slice_sample()`:
! `n` must be positive.
Code
slice_sample(gf, prop = -1)
Condition
Error in `slice_sample()`:
! `prop` must be positive.

56 changes: 40 additions & 16 deletions tests/testthat/test-slice.r
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,38 @@ test_that("get_slice_size() validates its inputs", {
})
})

test_that("get_slice_size() converts proportions to numbers", {
expect_equal(get_slice_size(prop = 0.5)(10), 5)
expect_equal(get_slice_size(prop = -0.1)(10), 9)
expect_equal(get_slice_size(prop = 1.1)(10), 11)
test_that("get_slice_size() standardises prop", {
expect_equal(get_slice_size(prop = 0)(10), 0)

expect_equal(get_slice_size(prop = 0.4)(10), 4)
expect_equal(get_slice_size(prop = 2)(10), 10)
expect_equal(get_slice_size(prop = 2, allow_outsize = TRUE)(10), 20)

expect_equal(get_slice_size(prop = -0.4)(10), 6)
expect_equal(get_slice_size(prop = -2)(10), 0)
})

test_that("get_slice_size() converts negative to positive", {
expect_equal(get_slice_size(n = -1)(10), 9)
expect_equal(get_slice_size(prop = -0.5)(10), 5)
test_that("get_slice_size() standardises n", {
expect_equal(get_slice_size(n = 0)(10), 0)

expect_equal(get_slice_size(n = 4)(10), 4)
expect_equal(get_slice_size(n = 20)(10), 10)
expect_equal(get_slice_size(n = 20, allow_outsize = TRUE)(10), 20)

expect_equal(get_slice_size(n = -4)(10), 6)
expect_equal(get_slice_size(n = -20)(10), 0)
})

test_that("get_slice_size() rounds non-integers", {
expect_equal(get_slice_size(n = 1.6)(10), 1)
test_that("get_slice_size() rounds prop in the right direction", {
expect_equal(get_slice_size(prop = 0.16)(10), 1)
expect_equal(get_slice_size(n = -1.6)(10), 9)
expect_equal(get_slice_size(prop = -0.16)(10), 9)
})

test_that("n must be an integer", {
df <- tibble(x = 1:5)
expect_snapshot(slice_head(df, n = 1.1), error = TRUE)
})

test_that("functions silently truncate results", {
# only test positive n because get_slice_size() converts all others

Expand Down Expand Up @@ -293,6 +307,12 @@ test_that("slice_sample() respects weight_by and replaces", {
expect_equal(out$x, c(1, 1))
})

test_that("slice_sample() can increase rows iff replace = TRUE", {
df <- tibble(x = 1:10)
expect_equal(nrow(slice_sample(df, n = 20, replace = FALSE)), 10)
expect_equal(nrow(slice_sample(df, n = 20, replace = TRUE)), 20)
})

test_that("slice_sample() checks size of `weight_by=` (#5922)", {
df <- tibble(x = 1:10)
expect_snapshot(slice_sample(df, n = 2, weight_by = 1:6), error = TRUE)
Expand All @@ -312,16 +332,20 @@ test_that("`slice_sample()` validates `replace`", {
})
})

test_that("slice_sample() handles n= and prop=", {
test_that("slice_sample() handles positive n= and prop=", {
gf <- group_by(tibble(a = 1, b = 1), a)
expect_equal(slice_sample(gf, n = 3, replace = TRUE), gf[c(1, 1, 1), ])
expect_equal(slice_sample(gf, prop = 3, replace = TRUE), gf[c(1, 1, 1), ])
})

# Unlike other helpers, can't supply negative values
expect_snapshot(error = TRUE, {
slice_sample(gf, n = -1)
slice_sample(gf, prop = -1)
})
test_that("slice_sample() handles negative n= and prop= (#6402)", {
df <- tibble(a = 1:2)
expect_equal(nrow(slice_sample(df, n = -1)), 1)
expect_equal(nrow(slice_sample(df, prop = -0.5)), 1)

# even if larger than n
expect_equal(nrow(slice_sample(df, n = -3)), 0)
expect_equal(nrow(slice_sample(df, prop = -2)), 0)
})

# slice_head/slice_tail ---------------------------------------------------
Expand Down

0 comments on commit ee313a6

Please sign in to comment.