Skip to content

Commit

Permalink
Sketch writing with V2 API
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Jan 25, 2024
1 parent d7efaad commit 9967e97
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util
import java.util.UUID
import scala.jdk.CollectionConverters._

trait TableBase extends Table with SupportsRead{
trait TableBase extends Table with SupportsRead {

val cid: UUID

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package uk.co.gresearch.spark.dgraph.connector

import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel

case class TripleBatchWrite(schema: StructType, model: GraphTableModel) extends BatchWrite {
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory =
TripleDataWriterFactory(schema, model)

override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
writerCommitMessages.foreach(msg => Console.println(s"Committed $msg"))
}

override def abort(writerCommitMessages: Array[WriterCommitMessage]): Unit = { }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package uk.co.gresearch.spark.dgraph.connector

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel

case class TripleDataWriter(schema: StructType, model: GraphTableModel) extends DataWriter[InternalRow] {
var triples = 0L

override def write(row: InternalRow): Unit = {
// Console.println(s"Writing row: $row")
triples = triples + 1
}

override def commit(): WriterCommitMessage = {
val msg: WriterCommitMessage = new WriterCommitMessage {
val name: String = s"$triples triples (${Thread.currentThread().getName})"
override def toString: String = name
}
Console.println(s"Committing $msg")
msg
}

override def abort(): Unit = { }

override def close(): Unit = { }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package uk.co.gresearch.spark.dgraph.connector

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
import org.apache.spark.sql.types.StructType
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel

case class TripleDataWriterFactory(schema: StructType, model: GraphTableModel) extends DataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = TripleDataWriter(schema, model)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,31 @@

package uk.co.gresearch.spark.dgraph.connector

import java.util.UUID

import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
import uk.co.gresearch.spark.dgraph.connector.partitioner.Partitioner

case class TripleTable(partitioner: Partitioner, model: GraphTableModel, val cid: UUID) extends TableBase {
import java.util
import java.util.UUID
import scala.jdk.CollectionConverters._

case class TripleTable(partitioner: Partitioner, model: GraphTableModel, cid: UUID)
extends TableBase with SupportsWrite {

override def schema(): StructType = model.schema()

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
TripleScanBuilder(partitioner, model)

override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder =
TripleWriteBuilder(logicalWriteInfo.schema(), model)

override def capabilities(): util.Set[TableCapability] = Set(
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA
).asJava

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package uk.co.gresearch.spark.dgraph.connector

import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}
import org.apache.spark.sql.types.StructType
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel

case class TripleWriteBuilder(schema: StructType, model: GraphTableModel)
extends WriteBuilder {
override def buildForBatch(): BatchWrite = TripleBatchWrite(schema, model)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package uk.co.gresearch.spark.dgraph.connector.sources

import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import uk.co.gresearch.spark.dgraph.connector._
Expand Down Expand Up @@ -99,4 +101,11 @@ class NodeSource() extends TableProviderBase
TripleTable(partitioner, model, clusterState.cid)
}

def createRelation(context: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
new BaseRelation {
override def sqlContext: SQLContext = context
override def schema: StructType = data.schema
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2020 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark.dgraph.connector

import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DoubleType, StringType}
import org.scalatest.funspec.AnyFunSpec
import uk.co.gresearch.spark.dgraph.DgraphTestCluster

class TestWriter extends AnyFunSpec with ConnectorSparkTestSession with DgraphTestCluster {

import spark.implicits._

// we want a fresh cluster that we can mutate, definitively not one that is always running and used by all tests
override val clusterAlwaysStartUp: Boolean = true

describe("Connector") {
it("should write") {
spark.range(0, 1000000, 1, 10)
.select(
$"id".as("subject"),
$"id".cast(StringType).as("str"),
$"id".cast(DoubleType).as("dbl")
)
.repartition($"subject")
.sortWithinPartitions($"subject")
.write
.mode(SaveMode.Append)
.option("dgraph.nodes.mode", "wide")
.format("uk.co.gresearch.spark.dgraph.nodes")
.save(dgraph.target)
}
}
}

0 comments on commit 9967e97

Please sign in to comment.