Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13108][SQL] Support for ascii compatible encodings at CSV data source #11016

Closed
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.Charset

import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{LineRecordReader, TextInputFormat}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.Logging
Expand Down Expand Up @@ -113,18 +116,83 @@ object CSVRelation extends Logging {
}
}

/**
* Because `TextInputFormat` in Hadoop does not support non-ascii compatible encodings,
* We need another `InputFormat` to handle the encodings. See SPARK-13108.
*/
private[csv] class EncodingTextInputFormat extends TextInputFormat {
override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[LongWritable, Text] = {
val conf: Configuration = {
// Use reflection to get the Configuration. This is necessary because TaskAttemptContext is
// a class in Hadoop 1.x and an interface in Hadoop 2.x.
val method = context.getClass.getMethod("getConfiguration")
method.invoke(context).asInstanceOf[Configuration]
}
val charset = Charset.forName(conf.get(EncodingTextInputFormat.ENCODING_KEY, "UTF-8"))
val charsetName = charset.name
val safeRecordDelimiterBytes = {
val delimiter = "\n"
val recordDelimiterBytes = delimiter.getBytes(charset)
EncodingTextInputFormat.stripBOM(charsetName, recordDelimiterBytes)
}

new LineRecordReader(safeRecordDelimiterBytes) {
var isFirst = true
override def getCurrentValue: Text = {
val value = super.getCurrentValue
if (isFirst) {
isFirst = false
val safeBytes = EncodingTextInputFormat.stripBOM(charsetName, value.getBytes)
new Text(safeBytes)
} else {
value
}
}
}
}
}

private[csv] object EncodingTextInputFormat {
// configuration key for encoding type
val ENCODING_KEY = "encodinginputformat.encoding"
// BOM bytes for UTF-8, UTF-16 and UTF-32
private val utf8BOM = Array(0xEF.toByte, 0xBB.toByte, 0xBF.toByte)
private val utf16beBOM = Array(0xFE.toByte, 0xFF.toByte)
private val utf16leBOM = Array(0xFF.toByte, 0xFE.toByte)
private val utf32beBOM = Array(0x00.toByte, 0x00.toByte, 0xFE.toByte, 0xFF.toByte)
private val utf32leBOM = Array(0xFF.toByte, 0xFE.toByte, 0x00.toByte, 0x00.toByte)

def stripBOM(charsetName: String, bytes: Array[Byte]): Array[Byte] = {
charsetName match {
case "UTF-8" if bytes.startsWith(utf8BOM) =>
bytes.slice(utf8BOM.length, bytes.length)
case "UTF-16" | "UTF-16BE" if bytes.startsWith(utf16beBOM) =>
bytes.slice(utf16beBOM.length, bytes.length)
case "UTF-16LE" if bytes.startsWith(utf16leBOM) =>
bytes.slice(utf16leBOM.length, bytes.length)
case "UTF-32" | "UTF-32BE" if bytes.startsWith(utf32beBOM) =>
bytes.slice(utf32beBOM.length, bytes.length)
case "UTF-32LE" if bytes.startsWith(utf32leBOM) =>
bytes.slice(utf32leBOM.length, bytes.length)
case _ => bytes
}
}
}

private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) sys.error("csv doesn't support bucketing")
new CsvOutputWriter(path, dataSchema, context, params)
new CSVOutputWriter(path, dataSchema, context, params)
}
}

private[sql] class CsvOutputWriter(
private[sql] class CSVOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ class DefaultSource extends FileFormat with DataSourceRegister {
sqlContext.sparkContext.textFile(location)
} else {
val charset = options.charset
val conf = sqlContext.sparkContext.hadoopConfiguration
conf.set(EncodingTextInputFormat.ENCODING_KEY, charset)
sqlContext.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
.newAPIHadoopFile[LongWritable, Text, EncodingTextInputFormat](location)
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
}
}
Expand Down
Binary file added sql/core/src/test/resources/cars_utf-16.csv
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val carsFile = "cars.csv"
private val carsMalformedFile = "cars-malformed.csv"
private val carsFile8859 = "cars_iso-8859-1.csv"
private val carsFileUTF16 = "cars_utf-16.csv"
private val carsTsvFile = "cars.tsv"
private val carsAltFile = "cars-alternative.csv"
private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv"
Expand Down Expand Up @@ -152,6 +153,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(exception.getMessage.contains("1-9588-osi"))
}

test("non-ascii compatible encoding") {
val cars = sqlContext
.read
.format("csv")
.option("charset", "utf-16")
.option("header", "true")
.load(testFile(carsFileUTF16))

verifyCars(cars, withHeader = true, checkTypes = false)
}

test("test different encoding") {
// scalastyle:off
sqlContext.sql(
Expand Down