Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Keep row names in output from predict #4977

Merged
merged 15 commits into from
Apr 5, 2022
34 changes: 22 additions & 12 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ Booster <- R6::R6Class(
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
#'
#' If passing `reshape=TRUE` and `data` has row names, the output will also have those
#' row names.
#' @param params a list of additional named parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#' the "Predict Parameters" section of the documentation} for a list of parameters and
Expand Down Expand Up @@ -803,19 +806,26 @@ predict.lgb.Booster <- function(object,
))
}

return(
object$predict(
data = newdata
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
, params = params
)
pred <- object$predict(
data = newdata
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
, params = params
)

if (reshape && NROW(row.names(newdata))) {
if (is.null(dim(pred))) {
names(pred) <- row.names(newdata)
} else {
row.names(pred) <- row.names(newdata)
}
}
return(pred)
}

#' @name print.lgb.Booster
Expand Down
5 changes: 4 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

116 changes: 116 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)

library(Matrix)

test_that("Predictor$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]]
Expand Down Expand Up @@ -111,3 +113,117 @@ test_that("start_iteration works correctly", {
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE)
expect_equal(pred_leaf1, pred_leaf2)
})

test_that("predictions keep row names from the data", {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests look great, thank you! To speed up debugging if they break due to changes in future PRs, can you please break them into 3 test cases?

  • "predict() keeps row names from data (binary classification)"
  • "predict() keeps row names from data (multi-class classification)"
  • "predict() keeps row names from data (regression)"

To avoid duplicating test helper code and to make it clear that those methods were written just for these tests, please move them out of the test case and rename them to .expect_has_row_names() and .expect_doesnt_have_row_names().

data("mtcars")
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = -1L
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)

expect_has_row_names <- function(pred) {
if (is.vector(pred)) {
rnames <- names(pred)
} else {
rnames <- row.names(pred)
}
expect_false(is.null(rnames))
expect_true(is.vector(rnames))
expect_equal(row.names(X), rnames)
}

expect_doesnt_have_row_names <- function(pred) {
expect_null(row.names(pred))
}

pred <- predict(bst, X, reshape = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, rawscore = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predleaf = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predcontrib = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = FALSE)
expect_doesnt_have_row_names(pred)
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy, reshape = TRUE)
expect_doesnt_have_row_names(pred)

Xcsc <- as(X, "CsparseMatrix")
pred <- predict(bst, Xcsc, reshape = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, Xcsc, reshape = TRUE, rawscore = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, Xcsc, reshape = TRUE, predleaf = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, Xcsc, reshape = TRUE, predcontrib = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, Xcsc, reshape = FALSE)
expect_doesnt_have_row_names(pred)
Xcopy <- Xcsc
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy, reshape = TRUE)
expect_doesnt_have_row_names(pred)

data(iris)
y <- as.numeric(iris$Species) - 1.0
X <- as.matrix(iris[, names(iris) != "Species"])
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "multiclass"
, params = list(num_class = 3L)
, nrounds = 5L
, verbose = -1L
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)

pred <- predict(bst, X, reshape = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, rawscore = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predleaf = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predcontrib = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = FALSE)
expect_doesnt_have_row_names(pred)
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy, reshape = TRUE)
expect_doesnt_have_row_names(pred)

data(agaricus.train, package = "lightgbm")
X <- agaricus.train$data
y <- agaricus.train$label
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "binary"
, nrounds = 5L
, verbose = -1L
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
)

pred <- predict(bst, X, reshape = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, rawscore = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predleaf = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = TRUE, predcontrib = TRUE)
expect_has_row_names(pred)
pred <- predict(bst, X, reshape = FALSE)
expect_doesnt_have_row_names(pred)
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy, reshape = TRUE)
expect_doesnt_have_row_names(pred)
})