diff --git a/R-package/R/lgb.Predictor.R b/R-package/R/lgb.Predictor.R index 26b70d778d40..7a709e5dd20b 100644 --- a/R-package/R/lgb.Predictor.R +++ b/R-package/R/lgb.Predictor.R @@ -219,6 +219,15 @@ Predictor <- R6::R6Class( preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) } + # Keep row names if possible + if (NROW(row.names(data)) && NROW(data) == NROW(preds)) { + if (is.null(dim(preds))) { + names(preds) <- row.names(data) + } else { + row.names(preds) <- row.names(data) + } + } + return(preds) } diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 2057324d07fd..19360daa688f 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -2,6 +2,8 @@ VERBOSITY <- as.integer( Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1") ) +library(Matrix) + test_that("Predictor$finalize() should not fail", { X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) y <- iris[["Sepal.Length"]] @@ -112,6 +114,115 @@ test_that("start_iteration works correctly", { expect_equal(pred_leaf1, pred_leaf2) }) +.expect_has_row_names <- function(pred, X) { + if (is.vector(pred)) { + rnames <- names(pred) + } else { + rnames <- row.names(pred) + } + expect_false(is.null(rnames)) + expect_true(is.vector(rnames)) + expect_true(length(rnames) > 0L) + expect_equal(row.names(X), rnames) +} + +.expect_doesnt_have_row_names <- function(pred) { + if (is.vector(pred)) { + expect_null(names(pred)) + } else { + expect_null(row.names(pred)) + } +} + +.check_all_row_name_expectations <- function(bst, X) { + + # dense matrix with row names + pred <- predict(bst, X) + .expect_has_row_names(pred, X) + pred <- predict(bst, X, rawscore = TRUE) + .expect_has_row_names(pred, X) + pred <- predict(bst, X, predleaf = TRUE) + .expect_has_row_names(pred, X) + pred <- predict(bst, X, predcontrib = TRUE) + .expect_has_row_names(pred, X) + + # dense matrix without row names + Xcopy <- X + row.names(Xcopy) <- NULL + pred <- predict(bst, Xcopy) + .expect_doesnt_have_row_names(pred) + + # sparse matrix with row names + Xcsc <- as(X, "CsparseMatrix") + pred <- predict(bst, Xcsc) + .expect_has_row_names(pred, Xcsc) + pred <- predict(bst, Xcsc, rawscore = TRUE) + .expect_has_row_names(pred, Xcsc) + pred <- predict(bst, Xcsc, predleaf = TRUE) + .expect_has_row_names(pred, Xcsc) + pred <- predict(bst, Xcsc, predcontrib = TRUE) + .expect_has_row_names(pred, Xcsc) + + # sparse matrix without row names + Xcopy <- Xcsc + row.names(Xcopy) <- NULL + pred <- predict(bst, Xcopy) + .expect_doesnt_have_row_names(pred) +} + +test_that("predict() keeps row names from data (regression)", { + data("mtcars") + X <- as.matrix(mtcars[, -1L]) + y <- as.numeric(mtcars[, 1L]) + dtrain <- lgb.Dataset( + X + , label = y + , params = list( + max_bins = 5L + , min_data_in_bin = 1L + ) + ) + bst <- lgb.train( + data = dtrain + , obj = "regression" + , nrounds = 5L + , verbose = VERBOSITY + , params = list(min_data_in_leaf = 1L) + ) + .check_all_row_name_expectations(bst, X) +}) + +test_that("predict() keeps row names from data (binary classification)", { + data(agaricus.train, package = "lightgbm") + X <- as.matrix(agaricus.train$data) + y <- agaricus.train$label + row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "") + dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L)) + bst <- lgb.train( + data = dtrain + , obj = "binary" + , nrounds = 5L + , verbose = VERBOSITY + ) + .check_all_row_name_expectations(bst, X) +}) + +test_that("predict() keeps row names from data (multi-class classification)", { + data(iris) + y <- as.numeric(iris$Species) - 1.0 + X <- as.matrix(iris[, names(iris) != "Species"]) + row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "") + dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L)) + bst <- lgb.train( + data = dtrain + , obj = "multiclass" + , params = list(num_class = 3L) + , nrounds = 5L + , verbose = VERBOSITY + ) + .check_all_row_name_expectations(bst, X) +}) + test_that("predictions for regression and binary classification are returned as vectors", { data(mtcars) X <- as.matrix(mtcars[, -1L])