diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index d7ce9a0ce8894..7704f9d3d3015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.datasources.csv +import java.nio.charset.Charset + import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.RecordWriter -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{LineRecordReader, TextInputFormat} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.Logging @@ -113,6 +116,71 @@ object CSVRelation extends Logging { } } +/** + * Because `TextInputFormat` in Hadoop does not support non-ascii compatible encodings, + * We need another `InputFormat` to handle the encodings. See SPARK-13108. + */ +private[csv] class EncodingTextInputFormat extends TextInputFormat { + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[LongWritable, Text] = { + val conf: Configuration = { + // Use reflection to get the Configuration. This is necessary because TaskAttemptContext is + // a class in Hadoop 1.x and an interface in Hadoop 2.x. + val method = context.getClass.getMethod("getConfiguration") + method.invoke(context).asInstanceOf[Configuration] + } + val charset = Charset.forName(conf.get(EncodingTextInputFormat.ENCODING_KEY, "UTF-8")) + val charsetName = charset.name + val safeRecordDelimiterBytes = { + val delimiter = "\n" + val recordDelimiterBytes = delimiter.getBytes(charset) + EncodingTextInputFormat.stripBOM(charsetName, recordDelimiterBytes) + } + + new LineRecordReader(safeRecordDelimiterBytes) { + var isFirst = true + override def getCurrentValue: Text = { + val value = super.getCurrentValue + if (isFirst) { + isFirst = false + val safeBytes = EncodingTextInputFormat.stripBOM(charsetName, value.getBytes) + new Text(safeBytes) + } else { + value + } + } + } + } +} + +private[csv] object EncodingTextInputFormat { + // configuration key for encoding type + val ENCODING_KEY = "encodinginputformat.encoding" + // BOM bytes for UTF-8, UTF-16 and UTF-32 + private val utf8BOM = Array(0xEF.toByte, 0xBB.toByte, 0xBF.toByte) + private val utf16beBOM = Array(0xFE.toByte, 0xFF.toByte) + private val utf16leBOM = Array(0xFF.toByte, 0xFE.toByte) + private val utf32beBOM = Array(0x00.toByte, 0x00.toByte, 0xFE.toByte, 0xFF.toByte) + private val utf32leBOM = Array(0xFF.toByte, 0xFE.toByte, 0x00.toByte, 0x00.toByte) + + def stripBOM(charsetName: String, bytes: Array[Byte]): Array[Byte] = { + charsetName match { + case "UTF-8" if bytes.startsWith(utf8BOM) => + bytes.slice(utf8BOM.length, bytes.length) + case "UTF-16" | "UTF-16BE" if bytes.startsWith(utf16beBOM) => + bytes.slice(utf16beBOM.length, bytes.length) + case "UTF-16LE" if bytes.startsWith(utf16leBOM) => + bytes.slice(utf16leBOM.length, bytes.length) + case "UTF-32" | "UTF-32BE" if bytes.startsWith(utf32beBOM) => + bytes.slice(utf32beBOM.length, bytes.length) + case "UTF-32LE" if bytes.startsWith(utf32leBOM) => + bytes.slice(utf32leBOM.length, bytes.length) + case _ => bytes + } + } +} + private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, @@ -120,11 +188,11 @@ private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { if (bucketId.isDefined) sys.error("csv doesn't support bucketing") - new CsvOutputWriter(path, dataSchema, context, params) + new CSVOutputWriter(path, dataSchema, context, params) } } -private[sql] class CsvOutputWriter( +private[sql] class CSVOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index aff672281d640..9d7b21ae026c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -165,8 +165,10 @@ class DefaultSource extends FileFormat with DataSourceRegister { sqlContext.sparkContext.textFile(location) } else { val charset = options.charset + val conf = sqlContext.sparkContext.hadoopConfiguration + conf.set(EncodingTextInputFormat.ENCODING_KEY, charset) sqlContext.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](location) + .newAPIHadoopFile[LongWritable, Text, EncodingTextInputFormat](location) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) } } diff --git a/sql/core/src/test/resources/cars_utf-16.csv b/sql/core/src/test/resources/cars_utf-16.csv new file mode 100644 index 0000000000000..a94ed8cb14be5 Binary files /dev/null and b/sql/core/src/test/resources/cars_utf-16.csv differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 53027bb698bf8..2d2621a8d74aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -36,6 +36,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val carsFile = "cars.csv" private val carsMalformedFile = "cars-malformed.csv" private val carsFile8859 = "cars_iso-8859-1.csv" + private val carsFileUTF16 = "cars_utf-16.csv" private val carsTsvFile = "cars.tsv" private val carsAltFile = "cars-alternative.csv" private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" @@ -152,6 +153,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("1-9588-osi")) } + test("non-ascii compatible encoding") { + val cars = sqlContext + .read + .format("csv") + .option("charset", "utf-16") + .option("header", "true") + .load(testFile(carsFileUTF16)) + + verifyCars(cars, withHeader = true, checkTypes = false) + } + test("test different encoding") { // scalastyle:off sqlContext.sql(