Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic support for user provided schema in the query path #18031

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.types.StructType

class CosmosBatchWriter(userConfig: Map[String, String]) extends BatchWrite with CosmosLoggingTrait {
class CosmosBatchWriter(userConfig: Map[String, String], inputSchema: StructType) extends BatchWrite with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory(userConfig)
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory(userConfig, inputSchema)

override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
// TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class CosmosCatalog extends CatalogPlugin
checkNamespace(ident.namespace())
getContainerMetadata(ident) // validates that table exists
// scalastyle:off null
new CosmosTable(null, null, tableOptions.asJava)
new CosmosTable(Array[Transform](), tableOptions.asJava, Option.empty)
// scalastyle:off on
}

Expand Down Expand Up @@ -232,7 +232,7 @@ class CosmosCatalog extends CatalogPlugin
).block()
}
// TODO: moderakh this needs to be wired up against CosmosTabl
new CosmosTable(schema, partitions, tableOptions.asJava)
new CosmosTable(partitions, tableOptions.asJava, Option.apply(schema))
}

@throws(classOf[UnsupportedOperationException])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CosmosDataWriteFactory(userConfig: Map[String, String]) extends DataWriterFactory with CosmosLoggingTrait {
class CosmosDataWriteFactory(userConfig: Map[String, String],
inputSchema: StructType) extends DataWriterFactory with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createWriter(i: Int, l: Long): DataWriter[InternalRow] = new CosmosWriter()
override def createWriter(i: Int, l: Long): DataWriter[InternalRow] = new CosmosWriter(inputSchema)

class CosmosWriter() extends DataWriter[InternalRow] {
class CosmosWriter(inputSchema: StructType) extends DataWriter[InternalRow] {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

val cosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
Expand All @@ -30,9 +31,7 @@ class CosmosDataWriteFactory(userConfig: Map[String, String]) extends DataWriter

override def write(internalRow: InternalRow): Unit = {
// TODO moderakh: schema is hard coded for now to make end to end TestE2EMain work implement schema inference code
val userProvidedSchema = StructType(Seq(StructField("number", IntegerType), StructField("word", StringType)))

val objectNode = CosmosRowConverter.internalRowToObjectNode(internalRow, userProvidedSchema)
val objectNode = CosmosRowConverter.internalRowToObjectNode(internalRow, inputSchema)
// TODO: moderakh how should we handle absence of id?
if (!objectNode.has("id")) {
objectNode.put("id", UUID.randomUUID().toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,41 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
class CosmosItemsDataSource extends DataSourceRegister with TableProvider with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = {
// scalastyle:off null
getTable(null,
Array.empty[Transform],
caseInsensitiveStringMap.asCaseSensitiveMap()).schema()
// scalastyle:on null
/**
* Infer the schema of the table identified by the given options.
* @param options an immutable case-insensitive string-to-string
* @return StructType inferred schema
*/
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
new CosmosTable(Array.empty, options).schema()
}

/**
* Represents the format that this data source provider uses.
*/
override def shortName(): String = "cosmos.items"

override def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table = {
/**
* Return a {@link Table} instance with the specified table schema, partitioning and properties
* to do read/write. The returned table should report the same schema and partitioning with the
* specified ones, or Spark may fail the operation.
*
* @param schema The specified table schema.
* @param partitioning The specified table partitioning.
* @param properties The specified table properties. It's case preserving (contains exactly what
* users specified) and implementations are free to use it case sensitively or
* insensitively. It should be able to identify a table, e.g. file path, Kafka
* topic name, etc.
*/
override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = {
// getTable - This is used for loading table with user specified schema and other transformations.
new CosmosTable(structType, transforms, map)
new CosmosTable(partitioning, properties, Option.apply(schema))
}

/**
* Returns true if the source has the ability of accepting external table metadata when getting
* tables. The external table metadata includes user-specified schema from
* `DataFrameReader`/`DataStreamReader` and schema/partitioning stored in Spark catalog.
*/
override def supportsExternalMetadata(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ import com.azure.cosmos.models.CosmosParametrizedQuery
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class CosmosScan(config: Map[String, String], cosmosQuery: CosmosParametrizedQuery)
case class CosmosScan(schema: StructType,
config: Map[String, String], cosmosQuery: CosmosParametrizedQuery)
extends Scan
with Batch
with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

/**
* Returns the actual schema of this data source scan, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
*/
override def readSchema(): StructType = {
// TODO: moderakh add support for schema inference
// for now schema is hard coded to make TestE2EMain to work
StructType(Seq(StructField("number", IntegerType), StructField("word", StringType)))
schema
}

override def planInputPartitions(): Array[InputPartition] = {
Expand All @@ -25,7 +28,7 @@ case class CosmosScan(config: Map[String, String], cosmosQuery: CosmosParametriz
}

override def createReaderFactory(): PartitionReaderFactory = {
CosmosScanPartitionReaderFactory(config, readSchema, cosmosQuery)
CosmosScanPartitionReaderFactory(config, schema, cosmosQuery)
}

override def toBatch: Batch = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

case class CosmosScanBuilder(config: CaseInsensitiveStringMap)
case class CosmosScanBuilder(config: CaseInsensitiveStringMap, inputSchema: StructType)
extends ScanBuilder
with SupportsPushDownFilters
with SupportsPushDownRequiredColumns
Expand Down Expand Up @@ -48,7 +48,9 @@ case class CosmosScanBuilder(config: CaseInsensitiveStringMap)

override def build(): Scan = {
assert(this.processedPredicates.isDefined)
CosmosScan(config.asScala.toMap, this.processedPredicates.get.cosmosParametrizedQuery)

// TODO moderakh when inferring schema we should consolidate the schema from pruneColumns
CosmosScan(inputSchema, config.asScala.toMap, this.processedPredicates.get.cosmosParametrizedQuery)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@ import scala.collection.JavaConverters._
* @param transforms
* @param userConfig
*/
class CosmosTable(val userProvidedSchema: StructType,
val transforms: Array[Transform],
val userConfig: util.Map[String, String])
class CosmosTable(val transforms: Array[Transform],
val userConfig: util.Map[String, String],
val userProvidedSchema: Option[StructType] = Option.empty)
extends Table
with SupportsWrite
with SupportsRead
with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

// TODO: FIXME moderakh
// A name to identify this table. Implementations should provide a meaningful name, like the
// database and table name from catalog, or the location of files for this table.
override def name(): String = "com.azure.cosmos.spark.write"

/**
* Returns the schema of this table. If the table is not readable and doesn't have a schema, an
* empty schema can be returned here.
*/
override def schema(): StructType = {
// TODO: moderakh add support for schema inference
// for now schema is hard coded to make TestE2EMain to work
StructType(Seq(StructField("number", IntegerType), StructField("word", StringType)))
val hardCodedSchema = StructType(Seq(StructField("number", IntegerType), StructField("word", StringType)))
userProvidedSchema.getOrElse(hardCodedSchema)
}

override def capabilities(): util.Set[TableCapability] = Set(
Expand All @@ -44,10 +52,15 @@ class CosmosTable(val userProvidedSchema: StructType,

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
// TODO moderakh how options and userConfig should be merged? is there any difference?
CosmosScanBuilder(options)
CosmosScanBuilder(options, schema())
}

override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = {
new CosmosWriterBuilder(userConfig.asScala.toMap)
// TODO: moderakh merge logicalWriteInfo config with other configs

new CosmosWriterBuilder(
userConfig.asScala.toMap,
logicalWriteInfo.schema()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}
import org.apache.spark.sql.types.StructType

class CosmosWriterBuilder(userConfig: Map[String, String]) extends WriteBuilder with CosmosLoggingTrait {
class CosmosWriterBuilder(userConfig: Map[String, String], inputSchema: StructType) extends WriteBuilder with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def buildForBatch(): BatchWrite = new CosmosBatchWriter(userConfig: Map[String, String])
override def buildForBatch(): BatchWrite = new CosmosBatchWriter(userConfig: Map[String, String], inputSchema)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import java.util.UUID

import com.azure.cosmos.CosmosClientBuilder
import com.azure.cosmos.implementation.{TestConfigurations, Utils}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.assertj.core.api.Assertions.assertThat
import org.codehaus.jackson.map.ObjectMapper
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

class SparkE2EQuerySpec extends IntegrationSpec {
//scalastyle:off multiple.string.literals
//scalastyle:off magic.number

it can "query cosmos and use user provided schema" taggedAs (RequiresCosmosEndpoint) in {
val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY
val cosmosDatabase = "testDB"
val cosmosContainer = UUID.randomUUID().toString

val client = new CosmosClientBuilder()
.endpoint(cosmosEndpoint)
.key(cosmosMasterKey)
.buildAsyncClient()

client.createDatabaseIfNotExists(cosmosDatabase).block()
client.getDatabase(cosmosDatabase).createContainerIfNotExists(cosmosContainer, "/id").block()

val container = client.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
for (state <- Array(true, false)) {
val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
objectNode.put("name", "Shrodigner's cat")
objectNode.put("type", "cat")
objectNode.put("age", 20)
objectNode.put("isAlive", state)
objectNode.put("id", UUID.randomUUID().toString)
container.createItem(objectNode).block()
}
val cfg = Map("spark.cosmos.accountEndpoint" -> cosmosEndpoint,
"spark.cosmos.accountKey" -> cosmosMasterKey,
"spark.cosmos.database" -> cosmosDatabase,
"spark.cosmos.container" -> cosmosContainer
)

val spark = SparkSession.builder()
.appName("spark connector sample")
.master("local")
.getOrCreate()

// scalastyle:off underscore.import
// scalastyle:off import.grouping
import org.apache.spark.sql.types._
// scalastyle:on underscore.import
// scalastyle:on import.grouping

val customSchema = StructType(Array(
StructField("id", StringType),
StructField("name", StringType),
StructField("type", StringType),
StructField("age", IntegerType),
StructField("isAlive", BooleanType)
))

val df = spark.read.schema(customSchema).format("cosmos.items").options(cfg).load()
val rowsArray = df.where("isAlive = 'true'").collect()
rowsArray should have size 1

val row = rowsArray(0)
row.getAs[String]("name") shouldEqual "Shrodigner's cat"
row.getAs[String]("type") shouldEqual "cat"
row.getAs[Integer]("age") shouldEqual 20
row.getAs[Boolean]("isAlive") shouldEqual true

client.close()
spark.close()
}

//scalastyle:on magic.number
//scalastyle:on multiple.string.literals
}