diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelper.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelper.scala index dd0224d19..31f13436b 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelper.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelper.scala @@ -17,21 +17,23 @@ package za.co.absa.pramen.core.utils.hive import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType trait HiveHelper { def createOrUpdateHiveTable(parquetPath: String, partitionBy: Seq[String], databaseName: String, - tableName: String): Unit + tableName: String)(implicit spark: SparkSession): Unit def repairHiveTable(databaseName: String, tableName: String): Unit + def getSchema(parquetPath: String)(implicit spark: SparkSession): StructType } object HiveHelper { def apply(implicit spark: SparkSession): HiveHelper = { - val queryExecutor = new SparkQueryExecutor() + val queryExecutor = new QueryExecutorSpark() new HiveHelperImpl(queryExecutor) } } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelperImpl.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelperImpl.scala index 918069bb5..f600646cc 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelperImpl.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/HiveHelperImpl.scala @@ -16,6 +16,7 @@ package za.co.absa.pramen.core.utils.hive +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory @@ -25,7 +26,8 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper { override def createOrUpdateHiveTable(parquetPath: String, partitionBy: Seq[String], databaseName: String, - tableName: String): Unit = { + tableName: String) + (implicit spark: SparkSession): Unit = { val fullTableName = getFullTable(databaseName, tableName) dropHiveTable(fullTableName) @@ -42,6 +44,12 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper { repairHiveTable(fullTableName) } + override def getSchema(parquetPath: String)(implicit spark: SparkSession): StructType = { + val df = spark.read.parquet(parquetPath) + + df.schema + } + private def getFullTable(databaseName: String, tableName: String): String = { if (databaseName.isEmpty) @@ -58,10 +66,10 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper { private def createHiveTable(fullTableName: String, parquetPath: String, partitionBy: Seq[String] - ): Unit = { + )(implicit spark: SparkSession): Unit = { log.info(s"Creating Hive table: $fullTableName...") - val schema = queryExecutor.getSchema(parquetPath) + val schema = getSchema(parquetPath) val sqlHiveCreate = s"""CREATE EXTERNAL TABLE IF NOT EXISTS diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutor.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutor.scala index 83317d2b1..cd915fa60 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutor.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutor.scala @@ -22,6 +22,4 @@ trait QueryExecutor { def doesTableExist(dbName: String, tableName: String): Boolean def execute(query: String): Unit - - def getSchema(parquetPath: String): StructType } diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala new file mode 100644 index 000000000..70252f1b2 --- /dev/null +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorJdbc.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.utils.hive + +import org.slf4j.LoggerFactory + +import java.sql.{Connection, ResultSet, SQLSyntaxErrorException} +import scala.util.Try + +class QueryExecutorJdbc(connection: Connection) extends QueryExecutor { + private val log = LoggerFactory.getLogger(this.getClass) + + override def doesTableExist(dbName: String, tableName: String): Boolean = { + val query = s"SELECT 1 FROM $tableName WHERE 0 = 1" + + Try { + execute(query) + }.isSuccess + } + + @throws[SQLSyntaxErrorException] + override def execute(query: String): Unit = { + log.info(s"Executing SQL: $query") + val statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + + val resultSet = statement.executeQuery(query) + + resultSet.close() + } +} diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/SparkQueryExecutor.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorSpark.scala similarity index 83% rename from pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/SparkQueryExecutor.scala rename to pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorSpark.scala index 608f3555f..18d4e81b9 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/SparkQueryExecutor.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/hive/QueryExecutorSpark.scala @@ -16,11 +16,11 @@ package za.co.absa.pramen.core.utils.hive -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory -class SparkQueryExecutor(implicit spark: SparkSession) extends QueryExecutor { +class QueryExecutorSpark(implicit spark: SparkSession) extends QueryExecutor { private val log = LoggerFactory.getLogger(this.getClass) override def doesTableExist(dbName: String, tableName: String): Boolean = { @@ -31,14 +31,9 @@ class SparkQueryExecutor(implicit spark: SparkSession) extends QueryExecutor { } } + @throws[AnalysisException] override def execute(query: String): Unit = { log.info(s"Executing SQL: $query") spark.sql(query).take(100) } - - override def getSchema(parquetPath: String): StructType = { - val df = spark.read.parquet(parquetPath) - - df.schema - } } diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorJdbcSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorJdbcSuite.scala new file mode 100644 index 000000000..4bb519e71 --- /dev/null +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorJdbcSuite.scala @@ -0,0 +1,86 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.tests.utils.hive + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.pramen.core.fixtures.RelationalDbFixture +import za.co.absa.pramen.core.samples.RdbExampleTable +import za.co.absa.pramen.core.utils.hive.QueryExecutorJdbc + +import java.sql.SQLSyntaxErrorException + +class QueryExecutorJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with RelationalDbFixture { + + override protected def beforeAll(): Unit = { + super.beforeAll() + RdbExampleTable.Company.initTable(getConnection) + } + + override protected def afterAll(): Unit = { + RdbExampleTable.Company.dropTable(getConnection) + super.afterAll() + } + + "QueryExecutorJdbc" should { + "execute JDBC queries" in { + val connection = getConnection + + val qe = new QueryExecutorJdbc(connection) + + qe.execute("SELECT * FROM company") + } + + "execute CREATE TABLE queries" in { + val connection = getConnection + + val qe = new QueryExecutorJdbc(connection) + + qe.execute("CREATE TABLE my_table (id INT)") + + val exist = qe.doesTableExist(database, "my_table") + + assert(exist) + } + + "throw an exception on errors" in { + val qe = new QueryExecutorJdbc(getConnection) + + val ex = intercept[SQLSyntaxErrorException] { + qe.execute("SELECT * FROM does_not_exist") + } + + assert(ex.getMessage.contains("object not found")) + } + + "return true if the table is found" in { + val qe = new QueryExecutorJdbc(getConnection) + + val exist = qe.doesTableExist(database, "company") + + assert(exist) + } + + "return false if the table is not found" in { + val qe = new QueryExecutorJdbc(getConnection) + + val exist = qe.doesTableExist(database, "does_not_exist") + + assert(!exist) + } + } +} diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/SparkQueryExecutorSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorSparkSuite.scala similarity index 83% rename from pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/SparkQueryExecutorSuite.scala rename to pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorSparkSuite.scala index f6e28bb37..978240ece 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/SparkQueryExecutorSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/hive/QueryExecutorSparkSuite.scala @@ -19,25 +19,25 @@ package za.co.absa.pramen.core.tests.utils.hive import org.apache.spark.sql.AnalysisException import org.scalatest.wordspec.AnyWordSpec import za.co.absa.pramen.core.base.SparkTestBase -import za.co.absa.pramen.core.utils.hive.SparkQueryExecutor +import za.co.absa.pramen.core.utils.hive.QueryExecutorSpark -class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase { +class QueryExecutorSparkSuite extends AnyWordSpec with SparkTestBase { import spark.implicits._ - "SparkQueryExecutor" should { + "QueryExecutorSpark" should { "execute Spark queries" in { val df = List(("A", 1), ("B", 2), ("C", 3)).toDF("a", "b") df.createOrReplaceTempView("temp") - val qe = new SparkQueryExecutor() + val qe = new QueryExecutorSpark() qe.execute("SELECT * FROM temp") } "throw an exception on errors" in { - val qe = new SparkQueryExecutor() + val qe = new QueryExecutorSpark() val ex = intercept[AnalysisException] { qe.execute("SELECT dummy from dummy") @@ -47,7 +47,7 @@ class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase { } "throw an exception if Hive is not initialized" in { - val qe = new SparkQueryExecutor() + val qe = new QueryExecutorSpark() val ex = intercept[IllegalArgumentException] { qe.doesTableExist("dummyDb", "dummyTable") @@ -57,7 +57,7 @@ class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase { } "return false if the table is not found" in { - val qe = new SparkQueryExecutor() + val qe = new QueryExecutorSpark() val exist = qe.doesTableExist("default", "dummyTable")