From 49c8af06887ea5bb2c1be2d99eb36d076b8ff736 Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Thu, 28 Sep 2023 21:41:56 +0100 Subject: [PATCH] GH-34640: [R] Can't read in partitioning column in CSV datasets when both (non-hive) partition and schema supplied (#37658) ### Rationale for this change It wasn't possible to use the partitioning column in the dataset when reading in CSV datasets and supplying both a schema and a partition variable. ### What changes are included in this PR? This PR updates the code which creates the `CSVReadOptions` object and makes sure we don't pass in the partition variable column name as a column name there, as previously this was resulting in an error. ### Are these changes tested? Yes ### Are there any user-facing changes? Yes * Closes: #34640 Lead-authored-by: Nic Crane Co-authored-by: Dewey Dunnington Signed-off-by: Nic Crane --- r/R/dataset-factory.R | 2 +- r/R/dataset-format.R | 31 +++++++++++++++----------- r/tests/testthat/test-dataset-csv.R | 34 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/r/R/dataset-factory.R b/r/R/dataset-factory.R index adb7353a043b9..d3d4f639e3729 100644 --- a/r/R/dataset-factory.R +++ b/r/R/dataset-factory.R @@ -49,7 +49,7 @@ DatasetFactory$create <- function(x, } if (is.character(format)) { - format <- FileFormat$create(match.arg(format), ...) + format <- FileFormat$create(match.arg(format), partitioning = partitioning, ...) } else { assert_is(format, "FileFormat") } diff --git a/r/R/dataset-format.R b/r/R/dataset-format.R index e1f434d60cd50..5dd00b9344014 100644 --- a/r/R/dataset-format.R +++ b/r/R/dataset-format.R @@ -74,13 +74,14 @@ FileFormat <- R6Class("FileFormat", type = function() dataset___FileFormat__type_name(self) ) ) -FileFormat$create <- function(format, schema = NULL, ...) { + +FileFormat$create <- function(format, schema = NULL, partitioning = NULL, ...) { opt_names <- names(list(...)) if (format %in% c("csv", "text", "txt") || any(opt_names %in% c("delim", "delimiter"))) { - CsvFileFormat$create(schema = schema, ...) + CsvFileFormat$create(schema = schema, partitioning = partitioning, ...) } else if (format == "tsv") { # This delimiter argument is ignored. - CsvFileFormat$create(delimiter = "\t", schema = schema, ...) + CsvFileFormat$create(delimiter = "\t", schema = schema, partitioning = partitioning, ...) } else if (format == "parquet") { ParquetFileFormat$create(...) } else if (format %in% c("ipc", "arrow", "feather")) { # These are aliases for the same thing @@ -189,16 +190,19 @@ JsonFileFormat$create <- function(...) { #' #' @export CsvFileFormat <- R6Class("CsvFileFormat", inherit = FileFormat) -CsvFileFormat$create <- function(...) { +CsvFileFormat$create <- function(..., partitioning = NULL) { + dots <- list(...) - options <- check_csv_file_format_args(dots) - check_schema(options[["schema"]], options[["read_options"]]$column_names) + + options <- check_csv_file_format_args(dots, partitioning = partitioning) + check_schema(options[["schema"]], partitioning, options[["read_options"]]$column_names) dataset___CsvFileFormat__Make(options$parse_options, options$convert_options, options$read_options) } # Check all arguments are valid -check_csv_file_format_args <- function(args) { +check_csv_file_format_args <- function(args, partitioning = NULL) { + options <- list( parse_options = args$parse_options, convert_options = args$convert_options, @@ -223,7 +227,7 @@ check_csv_file_format_args <- function(args) { } if (is.null(args$read_options)) { - options$read_options <- do.call(csv_file_format_read_opts, args) + options$read_options <- do.call(csv_file_format_read_opts, c(args, list(partitioning = partitioning))) } else if (is.list(args$read_options)) { options$read_options <- do.call(CsvReadOptions$create, args$read_options) } @@ -339,7 +343,7 @@ check_ambiguous_options <- function(passed_opts, opts1, opts2) { } } -check_schema <- function(schema, column_names) { +check_schema <- function(schema, partitioning, column_names) { if (!is.null(schema) && !inherits(schema, "Schema")) { abort(paste0( "`schema` must be an object of class 'Schema' not '", @@ -348,7 +352,7 @@ check_schema <- function(schema, column_names) { )) } - schema_names <- names(schema) + schema_names <- setdiff(names(schema), names(partitioning)) if (!is.null(schema) && !identical(schema_names, column_names)) { missing_from_schema <- setdiff(column_names, schema_names) @@ -451,7 +455,8 @@ csv_file_format_convert_opts <- function(...) { do.call(CsvConvertOptions$create, opts) } -csv_file_format_read_opts <- function(schema = NULL, ...) { +csv_file_format_read_opts <- function(schema = NULL, partitioning = NULL, ...) { + opts <- list(...) # Filter out arguments meant for CsvParseOptions/CsvConvertOptions arrow_opts <- c(names(formals(CsvParseOptions$create)), "parse_options") @@ -477,9 +482,9 @@ csv_file_format_read_opts <- function(schema = NULL, ...) { if (!is.null(schema) && null_or_true(opts[["column_names"]]) && null_or_true(opts[["col_names"]])) { if (any(is_readr_opt)) { - opts[["col_names"]] <- names(schema) + opts[["col_names"]] <- setdiff(names(schema), names(partitioning)) } else { - opts[["column_names"]] <- names(schema) + opts[["column_names"]] <- setdiff(names(schema), names(partitioning)) } } diff --git a/r/tests/testthat/test-dataset-csv.R b/r/tests/testthat/test-dataset-csv.R index c83c30ff904ff..ff1712646a472 100644 --- a/r/tests/testthat/test-dataset-csv.R +++ b/r/tests/testthat/test-dataset-csv.R @@ -593,3 +593,37 @@ test_that("CSVReadOptions field access", { expect_equal(options$block_size, 1048576L) expect_equal(options$encoding, "UTF-8") }) + +test_that("GH-34640 - CSV datasets are read in correctly when both schema and partitioning supplied", { + target_schema <- schema( + int = int32(), dbl = float32(), lgl = bool(), chr = utf8(), + fct = utf8(), ts = timestamp(unit = "s"), part = int8() + ) + + ds <- open_dataset( + csv_dir, + partitioning = schema(part = int32()), + format = "csv", + schema = target_schema, + skip = 1 + ) + expect_r6_class(ds$format, "CsvFileFormat") + expect_r6_class(ds$filesystem, "LocalFileSystem") + expect_identical(names(ds), c(names(df1), "part")) + expect_identical(names(collect(ds)), c(names(df1), "part")) + + expect_identical(dim(ds), c(20L, 7L)) + expect_equal(schema(ds), target_schema) + + expect_equal( + ds %>% + select(string = chr, integer = int, part) %>% + filter(integer > 6 & part == 5) %>% + collect() %>% + summarize(mean = mean(as.numeric(integer))), + df1 %>% + select(string = chr, integer = int) %>% + filter(integer > 6) %>% + summarize(mean = mean(integer)) + ) +})