Skip to content

Commit

Permalink
apacheGH-34640: [R] Can't read in partitioning column in CSV datasets…
Browse files Browse the repository at this point in the history
… when both (non-hive) partition and schema supplied (apache#37658)

### Rationale for this change

It wasn't possible to use the partitioning column in the dataset when reading in CSV datasets and supplying both a schema and a partition variable.

### What changes are included in this PR?

This PR updates the code which creates the `CSVReadOptions` object and makes sure we don't pass in the partition variable column name as a column name there, as previously this was resulting in an error.

### Are these changes tested?

Yes

### Are there any user-facing changes?

Yes 

* Closes: apache#34640

Lead-authored-by: Nic Crane <thisisnic@gmail.com>
Co-authored-by: Dewey Dunnington <dewey@dunnington.ca>
Signed-off-by: Nic Crane <thisisnic@gmail.com>
  • Loading branch information
2 people authored and dgreiss committed Feb 17, 2024
1 parent 911bc3f commit 49c8af0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
2 changes: 1 addition & 1 deletion r/R/dataset-factory.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DatasetFactory$create <- function(x,
}

if (is.character(format)) {
format <- FileFormat$create(match.arg(format), ...)
format <- FileFormat$create(match.arg(format), partitioning = partitioning, ...)
} else {
assert_is(format, "FileFormat")
}
Expand Down
31 changes: 18 additions & 13 deletions r/R/dataset-format.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ FileFormat <- R6Class("FileFormat",
type = function() dataset___FileFormat__type_name(self)
)
)
FileFormat$create <- function(format, schema = NULL, ...) {

FileFormat$create <- function(format, schema = NULL, partitioning = NULL, ...) {
opt_names <- names(list(...))
if (format %in% c("csv", "text", "txt") || any(opt_names %in% c("delim", "delimiter"))) {
CsvFileFormat$create(schema = schema, ...)
CsvFileFormat$create(schema = schema, partitioning = partitioning, ...)
} else if (format == "tsv") {
# This delimiter argument is ignored.
CsvFileFormat$create(delimiter = "\t", schema = schema, ...)
CsvFileFormat$create(delimiter = "\t", schema = schema, partitioning = partitioning, ...)
} else if (format == "parquet") {
ParquetFileFormat$create(...)
} else if (format %in% c("ipc", "arrow", "feather")) { # These are aliases for the same thing
Expand Down Expand Up @@ -189,16 +190,19 @@ JsonFileFormat$create <- function(...) {
#'
#' @export
CsvFileFormat <- R6Class("CsvFileFormat", inherit = FileFormat)
CsvFileFormat$create <- function(...) {
CsvFileFormat$create <- function(..., partitioning = NULL) {

dots <- list(...)
options <- check_csv_file_format_args(dots)
check_schema(options[["schema"]], options[["read_options"]]$column_names)

options <- check_csv_file_format_args(dots, partitioning = partitioning)
check_schema(options[["schema"]], partitioning, options[["read_options"]]$column_names)

dataset___CsvFileFormat__Make(options$parse_options, options$convert_options, options$read_options)
}

# Check all arguments are valid
check_csv_file_format_args <- function(args) {
check_csv_file_format_args <- function(args, partitioning = NULL) {

options <- list(
parse_options = args$parse_options,
convert_options = args$convert_options,
Expand All @@ -223,7 +227,7 @@ check_csv_file_format_args <- function(args) {
}

if (is.null(args$read_options)) {
options$read_options <- do.call(csv_file_format_read_opts, args)
options$read_options <- do.call(csv_file_format_read_opts, c(args, list(partitioning = partitioning)))
} else if (is.list(args$read_options)) {
options$read_options <- do.call(CsvReadOptions$create, args$read_options)
}
Expand Down Expand Up @@ -339,7 +343,7 @@ check_ambiguous_options <- function(passed_opts, opts1, opts2) {
}
}

check_schema <- function(schema, column_names) {
check_schema <- function(schema, partitioning, column_names) {
if (!is.null(schema) && !inherits(schema, "Schema")) {
abort(paste0(
"`schema` must be an object of class 'Schema' not '",
Expand All @@ -348,7 +352,7 @@ check_schema <- function(schema, column_names) {
))
}

schema_names <- names(schema)
schema_names <- setdiff(names(schema), names(partitioning))

if (!is.null(schema) && !identical(schema_names, column_names)) {
missing_from_schema <- setdiff(column_names, schema_names)
Expand Down Expand Up @@ -451,7 +455,8 @@ csv_file_format_convert_opts <- function(...) {
do.call(CsvConvertOptions$create, opts)
}

csv_file_format_read_opts <- function(schema = NULL, ...) {
csv_file_format_read_opts <- function(schema = NULL, partitioning = NULL, ...) {

opts <- list(...)
# Filter out arguments meant for CsvParseOptions/CsvConvertOptions
arrow_opts <- c(names(formals(CsvParseOptions$create)), "parse_options")
Expand All @@ -477,9 +482,9 @@ csv_file_format_read_opts <- function(schema = NULL, ...) {

if (!is.null(schema) && null_or_true(opts[["column_names"]]) && null_or_true(opts[["col_names"]])) {
if (any(is_readr_opt)) {
opts[["col_names"]] <- names(schema)
opts[["col_names"]] <- setdiff(names(schema), names(partitioning))
} else {
opts[["column_names"]] <- names(schema)
opts[["column_names"]] <- setdiff(names(schema), names(partitioning))
}
}

Expand Down
34 changes: 34 additions & 0 deletions r/tests/testthat/test-dataset-csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,37 @@ test_that("CSVReadOptions field access", {
expect_equal(options$block_size, 1048576L)
expect_equal(options$encoding, "UTF-8")
})

test_that("GH-34640 - CSV datasets are read in correctly when both schema and partitioning supplied", {
target_schema <- schema(
int = int32(), dbl = float32(), lgl = bool(), chr = utf8(),
fct = utf8(), ts = timestamp(unit = "s"), part = int8()
)

ds <- open_dataset(
csv_dir,
partitioning = schema(part = int32()),
format = "csv",
schema = target_schema,
skip = 1
)
expect_r6_class(ds$format, "CsvFileFormat")
expect_r6_class(ds$filesystem, "LocalFileSystem")
expect_identical(names(ds), c(names(df1), "part"))
expect_identical(names(collect(ds)), c(names(df1), "part"))

expect_identical(dim(ds), c(20L, 7L))
expect_equal(schema(ds), target_schema)

expect_equal(
ds %>%
select(string = chr, integer = int, part) %>%
filter(integer > 6 & part == 5) %>%
collect() %>%
summarize(mean = mean(as.numeric(integer))),
df1 %>%
select(string = chr, integer = int) %>%
filter(integer > 6) %>%
summarize(mean = mean(integer))
)
})

0 comments on commit 49c8af0

Please sign in to comment.