Skip to content

Commit

Permalink
#398 Fix the usage of custom schema and decimal corrections with Hive…
Browse files Browse the repository at this point in the history
… JDBC source.
  • Loading branch information
yruslan committed Dec 9, 2024
1 parent e1ab9b0 commit 5d8e74f
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.Query
import za.co.absa.pramen.core.config.Keys
import za.co.absa.pramen.core.reader.model.TableReaderJdbcConfig
import za.co.absa.pramen.core.utils.{ConfigUtils, JdbcNativeUtils, JdbcSparkUtils, TimeUtils}
import za.co.absa.pramen.core.utils.{ConfigUtils, JdbcNativeUtils, JdbcSparkUtils, SparkUtils, TimeUtils}

import java.time.format.DateTimeFormatter
import java.time.{Instant, LocalDate}
Expand Down Expand Up @@ -173,6 +173,10 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
.load()

if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) {
if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}

JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach(schema =>
df = spark
.read
Expand Down Expand Up @@ -222,8 +226,8 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
}

object TableReaderJdbc {
def apply(conf: Config, parent: String)(implicit spark: SparkSession): TableReaderJdbc = {
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(conf, parent)
def apply(conf: Config, workflowConf: Config, parent: String)(implicit spark: SparkSession): TableReaderJdbc = {
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(conf, workflowConf, parent)

val urlSelector = JdbcUrlSelector(jdbcTableReaderConfig.jdbcConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ object TableReaderJdbcNative {
val FETCH_SIZE_KEY = "option.fetchsize"

def apply(conf: Config,
workflowConf: Config,
parent: String = "")
(implicit spark: SparkSession): TableReaderJdbcNative = {
val tableReaderJdbcOrig = TableReaderJdbcConfig.load(conf, parent)
val tableReaderJdbcOrig = TableReaderJdbcConfig.load(conf, workflowConf, parent)
val jdbcConfig = getJdbcConfig(tableReaderJdbcOrig, conf)
val tableReaderJdbc = tableReaderJdbcOrig.copy(jdbcConfig = jdbcConfig)
val urlSelector = JdbcUrlSelector(tableReaderJdbc.jdbcConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package za.co.absa.pramen.core.reader.model
import com.typesafe.config.Config
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.sql.QuotingPolicy
import za.co.absa.pramen.core.config.Keys.SPECIAL_CHARACTERS_IN_COLUMN_NAMES
import za.co.absa.pramen.core.utils.ConfigUtils

case class TableReaderJdbcConfig(
Expand All @@ -33,6 +34,7 @@ case class TableReaderJdbcConfig(
correctDecimalsFixPrecision: Boolean = false,
enableSchemaMetadata: Boolean = false,
useJdbcNative: Boolean = false,
specialCharacters: String = " ",
identifierQuotingPolicy: QuotingPolicy = QuotingPolicy.Auto,
sqlGeneratorClass: Option[String] = None
)
Expand All @@ -55,7 +57,7 @@ object TableReaderJdbcConfig {
val IDENTIFIER_QUOTING_POLICY = "identifier.quoting.policy"
val SQL_GENERATOR_CLASS_KEY = "sql.generator.class"

def load(conf: Config, parent: String = ""): TableReaderJdbcConfig = {
def load(conf: Config, workflowConf: Config, parent: String = ""): TableReaderJdbcConfig = {
ConfigUtils.validatePathsExistence(conf, parent, HAS_INFO_DATE :: Nil)

val hasInformationDate = conf.getBoolean(HAS_INFO_DATE)
Expand All @@ -78,6 +80,8 @@ object TableReaderJdbcConfig {
.map(s => QuotingPolicy.fromString(s))
.getOrElse(QuotingPolicy.Auto)

val specialCharacters = ConfigUtils.getOptionString(workflowConf, SPECIAL_CHARACTERS_IN_COLUMN_NAMES).getOrElse(" ")

TableReaderJdbcConfig(
jdbcConfig = JdbcConfig.load(conf, parent),
hasInfoDate = conf.getBoolean(HAS_INFO_DATE),
Expand All @@ -90,6 +94,7 @@ object TableReaderJdbcConfig {
correctDecimalsFixPrecision = ConfigUtils.getOptionBoolean(conf, CORRECT_DECIMALS_FIX_PRECISION).getOrElse(false),
enableSchemaMetadata = ConfigUtils.getOptionBoolean(conf, ENABLE_SCHEMA_METADATA_KEY).getOrElse(false),
useJdbcNative = ConfigUtils.getOptionBoolean(conf, USE_JDBC_NATIVE).getOrElse(false),
specialCharacters,
identifierQuotingPolicy = identifierQuotingPolicy,
sqlGeneratorClass = ConfigUtils.getOptionString(conf, SQL_GENERATOR_CLASS_KEY)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class JdbcSource(sourceConfig: Config,
}
}

object JdbcSource extends ExternalChannelFactory[JdbcSource] {
override def apply(conf: Config, parentPath: String, spark: SparkSession): JdbcSource = {
val tableReaderJdbc = TableReaderJdbcConfig.load(conf)
object JdbcSource extends ExternalChannelFactoryV2[JdbcSource] {
override def apply(conf: Config, workflowConfig: Config, parentPath: String, spark: SparkSession): JdbcSource = {
val tableReaderJdbc = TableReaderJdbcConfig.load(conf, workflowConfig)

new JdbcSource(conf, parentPath, tableReaderJdbc)(spark)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TableReaderJdbcNativeSuite extends AnyWordSpec with RelationalDbFixture wi
}

private def getReader: TableReaderJdbcNative =
TableReaderJdbcNative(conf.getConfig("reader"), "reader")
TableReaderJdbcNative(conf.getConfig("reader"), conf, "reader")

"TableReaderJdbcNative factory" should {
"construct a reader object" in {
Expand All @@ -109,13 +109,13 @@ class TableReaderJdbcNativeSuite extends AnyWordSpec with RelationalDbFixture wi
}

"work with legacy config" in {
val reader = TableReaderJdbcNative(conf.getConfig("reader_legacy"), "reader_legacy")
val reader = TableReaderJdbcNative(conf.getConfig("reader_legacy"), conf, "reader_legacy")
assert(reader.getJdbcReaderConfig.infoDateFormat == "yyyy-MM-DD")
assert(!reader.getJdbcReaderConfig.jdbcConfig.sanitizeDateTime)
}

"work with minimal config" in {
val reader = TableReaderJdbcNative(conf.getConfig("reader_minimal"), "reader_minimal")
val reader = TableReaderJdbcNative(conf.getConfig("reader_minimal"), conf, "reader_minimal")
assert(reader.getJdbcReaderConfig.infoDateFormat == "yyyy-MM-dd")
assert(reader.getJdbcReaderConfig.jdbcConfig.sanitizeDateTime)
assert(!reader.getJdbcReaderConfig.jdbcConfig.autoCommit)
Expand All @@ -124,7 +124,7 @@ class TableReaderJdbcNativeSuite extends AnyWordSpec with RelationalDbFixture wi

"throw an exception if config is missing" in {
intercept[IllegalArgumentException] {
TableReaderJdbcNative(conf)
TableReaderJdbcNative(conf, conf)
}
}
}
Expand Down Expand Up @@ -293,15 +293,15 @@ class TableReaderJdbcNativeSuite extends AnyWordSpec with RelationalDbFixture wi
}

"return a query without info date if it is disabled" in {
val reader = TableReaderJdbcNative(conf.getConfig("reader_minimal"), "reader_minimal")
val reader = TableReaderJdbcNative(conf.getConfig("reader_minimal"), conf, "reader_minimal")

val actual = reader.getSqlDataQuery("table1", infoDateBegin, infoDateEnd, Nil)

assert(actual == "SELECT * FROM table1")
}

"return a query without with limits" in {
val reader = TableReaderJdbcNative(conf.getConfig("reader_limit"), "reader_limit")
val reader = TableReaderJdbcNative(conf.getConfig("reader_limit"), conf, "reader_limit")

val actual = reader.getSqlDataQuery("table1", infoDateBegin, infoDateEnd, Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
|}""".stripMargin)

"be able to be constructed properly from config" in {
val reader = TableReaderJdbc(conf.getConfig("reader"), "reader")
val reader = TableReaderJdbc(conf.getConfig("reader"), conf, "reader")

val jdbc = reader.getJdbcConfig

Expand All @@ -113,7 +113,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
}

"be able to be constructed properly from legacy config" in {
val reader = TableReaderJdbc(conf.getConfig("reader_legacy"), "reader_legacy")
val reader = TableReaderJdbc(conf.getConfig("reader_legacy"), conf, "reader_legacy")

val jdbc = reader.getJdbcConfig

Expand All @@ -132,7 +132,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
}

"be able to be constructed properly from minimal config" in {
val reader = TableReaderJdbc(conf.getConfig("reader_minimal"), "reader_minimal")
val reader = TableReaderJdbc(conf.getConfig("reader_minimal"), conf, "reader_minimal")

val jdbc = reader.getJdbcConfig

Expand All @@ -151,13 +151,13 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
}

"ensure sql query generator is properly selected 1" in {
val reader = TableReaderJdbc(conf.getConfig("reader"), "reader")
val reader = TableReaderJdbc(conf.getConfig("reader"), conf, "reader")

assert(reader.sqlGen.isInstanceOf[SqlGeneratorHsqlDb])
}

"ensure sql query generator is properly selected 2" in {
val reader = TableReaderJdbc(conf.getConfig("reader_legacy"), "reader_legacy")
val reader = TableReaderJdbc(conf.getConfig("reader_legacy"), conf, "reader_legacy")

assert(reader.sqlGen.isInstanceOf[SqlGeneratorDummy])
}
Expand All @@ -166,7 +166,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
val testConfig = conf
.withValue("reader.save.timestamps.as.dates", ConfigValueFactory.fromAnyRef(true))
.withValue("reader.correct.decimals.in.schema", ConfigValueFactory.fromAnyRef(true))
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), conf, "reader")

val jdbc = reader.getJdbcConfig

Expand All @@ -186,7 +186,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
|
| has.information.date.column = false
|}""".stripMargin)
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), conf, "reader")

val jdbc = reader.getJdbcConfig

Expand All @@ -207,7 +207,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
| information.date.column = "sync_date"
| information.date.type = "date"
|}""".stripMargin)
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), conf, "reader")

val jdbc = reader.getJdbcConfig

Expand All @@ -219,7 +219,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
"getWithRetry" should {
"return the successful dataframe on the second try" in {
val readerConfig = conf.getConfig("reader")
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, "reader")
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, conf, "reader")

val urlSelector = mock(classOf[JdbcUrlSelector])

Expand All @@ -238,7 +238,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"pass the exception when out of retries" in {
val readerConfig = conf.getConfig("reader")
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, "reader")
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, conf, "reader")

val urlSelector = mock(classOf[JdbcUrlSelector])

Expand All @@ -264,7 +264,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("reader.correct.decimals.in.schema", ConfigValueFactory.fromAnyRef(true))
.withValue("enable.schema.metadata", ConfigValueFactory.fromAnyRef(true))

val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, "reader")
val jdbcTableReaderConfig = TableReaderJdbcConfig.load(readerConfig, conf, "reader")
val urlSelector = JdbcUrlSelector(jdbcTableReaderConfig.jdbcConfig)

val reader = new TableReaderJdbc(jdbcTableReaderConfig, urlSelector, readerConfig)
Expand All @@ -284,7 +284,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
"getCount()" should {
"return count for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val count = reader.getRecordCount(Query.Table("company"), null, null)

Expand All @@ -293,7 +293,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"return count for a sql snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val count = reader.getRecordCount(Query.Sql("SELECT * FROM company"), null, null)

Expand All @@ -307,7 +307,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val count = reader.getRecordCount(Query.Table("company"), infoDate, infoDate)

Expand All @@ -316,7 +316,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"return count for a snapshot-like SQL" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val count = reader.getRecordCount(Query.Sql("SELECT id FROM company"), infoDate, infoDate)

Expand All @@ -330,7 +330,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val count = reader.getRecordCount(Query.Sql("SELECT id, info_date FROM company WHERE info_date BETWEEN '@dateFrom' AND '@dateTo'"), infoDate, infoDate)

Expand All @@ -341,7 +341,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
"getCountSqlQuery" should {
"return a count query for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY", infoDate, infoDate)

Expand All @@ -355,7 +355,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY WHERE info_date BETWEEN '@dateFrom' AND '@dateTo'", infoDate, infoDate)

Expand All @@ -370,7 +370,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM my_db.my_table WHERE info_date = CAST(REPLACE(CAST(CAST('@infoDate' AS DATE) AS VARCHAR(10)), '-', '') AS INTEGER)", infoDate, infoDate)

Expand All @@ -381,7 +381,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
"getData()" should {
"return data for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val df = reader.getData(Query.Table("company"), null, null, Seq.empty[String])

Expand All @@ -391,7 +391,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"return selected column for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val df = reader.getData(Query.Table("company"), null, null, Seq("id", "name"))

Expand All @@ -407,7 +407,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))
.withValue("correct.decimals.in.schema", ConfigValueFactory.fromAnyRef(true))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val df = reader.getData(Query.Table("company"), infoDate, infoDate, Seq.empty[String])

Expand All @@ -416,7 +416,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"return data for a snapshot-like SQL" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val df = reader.getData(Query.Sql("SELECT id FROM company"), infoDate, infoDate, Seq.empty[String])

Expand All @@ -426,7 +426,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark

"return selected columns for a snapshot-like SQL" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")
val reader = TableReaderJdbc(testConfig.getConfig("reader"), testConfig, "reader")

val df = reader.getData(Query.Sql("SELECT * FROM company"), infoDate, infoDate, Seq("id", "name"))

Expand All @@ -441,7 +441,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

val df = reader.getData(Query.Sql("SELECT id, info_date FROM company WHERE info_date BETWEEN '@dateFrom' AND '@dateTo'"), infoDate, infoDate, Seq.empty[String])

Expand All @@ -456,7 +456,7 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
.withValue("information.date.column", ConfigValueFactory.fromAnyRef("info_date"))
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("not_exist"))

val reader = TableReaderJdbc(testConfig, "reader")
val reader = TableReaderJdbc(testConfig, testConfig, "reader")

assertThrows[IllegalArgumentException] {
reader.getSqlConfig
Expand Down

0 comments on commit 5d8e74f

Please sign in to comment.