Skip to content

Commit

Permalink
Issue 317: Reduce optimization overhead (Qbeast-io#318)
Browse files Browse the repository at this point in the history
* Broadcast rollup map and cube max weights
  • Loading branch information
Jiaweihu08 authored Apr 18, 2024
1 parent ea0a26d commit d38a94c
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/main/scala/io/qbeast/spark/delta/writer/RollupDataWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,29 @@ object RollupDataWriter
extendedData: DataFrame,
revision: Revision,
indexStatus: IndexStatus): DataFrame = {
val spark = extendedData.sparkSession
val cubeMaxWeightsBroadcast =
spark.sparkContext.broadcast(
indexStatus.cubesStatuses
.mapValues(_.maxWeight)
.map(identity))
val columns = revision.columnTransformers.map(_.columnName)
extendedData
.withColumn(
QbeastColumns.cubeColumnName,
getCubeIdUDF(revision, indexStatus)(
getCubeIdUDF(revision, cubeMaxWeightsBroadcast.value)(
struct(columns.map(col): _*),
col(QbeastColumns.weightColumnName)))
}

private def getCubeIdUDF(revision: Revision, indexStatus: IndexStatus): UserDefinedFunction =
private def getCubeIdUDF(
revision: Revision,
cubeMaxWeights: Map[CubeId, Weight]): UserDefinedFunction =
udf { (row: Row, weight: Int) =>
val point = RowUtils.rowValuesToPoint(row, revision)
val cubeId = CubeId.containers(point).find { cubeId =>
indexStatus.cubesStatuses.get(cubeId) match {
case Some(status) => weight <= status.maxWeight.value
cubeMaxWeights.get(cubeId) match {
case Some(maxWeight) => weight <= maxWeight.value
case None => true
}
}
Expand All @@ -315,10 +323,11 @@ object RollupDataWriter
private def extendDataWithCubeToRollup(
extendedData: DataFrame,
revision: Revision): DataFrame = {
val rollup = computeRollup(revision, extendedData)
val spark = extendedData.sparkSession
val rollupBroadcast = spark.sparkContext.broadcast(computeRollup(revision, extendedData))
extendedData.withColumn(
QbeastColumns.cubeToRollupColumnName,
getRollupCubeIdUDF(revision, rollup)(col(QbeastColumns.cubeColumnName)))
getRollupCubeIdUDF(revision, rollupBroadcast.value)(col(QbeastColumns.cubeColumnName)))
}

private def computeRollup(revision: Revision, extendedData: DataFrame): Map[CubeId, CubeId] = {
Expand Down

0 comments on commit d38a94c

Please sign in to comment.