diff --git a/build.sbt b/build.sbt index 2851c88..992a103 100755 --- a/build.sbt +++ b/build.sbt @@ -88,6 +88,7 @@ mimaDefaultSettings ++ Seq( ProblemFilters.excludePackage("com.databricks.spark.csv.CsvRelation"), ProblemFilters.excludePackage("com.databricks.spark.csv.util.InferSchema"), ProblemFilters.excludePackage("com.databricks.spark.sql.readers"), + ProblemFilters.excludePackage("com.databricks.spark.csv.util.TypeCast"), // We allowed the private `CsvRelation` type to leak into the public method signature: ProblemFilters.exclude[IncompatibleResultTypeProblem]( "com.databricks.spark.csv.DefaultSource.createRelation") diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 1884748..b19244b 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -35,6 +35,7 @@ class CsvParser extends Serializable { private var parseMode: String = ParseModes.DEFAULT private var ignoreLeadingWhiteSpace: Boolean = false private var ignoreTrailingWhiteSpace: Boolean = false + private var treatEmptyValuesAsNulls: Boolean = false private var parserLib: String = ParserLibs.DEFAULT private var charset: String = TextFile.DEFAULT_CHARSET.name() private var inferSchema: Boolean = false @@ -84,6 +85,11 @@ class CsvParser extends Serializable { this } + def withTreatEmptyValuesAsNulls(treatAsNull: Boolean): CsvParser = { + this.treatEmptyValuesAsNulls = treatAsNull + this + } + def withParserLib(parserLib: String): CsvParser = { this.parserLib = parserLib this @@ -114,6 +120,7 @@ class CsvParser extends Serializable { parserLib, ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, + treatEmptyValuesAsNulls, schema, inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) @@ -132,6 +139,7 @@ class CsvParser extends Serializable { parserLib, ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, + treatEmptyValuesAsNulls, schema, inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index eacc04d..bc1ea69 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -43,6 +43,7 @@ case class CsvRelation protected[spark] ( parserLib: String, ignoreLeadingWhiteSpace: Boolean, ignoreTrailingWhiteSpace: Boolean, + treatEmptyValuesAsNulls: Boolean, userSchema: StructType = null, inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { @@ -113,7 +114,8 @@ case class CsvRelation protected[spark] ( index = 0 while (index < schemaFields.length) { val field = schemaFields(index) - rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable) + rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, + treatEmptyValuesAsNulls) index = index + 1 } Some(Row.fromSeq(rowArray)) diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index f0e35f5..7bd4b46 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -112,6 +112,14 @@ class DefaultSource } else { throw new Exception("Ignore white space flag can be true or false") } + val treatEmptyValuesAsNulls = parameters.getOrElse("treatEmptyValuesAsNulls", "false") + val treatEmptyValuesAsNullsFlag = if (treatEmptyValuesAsNulls == "false") { + false + } else if (treatEmptyValuesAsNulls == "true") { + true + } else { + throw new Exception("Treat empty values as null flag can be true or false") + } val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name()) // TODO validate charset? @@ -137,6 +145,7 @@ class DefaultSource parserLib, ignoreLeadingWhiteSpaceFlag, ignoreTrailingWhiteSpaceFlag, + treatEmptyValuesAsNullsFlag, schema, inferSchemaFlag)(sqlContext) } diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index b972040..8c5b5af 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -52,6 +52,7 @@ package object csv { parserLib = parserLib, ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, + treatEmptyValuesAsNulls = false, inferCsvSchema = inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -76,6 +77,7 @@ package object csv { parserLib = parserLib, ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, + treatEmptyValuesAsNulls = false, inferCsvSchema = inferSchema)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -116,11 +118,13 @@ package object csv { case None => None } + val nullValue = parameters.getOrElse("nullValue", "null") + val csvFormatBase = CSVFormat.DEFAULT .withDelimiter(delimiterChar) .withEscape(escapeChar) .withSkipHeaderRecord(false) - .withNullString("null") + .withNullString(nullValue) val csvFormat = quoteChar match { case Some(c) => csvFormatBase.withQuote(c) @@ -139,7 +143,7 @@ package object csv { .withDelimiter(delimiterChar) .withEscape(escapeChar) .withSkipHeaderRecord(false) - .withNullString("null") + .withNullString(nullValue) val csvFormat = quoteChar match { case Some(c) => csvFormatBase.withQuote(c) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index c3f4de2..265515e 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -35,8 +35,12 @@ object TypeCast { * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = { - if (datum == "" && nullable && !castType.isInstanceOf[StringType]){ + private[csv] def castTo( + datum: String, + castType: DataType, + nullable: Boolean = true, + treatEmptyValuesAsNulls: Boolean = false): Any = { + if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){ null } else { castType match { diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index da1b30c..d758815 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -167,6 +167,30 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) } + test("DSL test roundtrip nulls") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "null-numbers.csv" + val agesSchema = StructType(List(StructField("name", StringType, true), + StructField("age", IntegerType, true))) + + val agesRows = Seq(Row("alice", 35), Row("bob", null), Row(null, 24)) + val agesRdd = sqlContext.sparkContext.parallelize(agesRows) + val agesDf = sqlContext.createDataFrame(agesRdd, agesSchema) + + agesDf.saveAsCsvFile(copyFilePath, Map("header" -> "true", "nullValue" -> "")) + + val agesCopy = new CsvParser() + .withSchema(agesSchema) + .withUseHeader(true) + .withTreatEmptyValuesAsNulls(true) + .withParserLib(parserLib) + .csvFile(sqlContext, copyFilePath) + + assert(agesCopy.count == agesRows.size) + assert(agesCopy.collect.toSet == agesRows.toSet) + } test("DSL test with alternative delimiter and quote") { val results = new CsvParser()