Skip to content

Commit

Permalink
Adding joinWithMapRDBTable function (apache#529)
Browse files Browse the repository at this point in the history
The related documentation of this function is here https://github.com/anicolaspp/MapRDBConnector#joinwithmaprdbtable.

The main idea is that having a dataframe (no matter how was it constructed) we can join it with a MapR-DB table. This functions looks at the join query and load only those records from MapR-DB that will join instead of loading the full table and then join in memory. In other words, we only load what we know will be joint.
  • Loading branch information
anicolaspp authored and ekrivokonmapr committed Sep 19, 2019
1 parent c79c632 commit 669fb06
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.db.spark.sql

import com.mapr.db.spark.sql.ojai.{JoinType, OJAISparkPartitionReader}
import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell
import com.mapr.db.spark.utils.{LoggingTrait, MapRSpark}
import org.ojai.DocumentConstants

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.StructType

private[spark] case class MapRDBDataFrameFunctions(@transient df: DataFrame,
bufferWrites: Boolean = true)
extends LoggingTrait {
extends LoggingTrait {

def setBufferWrites(bufferWrites: Boolean): MapRDBDataFrameFunctions =
MapRDBDataFrameFunctions(df, bufferWrites)
Expand All @@ -24,4 +26,44 @@ private[spark] case class MapRDBDataFrameFunctions(@transient df: DataFrame,
createTable: Boolean = false,
bulkInsert: Boolean = false): Unit =
MapRSpark.insert(df, tableName, idFieldPath, createTable, bulkInsert, bufferWrites)

def joinWithMapRDBTable(table: String,
schema: StructType,
left: String,
right: String,
joinType: JoinType,
concurrentQueries: Int = 20)(implicit session: SparkSession): DataFrame = {

val columnDataType = schema.fields(schema.fieldIndex(right)).dataType

val documents = df
.select(left)
.distinct()
.rdd
.mapPartitions { partition =>
if (partition.isEmpty) {
List.empty.iterator
} else {

val partitionCellIterator = partition.map(row => Cell(row.get(0), columnDataType))

OJAISparkPartitionReader
.groupedPartitionReader(concurrentQueries)
.readFrom(partitionCellIterator, table, schema, right)
}
}

import org.apache.spark.sql.functions._
import session.implicits._

val rightDF = session.read.schema(schema).json(documents.toDS())

df.join(rightDF, col(left) === col(right), joinType.toString)
}

def joinWithMapRDBTable(maprdbTable: String,
schema: StructType,
left: String,
right: String)(implicit session: SparkSession): DataFrame =
joinWithMapRDBTable(maprdbTable, schema, left, right, JoinType.inner)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.mapr.db.spark.sql.concurrent

import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext

/**
* This is the default ConcurrentContext.
*
* We use a CachedThreadPool so we can spawn new threads if needed, but reused them as they become available.
*/
private[concurrent] object BoundedConcurrentContext extends ConcurrentContext {

/**
* We are using CachedThreadPool which is the same as the default used by Spark to run multiple tasks within an Executor.
*/
override def ec: ExecutionContext = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(24))
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.mapr.db.spark.sql.concurrent

import scala.concurrent.duration.Duration.Inf
import scala.concurrent.{Await, ExecutionContext, Future}

/**
* TaskLevelConcurrentContext is used to control a multithreaded context within a Spark Task.
*/
trait ConcurrentContext {

/**
* Wraps a block within a concurrent tasks
*
* @param task Block to be executed in concurrently.
* @tparam A Result type of the passed in block.
* @return A concurrent task, that is a Future[A].
*/
def async[A](task: => A): Future[A] = Future(task)(ec)

/**
* Awaits for multiple concurrent tasks using a sliding windows so we don't have to hold all task results in memory
* at once.
*
* @param it Iterator of concurrent tasks.
* @param batchSize The number of concurrent tasks we want to wait at a time.
* @tparam A Result type of each concurrent task.
* @return An iterator that contains the result of executing each concurrent task.
*/
def awaitSliding[A](it: Iterator[Future[A]], batchSize: Int = 20): Iterator[A] = {

implicit val context: ExecutionContext = ec

val slidingIterator = it.sliding(batchSize - 1).withPartial(true)

val (head, tail) = slidingIterator.span(_ => slidingIterator.hasNext)

head.map(batchOfFuture => Await.result(batchOfFuture.head, Inf)) ++
tail.flatMap(batchOfFuture => Await.result(Future.sequence(batchOfFuture), Inf))
}

/**
* We allow implementations to define the ExecutionContext to be used.
*
* @return ExecutionContext to be used when spawning new threads.
*/
def ec: ExecutionContext
}

object ConcurrentContext {

/**
* Implicit instance to our TaskLevelConcurrentContext since it is our default one.
*/
implicit val defaultConcurrentContext: ConcurrentContext = BoundedConcurrentContext

def unboundedConcurrentContext: ConcurrentContext = UnboundedConcurrentContext

/**
* Implicit syntax
*/
object Implicits {

implicit class ConcurrentIteratorOps[A](it: Iterator[Future[A]]) {
def awaitSliding(batchSize: Int = 20)(implicit concurrentContext: ConcurrentContext): Iterator[A] =
concurrentContext.awaitSliding(it, batchSize)
}

implicit class AsyncOps[A](task: => A) {
def async(implicit concurrentContext: ConcurrentContext): Future[A] = concurrentContext.async(task)
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.mapr.db.spark.sql.concurrent

import scala.concurrent.ExecutionContext

private[concurrent] object UnboundedConcurrentContext extends ConcurrentContext {
/**
* We allow implementation to define the ExecutionContext to be used.
*
* @return ExecutionContext to be used when spawning new threads.
*/
override def ec: ExecutionContext = scala.concurrent.ExecutionContext.global
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.mapr.db.spark.sql.ojai

import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell
import org.apache.spark.sql.types.StructType

/**
* PartitionQueryRunner reads the MapR-DB data that matches with certain rows.
*
* Each Spark executor has an instance of PartitionQueryRunner.
*/
private[ojai] class GroupedPartitionQueryRunner(querySize: Int) extends OJAISparkPartitionReader {

import com.mapr.db.spark.sql.concurrent.ConcurrentContext.Implicits._
import com.mapr.db.spark.sql.utils.MapRSqlUtils._
import org.ojai.store._

import scala.collection.JavaConverters._

/**
* Reads MapR-DB records that match with the data in a given partition.
*
* @param partition Contains the records used to match the data to be read from MapR-DB.
* @param table MapR-DB table to read from.
* @param schema Schema to be enforced over the MapR-DB data after the read.
* @param right Column to be used for MapR-DB query.
* @return Iterator that contains all records from MapR-DB that match with the data of the given partition.
*/
def readFrom(partition: Iterator[Cell],
table: String,
schema: StructType,
right: String): Iterator[String] = {

val connection = DriverManager.getConnection("ojai:mapr:")
val store = connection.getStore(table)

val parallelRunningQueries = partition
.map(cell => convertToDataType(cell.value, cell.dataType))
.grouped(querySize)
.map(group => connection.newCondition().in(right, group.asJava).build())
.map(cond =>
connection
.newQuery()
.where(cond) // Filters push down. Secondary indexes kick in here.
.select(schema.fields.map(_.name): _*) // Projections push down.
.build()
)
.map(query => store.find(query).asScala.map(_.asJsonString()).async)

parallelRunningQueries.awaitSliding().flatten
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.mapr.db.spark.sql.ojai

sealed trait JoinType

object JoinType {

def apply(value: String): JoinType = joins.indexWhere(_.toString == value.toLowerCase()) match {
case -1 => throw new IllegalArgumentException(s"$value is not a supported join type")
case idx => joins(idx)
}

private lazy val joins = List(inner, outer, full, left, left_outer)

case object inner extends JoinType {
override def toString: String = "inner"
}

case object outer extends JoinType {
override def toString: String = "outer"
}

case object full extends JoinType {
override def toString: String = "full"
}

case object left extends JoinType {
override def toString: String = "left"
}

case object left_outer extends JoinType {
override def toString: String = "left_outer"
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.mapr.db.spark.sql.ojai


import com.mapr.db.spark.sql.ojai.OJAISparkPartitionReader.Cell
import org.apache.spark.sql.types.{DataType, StructType}

trait OJAISparkPartitionReader {
def readFrom(partition: Iterator[Cell],
table: String,
schema: StructType,
right: String): Iterator[String]
}

object OJAISparkPartitionReader {

def groupedPartitionReader(batchSize: Int = 20): OJAISparkPartitionReader = new GroupedPartitionQueryRunner(batchSize)

def sequentialPartitionReader: OJAISparkPartitionReader = new GroupedPartitionQueryRunner(1)

/**
* Used to project the exact column we need to filter the MapR-DB table. We can use Cell instead of passing the
* entire Row to reduce the memory footprint.
*
* @param value Spark value of the Row at the specific column.
* @param dataType The corresponding data type
*/
private[mapr] case class Cell(value: Any, dataType: DataType)

}

0 comments on commit 669fb06

Please sign in to comment.