Skip to content

Commit

Permalink
feat(spark): add support for ‘offset’ clause
Browse files Browse the repository at this point in the history
Add missing support for the ‘offset’ clause in the spark module.

Signed-off-by: Andrew Coleman <andrew_coleman@uk.ibm.com>
  • Loading branch information
andrew-coleman committed Sep 24, 2024
1 parent 5d66cd1 commit c8bde6d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.types.{DataTypes, IntegerType, StructType}

import org.apache.spark.sql.types.{DataTypes, IntegerType, LongType, StructField, StructType}
import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
Expand Down Expand Up @@ -157,11 +156,16 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType)
fetch.getOffset match {
case 1L => GlobalLimit(limitExpr = limit, child = child)
case -1L => LocalLimit(limitExpr = limit, child = child)
case _ => visitFallback(fetch)
if (fetch.getCount.isPresent) {
val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType)
fetch.getOffset match {
case 1L => GlobalLimit(limitExpr = limit, child = child)
case -1L => LocalLimit(limitExpr = limit, child = child)
case _ => visitFallback(fetch)
}
} else {
val offset = Literal(fetch.getOffset, LongType)
Offset(offset, child)
}
}
override def visit(sort: relation.Sort): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
.build()
}

override def visitOffset(p: Offset): relation.Rel = {
val offset: Long = p.offsetExpr.eval() match {
case n: java.lang.Integer => n.toLong
case n: java.lang.Long => n
case e => throw new UnsupportedOperationException(s"Unable to convert the offset expression ${e}")
}
relation.Fetch.builder()
.input(visit(p.child))
.offset(offset)
.build()
}

override def visitFilter(p: Filter): relation.Rel = {
val condition = toExpression(p.child.output)(p.condition)
relation.Filter.builder().condition(condition).input(visit(p.child)).build()
Expand Down
2 changes: 1 addition & 1 deletion spark/src/test/scala/io/substrait/spark/TPCHPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
"order by l_shipdate asc, l_discount desc nulls last")
}

ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass
test("simpleOffsetClause") {
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc limit 100 offset 1000")
Expand Down

0 comments on commit c8bde6d

Please sign in to comment.