Skip to content

Commit

Permalink
#164 Add the ability to create Hive tables via a JDBC connection.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Mar 22, 2023
1 parent dc7ecdb commit 5a0e000
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@
package za.co.absa.pramen.core.utils.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

trait HiveHelper {
def createOrUpdateHiveTable(parquetPath: String,
partitionBy: Seq[String],
databaseName: String,
tableName: String): Unit
tableName: String)(implicit spark: SparkSession): Unit

def repairHiveTable(databaseName: String,
tableName: String): Unit

def getSchema(parquetPath: String)(implicit spark: SparkSession): StructType
}

object HiveHelper {
def apply(implicit spark: SparkSession): HiveHelper = {
val queryExecutor = new SparkQueryExecutor()
val queryExecutor = new QueryExecutorSpark()
new HiveHelperImpl(queryExecutor)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package za.co.absa.pramen.core.utils.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

Expand All @@ -25,7 +26,8 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper {
override def createOrUpdateHiveTable(parquetPath: String,
partitionBy: Seq[String],
databaseName: String,
tableName: String): Unit = {
tableName: String)
(implicit spark: SparkSession): Unit = {
val fullTableName = getFullTable(databaseName, tableName)

dropHiveTable(fullTableName)
Expand All @@ -42,6 +44,12 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper {
repairHiveTable(fullTableName)
}

override def getSchema(parquetPath: String)(implicit spark: SparkSession): StructType = {
val df = spark.read.parquet(parquetPath)

df.schema
}

private def getFullTable(databaseName: String,
tableName: String): String = {
if (databaseName.isEmpty)
Expand All @@ -58,10 +66,10 @@ class HiveHelperImpl(queryExecutor: QueryExecutor) extends HiveHelper {
private def createHiveTable(fullTableName: String,
parquetPath: String,
partitionBy: Seq[String]
): Unit = {
)(implicit spark: SparkSession): Unit = {

log.info(s"Creating Hive table: $fullTableName...")
val schema = queryExecutor.getSchema(parquetPath)
val schema = getSchema(parquetPath)

val sqlHiveCreate =
s"""CREATE EXTERNAL TABLE IF NOT EXISTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,4 @@ trait QueryExecutor {
def doesTableExist(dbName: String, tableName: String): Boolean

def execute(query: String): Unit

def getSchema(parquetPath: String): StructType
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.utils.hive

import org.slf4j.LoggerFactory

import java.sql.{Connection, ResultSet, SQLSyntaxErrorException}
import scala.util.Try

class QueryExecutorJdbc(connection: Connection) extends QueryExecutor {
private val log = LoggerFactory.getLogger(this.getClass)

override def doesTableExist(dbName: String, tableName: String): Boolean = {
val query = s"SELECT 1 FROM $tableName WHERE 0 = 1"

Try {
execute(query)
}.isSuccess
}

@throws[SQLSyntaxErrorException]
override def execute(query: String): Unit = {
log.info(s"Executing SQL: $query")
val statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

val resultSet = statement.executeQuery(query)

resultSet.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package za.co.absa.pramen.core.utils.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

class SparkQueryExecutor(implicit spark: SparkSession) extends QueryExecutor {
class QueryExecutorSpark(implicit spark: SparkSession) extends QueryExecutor {
private val log = LoggerFactory.getLogger(this.getClass)

override def doesTableExist(dbName: String, tableName: String): Boolean = {
Expand All @@ -31,14 +31,9 @@ class SparkQueryExecutor(implicit spark: SparkSession) extends QueryExecutor {
}
}

@throws[AnalysisException]
override def execute(query: String): Unit = {
log.info(s"Executing SQL: $query")
spark.sql(query).take(100)
}

override def getSchema(parquetPath: String): StructType = {
val df = spark.read.parquet(parquetPath)

df.schema
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.tests.utils.hive

import org.scalatest.BeforeAndAfterAll
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.core.fixtures.RelationalDbFixture
import za.co.absa.pramen.core.samples.RdbExampleTable
import za.co.absa.pramen.core.utils.hive.QueryExecutorJdbc

import java.sql.SQLSyntaxErrorException

class QueryExecutorJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with RelationalDbFixture {

override protected def beforeAll(): Unit = {
super.beforeAll()
RdbExampleTable.Company.initTable(getConnection)
}

override protected def afterAll(): Unit = {
RdbExampleTable.Company.dropTable(getConnection)
super.afterAll()
}

"QueryExecutorJdbc" should {
"execute JDBC queries" in {
val connection = getConnection

val qe = new QueryExecutorJdbc(connection)

qe.execute("SELECT * FROM company")
}

"execute CREATE TABLE queries" in {
val connection = getConnection

val qe = new QueryExecutorJdbc(connection)

qe.execute("CREATE TABLE my_table (id INT)")

val exist = qe.doesTableExist(database, "my_table")

assert(exist)
}

"throw an exception on errors" in {
val qe = new QueryExecutorJdbc(getConnection)

val ex = intercept[SQLSyntaxErrorException] {
qe.execute("SELECT * FROM does_not_exist")
}

assert(ex.getMessage.contains("object not found"))
}

"return true if the table is found" in {
val qe = new QueryExecutorJdbc(getConnection)

val exist = qe.doesTableExist(database, "company")

assert(exist)
}

"return false if the table is not found" in {
val qe = new QueryExecutorJdbc(getConnection)

val exist = qe.doesTableExist(database, "does_not_exist")

assert(!exist)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ package za.co.absa.pramen.core.tests.utils.hive
import org.apache.spark.sql.AnalysisException
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.core.base.SparkTestBase
import za.co.absa.pramen.core.utils.hive.SparkQueryExecutor
import za.co.absa.pramen.core.utils.hive.QueryExecutorSpark

class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase {
class QueryExecutorSparkSuite extends AnyWordSpec with SparkTestBase {

import spark.implicits._

"SparkQueryExecutor" should {
"QueryExecutorSpark" should {
"execute Spark queries" in {
val df = List(("A", 1), ("B", 2), ("C", 3)).toDF("a", "b")

df.createOrReplaceTempView("temp")

val qe = new SparkQueryExecutor()
val qe = new QueryExecutorSpark()

qe.execute("SELECT * FROM temp")
}

"throw an exception on errors" in {
val qe = new SparkQueryExecutor()
val qe = new QueryExecutorSpark()

val ex = intercept[AnalysisException] {
qe.execute("SELECT dummy from dummy")
Expand All @@ -47,7 +47,7 @@ class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase {
}

"throw an exception if Hive is not initialized" in {
val qe = new SparkQueryExecutor()
val qe = new QueryExecutorSpark()

val ex = intercept[IllegalArgumentException] {
qe.doesTableExist("dummyDb", "dummyTable")
Expand All @@ -57,7 +57,7 @@ class SparkQueryExecutorSuite extends AnyWordSpec with SparkTestBase {
}

"return false if the table is not found" in {
val qe = new SparkQueryExecutor()
val qe = new QueryExecutorSpark()

val exist = qe.doesTableExist("default", "dummyTable")

Expand Down

0 comments on commit 5a0e000

Please sign in to comment.