diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala index 04fa92b25b..98a3a08134 100644 --- a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -5,6 +5,7 @@ package org.opensearch.sql +import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession, Row} import org.apache.spark.sql.types._ @@ -31,8 +32,17 @@ object SQLJob { val auth = args(5) val region = args(6) + val conf: SparkConf = new SparkConf() + .setAppName("SQLJob") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + .set("spark.datasource.flint.host", host) + .set("spark.datasource.flint.port", port) + .set("spark.datasource.flint.scheme", scheme) + .set("spark.datasource.flint.auth", auth) + .set("spark.datasource.flint.region", region) + // Create a SparkSession - val spark = SparkSession.builder().appName("SQLJob").getOrCreate() + val spark = SparkSession.builder().config(conf).enableHiveSupport().getOrCreate() try { // Execute SQL query @@ -89,7 +99,7 @@ object SQLJob { // Create the data rows val rows = Seq(( - result.toJSON.collect.toList.map(_.replaceAll("\"", "'")), + result.toJSON.collect.toList.map(_.replaceAll("'", "\\\\'").replaceAll("\"", "'")), resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), sys.env.getOrElse("EMR_STEP_ID", "unknown"), spark.sparkContext.applicationId))