diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/MapRDBDataFrameFunctions.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/MapRDBDataFrameFunctions.scala index e893ceb3faf5c..c49240810579b 100644 --- a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/MapRDBDataFrameFunctions.scala +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/MapRDBDataFrameFunctions.scala @@ -1,14 +1,16 @@ /* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */ package com.mapr.db.spark.sql +import com.mapr.db.spark.sql.ojai.{JoinType, OJAISparkPartitionReader} +import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell import com.mapr.db.spark.utils.{LoggingTrait, MapRSpark} import org.ojai.DocumentConstants - -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types.StructType private[spark] case class MapRDBDataFrameFunctions(@transient df: DataFrame, bufferWrites: Boolean = true) - extends LoggingTrait { + extends LoggingTrait { def setBufferWrites(bufferWrites: Boolean): MapRDBDataFrameFunctions = MapRDBDataFrameFunctions(df, bufferWrites) @@ -24,4 +26,44 @@ private[spark] case class MapRDBDataFrameFunctions(@transient df: DataFrame, createTable: Boolean = false, bulkInsert: Boolean = false): Unit = MapRSpark.insert(df, tableName, idFieldPath, createTable, bulkInsert, bufferWrites) + + def joinWithMapRDBTable(table: String, + schema: StructType, + left: String, + right: String, + joinType: JoinType, + concurrentQueries: Int = 20)(implicit session: SparkSession): DataFrame = { + + val columnDataType = schema.fields(schema.fieldIndex(right)).dataType + + val documents = df + .select(left) + .distinct() + .rdd + .mapPartitions { partition => + if (partition.isEmpty) { + List.empty.iterator + } else { + + val partitionCellIterator = partition.map(row => Cell(row.get(0), columnDataType)) + + OJAISparkPartitionReader + .groupedPartitionReader(concurrentQueries) + .readFrom(partitionCellIterator, table, schema, right) + } + } + + import org.apache.spark.sql.functions._ + import session.implicits._ + + val rightDF = session.read.schema(schema).json(documents.toDS()) + + df.join(rightDF, col(left) === col(right), joinType.toString) + } + + def joinWithMapRDBTable(maprdbTable: String, + schema: StructType, + left: String, + right: String)(implicit session: SparkSession): DataFrame = + joinWithMapRDBTable(maprdbTable, schema, left, right, JoinType.inner) } diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/BoundedConcurrentContext.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/BoundedConcurrentContext.scala new file mode 100644 index 0000000000000..e60188a3df5c8 --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/BoundedConcurrentContext.scala @@ -0,0 +1,20 @@ +package com.mapr.db.spark.sql.concurrent + +import java.util.concurrent.Executors + +import scala.concurrent.ExecutionContext + +/** + * This is the default ConcurrentContext. + * + * We use a CachedThreadPool so we can spawn new threads if needed, but reused them as they become available. + */ +private[concurrent] object BoundedConcurrentContext extends ConcurrentContext { + + /** + * We are using CachedThreadPool which is the same as the default used by Spark to run multiple tasks within an Executor. + */ + override def ec: ExecutionContext = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(24)) +} + + diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/ConcurrentContext.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/ConcurrentContext.scala new file mode 100644 index 0000000000000..f6f61463669c6 --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/ConcurrentContext.scala @@ -0,0 +1,74 @@ +package com.mapr.db.spark.sql.concurrent + +import scala.concurrent.duration.Duration.Inf +import scala.concurrent.{Await, ExecutionContext, Future} + +/** + * TaskLevelConcurrentContext is used to control a multithreaded context within a Spark Task. + */ +trait ConcurrentContext { + + /** + * Wraps a block within a concurrent tasks + * + * @param task Block to be executed in concurrently. + * @tparam A Result type of the passed in block. + * @return A concurrent task, that is a Future[A]. + */ + def async[A](task: => A): Future[A] = Future(task)(ec) + + /** + * Awaits for multiple concurrent tasks using a sliding windows so we don't have to hold all task results in memory + * at once. + * + * @param it Iterator of concurrent tasks. + * @param batchSize The number of concurrent tasks we want to wait at a time. + * @tparam A Result type of each concurrent task. + * @return An iterator that contains the result of executing each concurrent task. + */ + def awaitSliding[A](it: Iterator[Future[A]], batchSize: Int = 20): Iterator[A] = { + + implicit val context: ExecutionContext = ec + + val slidingIterator = it.sliding(batchSize - 1).withPartial(true) + + val (head, tail) = slidingIterator.span(_ => slidingIterator.hasNext) + + head.map(batchOfFuture => Await.result(batchOfFuture.head, Inf)) ++ + tail.flatMap(batchOfFuture => Await.result(Future.sequence(batchOfFuture), Inf)) + } + + /** + * We allow implementations to define the ExecutionContext to be used. + * + * @return ExecutionContext to be used when spawning new threads. + */ + def ec: ExecutionContext +} + +object ConcurrentContext { + + /** + * Implicit instance to our TaskLevelConcurrentContext since it is our default one. + */ + implicit val defaultConcurrentContext: ConcurrentContext = BoundedConcurrentContext + + def unboundedConcurrentContext: ConcurrentContext = UnboundedConcurrentContext + + /** + * Implicit syntax + */ + object Implicits { + + implicit class ConcurrentIteratorOps[A](it: Iterator[Future[A]]) { + def awaitSliding(batchSize: Int = 20)(implicit concurrentContext: ConcurrentContext): Iterator[A] = + concurrentContext.awaitSliding(it, batchSize) + } + + implicit class AsyncOps[A](task: => A) { + def async(implicit concurrentContext: ConcurrentContext): Future[A] = concurrentContext.async(task) + } + + } + +} \ No newline at end of file diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/UnboundedConcurrentContext.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/UnboundedConcurrentContext.scala new file mode 100644 index 0000000000000..243d69ee4c3e9 --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/concurrent/UnboundedConcurrentContext.scala @@ -0,0 +1,12 @@ +package com.mapr.db.spark.sql.concurrent + +import scala.concurrent.ExecutionContext + +private[concurrent] object UnboundedConcurrentContext extends ConcurrentContext { + /** + * We allow implementation to define the ExecutionContext to be used. + * + * @return ExecutionContext to be used when spawning new threads. + */ + override def ec: ExecutionContext = scala.concurrent.ExecutionContext.global +} diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/GroupedPartitionQueryRunner.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/GroupedPartitionQueryRunner.scala new file mode 100644 index 0000000000000..018465d19dd28 --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/GroupedPartitionQueryRunner.scala @@ -0,0 +1,51 @@ +package com.mapr.db.spark.sql.ojai + +import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell +import org.apache.spark.sql.types.StructType + +/** + * PartitionQueryRunner reads the MapR-DB data that matches with certain rows. + * + * Each Spark executor has an instance of PartitionQueryRunner. + */ +private[ojai] class GroupedPartitionQueryRunner(querySize: Int) extends OJAISparkPartitionReader { + + import com.mapr.db.spark.sql.concurrent.ConcurrentContext.Implicits._ + import com.mapr.db.spark.sql.utils.MapRSqlUtils._ + import org.ojai.store._ + + import scala.collection.JavaConverters._ + + /** + * Reads MapR-DB records that match with the data in a given partition. + * + * @param partition Contains the records used to match the data to be read from MapR-DB. + * @param table MapR-DB table to read from. + * @param schema Schema to be enforced over the MapR-DB data after the read. + * @param right Column to be used for MapR-DB query. + * @return Iterator that contains all records from MapR-DB that match with the data of the given partition. + */ + def readFrom(partition: Iterator[Cell], + table: String, + schema: StructType, + right: String): Iterator[String] = { + + val connection = DriverManager.getConnection("ojai:mapr:") + val store = connection.getStore(table) + + val parallelRunningQueries = partition + .map(cell => convertToDataType(cell.value, cell.dataType)) + .grouped(querySize) + .map(group => connection.newCondition().in(right, group.asJava).build()) + .map(cond => + connection + .newQuery() + .where(cond) // Filters push down. Secondary indexes kick in here. + .select(schema.fields.map(_.name): _*) // Projections push down. + .build() + ) + .map(query => store.find(query).asScala.map(_.asJsonString()).async) + + parallelRunningQueries.awaitSliding().flatten + } +} \ No newline at end of file diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/JoinType.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/JoinType.scala new file mode 100644 index 0000000000000..431418d775120 --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/JoinType.scala @@ -0,0 +1,34 @@ +package com.mapr.db.spark.sql.ojai + +sealed trait JoinType + +object JoinType { + + def apply(value: String): JoinType = joins.indexWhere(_.toString == value.toLowerCase()) match { + case -1 => throw new IllegalArgumentException(s"$value is not a supported join type") + case idx => joins(idx) + } + + private lazy val joins = List(inner, outer, full, left, left_outer) + + case object inner extends JoinType { + override def toString: String = "inner" + } + + case object outer extends JoinType { + override def toString: String = "outer" + } + + case object full extends JoinType { + override def toString: String = "full" + } + + case object left extends JoinType { + override def toString: String = "left" + } + + case object left_outer extends JoinType { + override def toString: String = "left_outer" + } + +} \ No newline at end of file diff --git a/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/OJAISparkPartitionReader.scala b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/OJAISparkPartitionReader.scala new file mode 100644 index 0000000000000..a471f9e8a141d --- /dev/null +++ b/external/maprdb/src/main/scala/com/mapr/db/spark/sql/ojai/OJAISparkPartitionReader.scala @@ -0,0 +1,29 @@ +package com.mapr.db.spark.sql.ojai + + +import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell +import org.apache.spark.sql.types.{DataType, StructType} + +trait OJAISparkPartitionReader { + def readFrom(partition: Iterator[Cell], + table: String, + schema: StructType, + right: String): Iterator[String] +} + +object OJAISparkPartitionReader { + + def groupedPartitionReader(batchSize: Int = 20): OJAISparkPartitionReader = new GroupedPartitionQueryRunner(batchSize) + + def sequentialPartitionReader: OJAISparkPartitionReader = new GroupedPartitionQueryRunner(1) + + /** + * Used to project the exact column we need to filter the MapR-DB table. We can use Cell instead of passing the + * entire Row to reduce the memory footprint. + * + * @param value Spark value of the Row at the specific column. + * @param dataType The corresponding data type + */ + private[mapr] case class Cell(value: Any, dataType: DataType) + +} \ No newline at end of file