diff --git a/.github/workflows/jacoco_check.yml b/.github/workflows/jacoco_check.yml index 755d2a9..d1e0909 100644 --- a/.github/workflows/jacoco_check.yml +++ b/.github/workflows/jacoco_check.yml @@ -48,7 +48,7 @@ jobs: run: sbt ++${{matrix.scala}} jacoco - name: Add coverage to PR id: jacoco - uses: madrapps/jacoco-report@v1.3 + uses: madrapps/jacoco-report@v1.7.1 with: paths: > ${{ github.workspace }}/atum/target/scala-${{ matrix.scala_short }}/jacoco/report/jacoco.xml, diff --git a/README.md b/README.md index ec97c91..4892f34 100644 --- a/README.md +++ b/README.md @@ -419,13 +419,15 @@ The summary of common control framework routines you can use as Spark and Datafr The control measurement of a column is a hash sum. It can be calculated differently depending on the column's data type and on business requirements. This table represents all currently supported measurement types: -| Type | Description | -| ------------------------------ |:----------------------------------------------------- | -| controlType.Count | Calculates the number of rows in the dataset | -| controlType.distinctCount | Calculates DISTINCT(COUNT(()) of the specified column | -| controlType.aggregatedTotal | Calculates SUM() of the specified column | -| controlType.absAggregatedTotal | Calculates SUM(ABS()) of the specified column | -| controlType.HashCrc32 | Calculates SUM(CRC32()) of the specified column | +| Type | Description | +| ----------------------------------- |:----------------------------------------------------- | +| controlType.Count | Calculates the number of rows in the dataset | +| controlType.distinctCount | Calculates DISTINCT(COUNT(()) of the specified column | +| controlType.aggregatedTotal | Calculates SUM() of the specified column | +| controlType.absAggregatedTotal | Calculates SUM(ABS()) of the specified column | +| controlType.HashCrc32 | Calculates SUM(CRC32()) of the specified column | +| controlType.aggregatedTruncTotal | Calculates SUM(TRUNC()) of the specified column | +| controlType.absAggregatedTruncTotal | Calculates SUM(TRUNC(ABS())) of the specified column | ## How to generate Code coverage report ```sbt diff --git a/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala b/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala index c4917a9..27f9630 100644 --- a/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala +++ b/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala @@ -22,9 +22,12 @@ object ControlType { case object DistinctCount extends ControlType("distinctCount", false) case object AggregatedTotal extends ControlType("aggregatedTotal", true) case object AbsAggregatedTotal extends ControlType("absAggregatedTotal", true) + case object AggregatedTruncTotal extends ControlType("aggregatedTruncTotal", true) + case object AbsAggregatedTruncTotal extends ControlType("absAggregatedTruncTotal", true) case object HashCrc32 extends ControlType("hashCrc32", false) - val values: Seq[ControlType] = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, HashCrc32) + val values: Seq[ControlType] = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, + AggregatedTruncTotal, AbsAggregatedTruncTotal, HashCrc32) val valueNames: Seq[String] = values.map(_.value) def getNormalizedValueName(input: String): String = { diff --git a/atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala b/atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala index 25a3cc6..2d75f95 100644 --- a/atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala +++ b/atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala @@ -61,6 +61,16 @@ object MeasurementProcessor { .agg(sum(col(aggColName))).collect()(0)(0) if (v == null) "" else v.toString } + case AggregatedTruncTotal => + (ds: Dataset[Row]) => { + val aggCol = sum(col(valueColumnName).cast(LongType)) + aggregateColumn(ds, controlCol, aggCol) + } + case AbsAggregatedTruncTotal => + (ds: Dataset[Row]) => { + val aggCol = sum(abs(col(valueColumnName).cast(LongType))) + aggregateColumn(ds, controlCol, aggCol) + } } } diff --git a/atum/src/main/scala/za/co/absa/atum/utils/SparkLocalMaster.scala b/atum/src/main/scala/za/co/absa/atum/utils/SparkLocalMaster.scala index 10a31c9..b9b96ab 100644 --- a/atum/src/main/scala/za/co/absa/atum/utils/SparkLocalMaster.scala +++ b/atum/src/main/scala/za/co/absa/atum/utils/SparkLocalMaster.scala @@ -22,4 +22,5 @@ trait SparkLocalMaster { // in order to runSampleMeasuremts as tests, otherwise // java.lang.IllegalArgumentException: System memory 259522560 must be at least 471859200... is thrown System.getProperties.setProperty("spark.testing.memory", (1024*1024*1024).toString) // 1g + System.getProperties.setProperty("spark.app.name", "unit-test") } diff --git a/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala b/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala index f23dc0c..f229fcd 100644 --- a/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala +++ b/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala @@ -37,7 +37,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa ) )) - val measurementsIntOverflow = List( + val measurementsIntOverflow: Seq[Measurement] = List( Measurement( controlName = "RecordCount", controlType = ControlType.Count.value, @@ -112,7 +112,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa assert(newMeasurements == measurementsIntOverflow) } - val measurementsAggregation = List( + val measurementsAggregation: Seq[Measurement] = List( Measurement( controlName = "RecordCount", controlType = ControlType.Count.value, @@ -304,7 +304,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa assert(newMeasurements == measurements3) } - val measurementsWithHash = List( + val measurementsWithHash: Seq[Measurement] = List( Measurement( controlName = "RecordCount", controlType = ControlType.Count.value, @@ -394,4 +394,34 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa assert(newMeasurements == measurementsAggregationShort) } + val measurementsAggregatedTruncTotal: Seq[Measurement] = List( + Measurement( + controlName = "aggregatedTruncTotal", + controlType = "aggregatedTruncTotal", + controlCol = "price", + controlValue = "999" + ), + Measurement( + controlName = "absAggregatedTruncTotal", + controlType = "absAggregatedTruncTotal", + controlCol = "price", + controlValue = "2999" + ) + ) + + "aggregatedTruncTotal types" should "return truncated sum of values" in { + val inputDataJson = spark.sparkContext.parallelize( + s"""{"id": ${Long.MaxValue}, "price": -1000.000001, "order": { "orderid": 1, "items": 1 } } """ :: + s"""{"id": ${Long.MinValue}, "price": 1000.9, "order": { "orderid": -1, "items": -1 } } """ :: + s"""{"id": ${Long.MinValue}, "price": 999.999999, "order": { "orderid": -1, "items": -1 } } """ ::Nil) + val df = spark.read + .schema(schema) + .json(inputDataJson.toDS) + + val processor = new MeasurementProcessor(measurementsAggregatedTruncTotal) + val newMeasurements = processor.measureDataset(df) + + assert(newMeasurements == measurementsAggregatedTruncTotal) + } + }