Skip to content

Fix population scaling with other_keys + supporting fixes/changes #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.3
Version: 0.1.4
Authors@R: c(
person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
Expand Down Expand Up @@ -273,6 +274,7 @@ importFrom(hardhat,extract_recipe)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(magrittr,extract2)
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
## Improvements

- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
- Fix `layer_population_scaling` default `by` with `other_keys`.
- Make key column inference more consistent within the package and with current `epiprocess`.
- Fix `quantile_reg()` producing error when asked to output just median-level predictions.
- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift`

## Bug fixes
Expand Down
5 changes: 2 additions & 3 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ autoplot.epi_workflow <- function(
if (!is.null(shift)) {
edf <- mutate(edf, time_value = time_value + shift)
}
extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
if (length(extra_keys) == 0L) extra_keys <- NULL
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
other_keys = extra_keys %||% character()
other_keys = other_keys
)
if (is.null(predictions)) {
return(autoplot(
Expand Down
2 changes: 2 additions & 0 deletions R/epipredict-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#' @importFrom cli cli_abort cli_warn
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
#' @importFrom dplyr full_join relocate summarise everything
#' @importFrom dplyr inner_join
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
#' @importFrom magrittr extract2
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom stats poly predict lm residuals quantile
Expand Down
21 changes: 13 additions & 8 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
#' @export
key_colnames.recipe <- function(x, ...) {
key_colnames.recipe <- function(x, ..., exclude = character()) {
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
keys <- x$var_info$variable[x$var_info$role %in% "key"]
c(geo_key, keys, time_key) %||% character(0L)
full_key <- c(geo_key, keys, time_key) %||% character(0L)
full_key[!full_key %in% exclude]
}

#' @export
key_colnames.epi_workflow <- function(x, ...) {
key_colnames.epi_workflow <- function(x, ..., exclude = character()) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
molded_names <- names(mold$extras$roles)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
c(geo_key, keys, time_key) %||% character(0L)
molded_roles <- mold$extras$roles
extras <- bind_cols(molded_roles$geo_value, molded_roles$key, molded_roles$time_value)
full_key <- names(extras)
if (length(full_key) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_keys <- c("geo_value", "time_value")
full_key <- potential_keys[potential_keys %in% names(bind_cols(molded_roles))]
}
full_key[!full_key %in% exclude]
}

kill_time_value <- function(v) {
Expand Down
41 changes: 34 additions & 7 deletions R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
#' inverting the existing scaling.
#' @param by A (possibly named) character vector of variables to join by.
#'
#' If `NULL`, the default, the function will perform a natural join, using all
#' variables in common across the `epi_df` produced by the `predict()` call
#' and the user-provided dataset.
#' If columns in that `epi_df` and `df` have the same name (and aren't
#' included in `by`), `.df` is added to the one from the user-provided data
#' to disambiguate.
#' If `NULL`, the default, the function will try to infer a reasonable set of
#' columns. First, it will try to join by all variables in the test data with
#' roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in `df`;
#' these roles are automatically set if you are using an `epi_df`, or you can
#' use, e.g., `update_role`. If no such roles are set, it will try to perform a
#' natural join, using variables in common between the training/test data and
#' population data.
#'
#' If columns in the training/testing data and `df` have the same name (and
#' aren't included in `by`), a `.df` suffix is added to the one from the
#' user-provided data to disambiguate.
#'
#' To join by different variables on the `epi_df` and `df`, use a named vector.
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
Expand Down Expand Up @@ -135,6 +140,26 @@ slather.layer_population_scaling <-
)
rlang::check_dots_empty()

if (is.null(object$by)) {
# Assume `layer_predict` has calculated the prediction keys and other
# layers don't change the prediction key colnames:
prediction_key_colnames <- names(components$keys)
lhs_potential_keys <- prediction_key_colnames
rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col))
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
suggested_min_keys <- kill_time_value(lhs_potential_keys)
if (!all(suggested_min_keys %in% object$by)) {
cli_warn(c(
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {object$by}",
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
">" = "Manually specify `by =` to silence"
), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys")
}
}

object$by <- object$by %||% intersect(
epi_keys_only(components$predictions),
colnames(select(object$df, !object$df_pop_col))
Expand All @@ -152,10 +177,12 @@ slather.layer_population_scaling <-
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions))

components$predictions <- left_join(
components$predictions <- inner_join(
components$predictions,
object$df,
by = object$by,
relationship = "many-to-one",
unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(across(
Expand Down
2 changes: 1 addition & 1 deletion R/make_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ make_quantile_reg <- function() {

# can't make a method because object is second
out <- switch(type,
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this didn't raise an error before...

rqs = {
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
dist_quantiles(x, list(object$tau))
Expand Down
90 changes: 75 additions & 15 deletions R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
#' inverting the existing scaling.
#' @param by A (possibly named) character vector of variables to join by.
#'
#' If `NULL`, the default, the function will perform a natural join, using all
#' variables in common across the `epi_df` produced by the `predict()` call
#' and the user-provided dataset.
#' If columns in that `epi_df` and `df` have the same name (and aren't
#' included in `by`), `.df` is added to the one from the user-provided data
#' to disambiguate.
#' If `NULL`, the default, the function will try to infer a reasonable set of
#' columns. First, it will try to join by all variables in the training/test
#' data with roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in
#' `df`; these roles are automatically set if you are using an `epi_df`, or you
#' can use, e.g., `update_role`. If no such roles are set, it will try to
#' perform a natural join, using variables in common between the training/test
#' data and population data.
#'
#' If columns in the training/testing data and `df` have the same name (and
#' aren't included in `by`), a `.df` suffix is added to the one from the
#' user-provided data to disambiguate.
#'
#' To join by different variables on the `epi_df` and `df`, use a named vector.
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
#' to `df$states`. To join by multiple variables, use a vector with length > 1.
#' For example, `by = c("geo_value" = "states", "county" = "county")` will match
#' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
#'
#' See [dplyr::left_join()] for more details.
#' See [dplyr::inner_join()] for more details.
#' @param df_pop_col the name of the column in the data frame `df` that
#' contains the population data and will be used for scaling.
#' This should be one column.
Expand Down Expand Up @@ -89,13 +94,25 @@ step_population_scaling <-
suffix = "_scaled",
skip = FALSE,
id = rand_id("population_scaling")) {
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id)
arg_is_lgl(create_new, skip)
arg_is_chr(df_pop_col, suffix, id)
if (rlang::dots_n(...) == 0L) {
cli_abort(c(
"`...` must not be empty.",
">" = "Please provide one or more tidyselect expressions in `...`
specifying the columns to which scaling should be applied.",
">" = "If you really want to list `step_population_scaling` in your
recipe but not have it do anything, you can use a tidyselection
that selects zero variables, such as `c()`."
))
}
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id)
arg_is_chr(role, df_pop_col, suffix, id)
hardhat::validate_column_names(df, df_pop_col)
arg_is_chr(by, allow_null = TRUE)
arg_is_numeric(rate_rescaling)
if (rate_rescaling <= 0) {
cli_abort("`rate_rescaling` must be a positive number.")
}
arg_is_lgl(create_new, skip)

recipes::add_step(
recipe,
Expand Down Expand Up @@ -138,6 +155,42 @@ step_population_scaling_new <-

#' @export
prep.step_population_scaling <- function(x, training, info = NULL, ...) {
if (is.null(x$by)) {
rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col)
lhs_potential_keys <- info %>%
filter(role %in% c("geo_value", "key", "time_value")) %>%
extract2("variable") %>%
unique() # in case of weird var with multiple of above roles
if (length(lhs_potential_keys) == 0L) {
# We're working with a recipe and tibble, and *_role hasn't set up any of
# the above roles. Let's say any column could actually act as a key, and
# lean on `intersect` below to make this something reasonable.
lhs_potential_keys <- names(training)
}
suggested_min_keys <- info %>%
filter(role %in% c("geo_value", "key")) %>%
extract2("variable") %>%
unique()
# (0 suggested keys if we weren't given any epikeytime var info.)
x$by <- intersect(lhs_potential_keys, rhs_potential_keys)
if (length(x$by) == 0L) {
cli_stop(c(
"Couldn't guess a default for `by`",
">" = "Please rename columns in your population data to match those in your training data,
or manually specify `by =` in `step_population_scaling()`."
), class = "epipredict__step_population_scaling__default_by_no_intersection")
}
if (!all(suggested_min_keys %in% x$by)) {
cli_warn(c(
"{setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {x$by}.",
">" = "Double-check whether column names on the population `df` match those for your training data.",
">" = "Consider using population data with breakdowns by {suggested_min_keys}.",
">" = "Manually specify `by =` to silence."
), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys")
}
}
step_population_scaling_new(
terms = x$terms,
role = x$role,
Expand All @@ -156,10 +209,14 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {

#' @export
bake.step_population_scaling <- function(object, new_data, ...) {
object$by <- object$by %||% intersect(
epi_keys_only(new_data),
colnames(select(object$df, !object$df_pop_col))
)
if (is.null(object$by)) {
cli::cli_abort(c(
"`by` was not set and no default was filled in",
">" = "If this was a fit recipe generated from an older version
of epipredict that you loaded in from a file,
please regenerate with the current version of epipredict."
))
}
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
hardhat::validate_column_names(new_data, joinby$x)
hardhat::validate_column_names(object$df, joinby$y)
Expand All @@ -177,7 +234,10 @@ bake.step_population_scaling <- function(object, new_data, ...) {
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(new_data))

left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>%
inner_join(new_data, object$df,
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(
across(
all_of(object$columns),
Expand Down
82 changes: 64 additions & 18 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,76 @@ check_pname <- function(res, preds, object, newname = NULL) {
res
}

# Copied from `epiprocess`:

#' "Format" a character vector of column/variable names for cli interpolation
#'
#' Designed to give good output if interpolated with cli. Main purpose is to add
#' backticks around variable names when necessary, and something other than an
#' empty string if length 0.
#'
#' @param x `chr`; e.g., `colnames` of some data frame
#' @param empty string; what should be output if `x` is of length 0?
#' @return `chr`
#' @keywords internal
format_varnames <- function(x, empty = "*none*") {
if (length(x) == 0L) {
empty
} else {
as.character(syms(x))
}
}

grab_forged_keys <- function(forged, workflow, new_data) {
forged_roles <- names(forged$extras$roles)
extras <- bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
# 1. these are the keys in the test data after prep/bake
new_keys <- names(extras)
# 2. these are the keys in the training data
# 1. keys in the training data post-prep, based on roles:
old_keys <- key_colnames(workflow)
# 3. these are the keys in the test data as input
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
cli_warn(paste(
"Not all epi keys that were present in the training data are available",
"in `new_data`. Predictions will have only the available keys."
# 2. keys in the test data post-bake, based on roles & structure:
forged_roles <- forged$extras$roles
new_key_tbl <- bind_cols(forged_roles$geo_value, forged_roles$key, forged_roles$time_value)
new_keys <- names(new_key_tbl)
if (length(new_keys) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_new_keys <- c("geo_value", "time_value")
forged_tbl <- bind_cols(forged$extras$roles)
new_keys <- potential_new_keys[potential_new_keys %in% names(forged_tbl)]
new_key_tbl <- forged_tbl[new_keys]
}
# Softly validate:
if (!(setequal(old_keys, new_keys))) {
cli_warn(c(
"Inconsistent epikeytime identifier columns specified/inferred in training vs. in testing data.",
"i" = "training epikeytime columns, based on roles post-mold/prep: {format_varnames(old_keys)}",
"i" = "testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}",
"*" = "",
">" = 'Some mismatches can be addressed by using `epi_df`s instead of tibbles, or by using `update_role`
to assign pre-`prep` columns the "geo_value", "key", and "time_value" roles.'
))
}
if (is_epi_df(new_data)) {
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
# Convert `new_key_tbl` to `epi_df` if not renaming columns nor violating
# `epi_df` invariants. Require that our key is a unique key in any case.
if (all(c("geo_value", "time_value") %in% new_keys)) {
maybe_as_of <- attr(new_data, "metadata")$as_of # NULL if wasn't epi_df
try(return(as_epi_df(new_key_tbl, other_keys = new_keys, as_of = maybe_as_of)),
silent = TRUE
)
}
if (anyDuplicated(new_key_tbl)) {
duplicate_key_tbl <- new_key_tbl %>% filter(.by = everything(), dplyr::n() > 1L)
error_part1 <- cli::format_error(
c(
"Specified/inferred key columns had repeated combinations in the forged/baked test data.",
"i" = "Key columns: {format_varnames(new_keys)}",
"Duplicated keys:"
)
)
error_part2 <- capture.output(print(duplicate_key_tbl))
rlang::abort(
paste(collapse = "\n", c(error_part1, error_part2)),
class = "epipredict__grab_forged_keys__nonunique_key"
)
} else {
return(new_key_tbl)
}
extras
}

get_parsnip_mode <- function(trainer) {
Expand Down
Loading
Loading