Skip to content

Commit

Permalink
#522 Fix Denodo dialect handling of Time columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Aug 21, 2024
1 parent 31b0b94 commit ce475b4
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.Query
import za.co.absa.pramen.core.reader.model.{JdbcConfig, TableReaderJdbcConfig}
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, StringUtils, TimeUtils}
import za.co.absa.pramen.core.utils.{JdbcNativeUtils, JdbcSparkUtils, StringUtils, TimeUtils}

import java.time.format.DateTimeFormatter
import java.time.{Instant, LocalDate}

class TableReaderJdbcNative(jdbcReaderConfig: TableReaderJdbcConfig,
Expand Down Expand Up @@ -58,8 +57,8 @@ class TableReaderJdbcNative(jdbcReaderConfig: TableReaderJdbcConfig,
override def getData(query: Query, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String]): DataFrame = {
log.info(s"JDBC Native data of: $query")
query match {
case Query.Sql(sql) => getDataFrame(getFilteredSql(sql, infoDateBegin, infoDateEnd))
case Query.Table(table) => getDataFrame(getSqlDataQuery(table, infoDateBegin, infoDateEnd, columns))
case Query.Sql(sql) => getDataFrame(getFilteredSql(sql, infoDateBegin, infoDateEnd), None)
case Query.Table(table) => getDataFrame(getSqlDataQuery(table, infoDateBegin, infoDateEnd, columns), Option(table))
case other => throw new IllegalArgumentException(s"'${other.name}' is not supported by the JDBC Native reader. Use 'sql' or 'table' instead.")
}
}
Expand All @@ -79,15 +78,27 @@ class TableReaderJdbcNative(jdbcReaderConfig: TableReaderJdbcConfig,
}
}

private[core] def getDataFrame(sql: String): DataFrame = {
private[core] def getDataFrame(sql: String, tableOpt: Option[String]): DataFrame = {
log.info(s"JDBC Query: $sql")

val df = JdbcNativeUtils.getJdbcNativeDataFrame(jdbcConfig, url, sql)
var df = JdbcNativeUtils.getJdbcNativeDataFrame(jdbcConfig, url, sql)

if (log.isDebugEnabled) {
log.debug(df.schema.treeString)
}

if (jdbcReaderConfig.enableSchemaMetadata) {
JdbcSparkUtils.withJdbcMetadata(jdbcReaderConfig.jdbcConfig, sql) { (connection, jdbcMetadata) =>
val schemaWithColumnDescriptions = tableOpt match {
case Some(table) =>
log.info(s"Reading JDBC metadata descriptions the query: $sql")
JdbcSparkUtils.addColumnDescriptionsFromJdbc(df.schema, table, connection)
case None => df.schema
}
df = spark.createDataFrame(df.rdd, schemaWithColumnDescriptions)
}
}

df
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package za.co.absa.pramen.core.sql.dialects

import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.types.{DataType, MetadataBuilder, TimestampType}
import org.apache.spark.sql.types.{DataType, MetadataBuilder, StringType, TimestampType}

import java.sql.Types.TIMESTAMP_WITH_TIMEZONE
import java.sql.Types._

object DenodoDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:denodo") || url.startsWith("jdbc:vdb")
Expand All @@ -31,6 +31,7 @@ object DenodoDialect extends JdbcDialect {
): Option[DataType] =
sqlType match {
case TIMESTAMP_WITH_TIMEZONE => Some(TimestampType)
case TIME => Some(StringType)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import za.co.absa.pramen.core.pipeline.TransformExpression
import java.io.ByteArrayOutputStream
import java.time.format.DateTimeFormatter
import java.time.{Instant, LocalDate}
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe._
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -516,6 +517,77 @@ object SparkUtils {
output.toString()
}

/**
* Copies metadata from one schema to another as long as names and data types are the same.
*
* @param schemaFrom Schema to copy metadata from.
* @param schemaTo Schema to copy metadata to.
* @param overwrite If true, the metadata of schemaTo is not retained
* @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise.
* @return Same schema as schemaTo with metadata from schemaFrom.
*/
def copyMetadata(schemaFrom: StructType,
schemaTo: StructType,
overwrite: Boolean = false,
sourcePreferred: Boolean = false): StructType = {
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
val newMetadataMerged = new MetadataBuilder

if (sourcePreferred) {
newMetadataMerged.withMetadata(to)
newMetadataMerged.withMetadata(from)
} else {
newMetadataMerged.withMetadata(from)
newMetadataMerged.withMetadata(to)
}

newMetadataMerged.build()
}

@tailrec
def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = {
ar.elementType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] =>
val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
val newDataType = StructType(copyMetadata(innerStructFrom, st).fields)
ArrayType(newDataType, ar.containsNull)
case at: ArrayType =>
processArray(at, fieldFrom, fieldTo)
case p =>
ArrayType(p, ar.containsNull)
}
}

val fieldsMap = schemaFrom.fields.map(f => (f.name, f)).toMap

val newFields: Array[StructField] = schemaTo.fields.map { fieldTo =>
fieldsMap.get(fieldTo.name) match {
case Some(fieldFrom) =>
val newMetadata = if (overwrite) {
fieldFrom.metadata
} else {
joinMetadata(fieldFrom.metadata, fieldTo.metadata)
}

fieldTo.dataType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields)
fieldTo.copy(dataType = newDataType, metadata = newMetadata)
case at: ArrayType =>
val newType = processArray(at, fieldFrom, fieldTo)
fieldTo.copy(dataType = newType, metadata = newMetadata)
case _ =>
fieldTo.copy(metadata = newMetadata)
}
case None =>
fieldTo
}
}

StructType(newFields)
}


private def getActualProcessingTimeUdf: UserDefinedFunction = {
udf((_: Long) => Instant.now().getEpochSecond)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,103 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture
}
}

"copyMetadata" should {
"copy metadata from one schema to another when overwrite = false" in {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
}

"retain metadata on conflicts by default" in {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")
metadata1.putLong("maxLength", 100)

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)
metadata2.putLong("newMetadata", 180)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180)
}

"overwrite metadata on conflicts when sourcePreferred=true" in {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")
metadata1.putLong("maxLength", 100)

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)
metadata2.putLong("newMetadata", 180)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, sourcePreferred = true)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 100)
assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180)
}

"not retain original metadata when overwrite = true" in {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, overwrite = true)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(!newDf.schema.fields.head.metadata.contains("maxLength"))
}
}
}

0 comments on commit ce475b4

Please sign in to comment.