diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala index debbb1c3e..7f1d6e153 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala @@ -20,18 +20,21 @@ package com.intel.oap.spark.sql import com.intel.oap.spark.sql.ArrowWriteExtension.ArrowWritePostRule import com.intel.oap.spark.sql.ArrowWriteExtension.DummyRule import com.intel.oap.spark.sql.ArrowWriteExtension.SimpleColumnarRule +import com.intel.oap.spark.sql.ArrowWriteExtension.SimpleStrategy import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat import com.intel.oap.sql.execution.RowToArrowColumnarExec - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.OrderPreservingUnaryNode + import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.CodegenSupport import org.apache.spark.sql.execution.ColumnarRule import org.apache.spark.sql.execution.ColumnarToRowExec @@ -48,6 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriteExtension extends (SparkSessionExtensions => Unit) { def apply(e: SparkSessionExtensions): Unit = { e.injectColumnar(session => SimpleColumnarRule(DummyRule, ArrowWritePostRule(session))) + e.injectPlannerStrategy(session => SimpleStrategy()) } } @@ -68,7 +72,7 @@ object ArrowWriteExtension { cmd match { case command: InsertIntoHadoopFsRelationCommand => if (command.fileFormat - .isInstanceOf[ArrowFileFormat]) { + .isInstanceOf[ArrowFileFormat]) { rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(child))) } else { plan.withNewChildren(plan.children.map(apply)) @@ -79,8 +83,20 @@ object ArrowWriteExtension { cmd match { case command: InsertIntoHadoopFsRelationCommand => if (command.fileFormat - .isInstanceOf[ArrowFileFormat]) { - rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(RowToArrowColumnarExec(child)))) + .isInstanceOf[ArrowFileFormat]) { + child match { + case c: AdaptiveSparkPlanExec => + rc.withNewChildren( + Array( + AdaptiveSparkPlanExec( + ColumnarToFakeRowAdaptor(c.inputPlan), + c.context, + c.preprocessingRules, + c.isSubquery))) + case other => + rc.withNewChildren( + Array(ColumnarToFakeRowAdaptor(RowToArrowColumnarExec(child)))) + } } else { plan.withNewChildren(plan.children.map(apply)) } @@ -90,18 +106,6 @@ object ArrowWriteExtension { } } - private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition { - assert(child.supportsColumnar) - - override protected def doExecute(): RDD[InternalRow] = { - child.executeColumnar().map { cb => - new FakeRow(cb) - } - } - - override def output: Seq[Attribute] = child.output - } - class FakeRow(val batch: ColumnarBatch) extends InternalRow { override def numFields: Int = throw new UnsupportedOperationException() override def setNullAt(i: Int): Unit = throw new UnsupportedOperationException() @@ -117,7 +121,8 @@ object ArrowWriteExtension { override def getDouble(ordinal: Int): Double = throw new UnsupportedOperationException() override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = throw new UnsupportedOperationException() - override def getUTF8String(ordinal: Int): UTF8String = throw new UnsupportedOperationException() + override def getUTF8String(ordinal: Int): UTF8String = + throw new UnsupportedOperationException() override def getBinary(ordinal: Int): Array[Byte] = throw new UnsupportedOperationException() override def getInterval(ordinal: Int): CalendarInterval = throw new UnsupportedOperationException() @@ -128,4 +133,31 @@ object ArrowWriteExtension { override def get(ordinal: Int, dataType: DataType): AnyRef = throw new UnsupportedOperationException() } + + private case class ColumnarToFakeRowLogicAdaptor(child: LogicalPlan) + extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = child.output + } + + private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition { + if (!child.logicalLink.isEmpty) { + setLogicalLink(ColumnarToFakeRowLogicAdaptor(child.logicalLink.get)) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.executeColumnar().map { cb => new FakeRow(cb) } + } + + override def output: Seq[Attribute] = child.output + } + + case class SimpleStrategy() extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ColumnarToFakeRowLogicAdaptor(child: LogicalPlan) => + Seq(ColumnarToFakeRowAdaptor(planLater(child))) + case other => + Nil + } + } + }