Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos spark config infra #17702

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}

class CosmosBatchWriter extends BatchWrite with CosmosLoggingTrait {
class CosmosBatchWriter(userConfig: Map[String, String]) extends BatchWrite with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory()
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory(userConfig)

override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
// TODO
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.spark

import java.net.URL
import java.util.Locale


// each config category will be a case class:
// TODO moderakh more configs
//case class ClientConfig()
//case class CosmosBatchWriteConfig()

case class CosmosAccountConfig(endpoint: String, key: String)

object CosmosAccountConfig {
val CosmosAccountEndpointUri = CosmosConfigEntry[String](key = "spark.cosmos.accountEndpoint",
mandatory = true,
parseFromStringFunction = accountEndpointUri => {
new URL(accountEndpointUri)
accountEndpointUri
},
helpMessage = "Cosmos DB Account Endpoint Uri")

val CosmosKey = CosmosConfigEntry[String](key = "spark.cosmos.accountKey",
mandatory = true,
parseFromStringFunction = accountEndpointUri => accountEndpointUri,
helpMessage = "Cosmos DB Account Key")

def parseCosmosAccountConfig(cfg: Map[String, String]): CosmosAccountConfig = {
val endpointOpt = CosmosConfigEntry.parse(cfg, CosmosAccountEndpointUri)
val key = CosmosConfigEntry.parse(cfg, CosmosKey)

// parsing above already validated these assertions
assert(endpointOpt.isDefined)
assert(key.isDefined)

CosmosAccountConfig(endpointOpt.get, key.get)
}
}

case class CosmosContainerConfig(database: String, container: String)

object CosmosContainerConfig {
val databaseName = CosmosConfigEntry[String](key = "spark.cosmos.database",
mandatory = true,
parseFromStringFunction = database => database,
helpMessage = "Cosmos DB database name")

val containerName = CosmosConfigEntry[String](key = "spark.cosmos.container",
mandatory = true,
parseFromStringFunction = container => container,
helpMessage = "Cosmos DB container name")

def parseCosmosContainerConfig(cfg: Map[String, String]): CosmosContainerConfig = {
val databaseOpt = CosmosConfigEntry.parse(cfg, databaseName)
val containerOpt = CosmosConfigEntry.parse(cfg, containerName)

// parsing above already validated this
assert(databaseOpt.isDefined)
assert(containerOpt.isDefined)

CosmosContainerConfig(databaseOpt.get, containerOpt.get)
}
}

case class CosmosConfigEntry[T](key: String,
mandatory: Boolean,
defaultValue: Option[String] = Option.empty,
parseFromStringFunction: String => T,
helpMessage: String) {

def parse(paramAsString: String) : T = {
try {
parseFromStringFunction(paramAsString)
} catch {
case e: Exception => throw new RuntimeException(s"invalid configuration for ${key}:${paramAsString}. Config description: ${helpMessage}", e)
}
}
}

// TODO: moderakh how to merge user config with SparkConf application config?
object CosmosConfigEntry {
def parse[T](configuration: Map[String, String], configEntry: CosmosConfigEntry[T]): Option[T] = {
// TODO moderakh: where should we handle case sensitivity?
// we are doing this here per config parsing for now
val opt = configuration.map { case (key, value) => (key.toLowerCase(Locale.ROOT), value) }.get(configEntry.key.toLowerCase(Locale.ROOT))
if (opt.isDefined) {
Option.apply(configEntry.parse(opt.get))
}
else {
if (configEntry.mandatory) {
throw new RuntimeException(s"mandatory option ${configEntry.key} is missing. Config description: ${configEntry.helpMessage}")
} else {
Option.empty
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CosmosDataWriteFactory extends DataWriterFactory with CosmosLoggingTrait {
class CosmosDataWriteFactory(userConfig: Map[String, String]) extends DataWriterFactory with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createWriter(i: Int, l: Long): DataWriter[InternalRow] = new CosmosWriter()

class CosmosWriter() extends DataWriter[InternalRow] {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

// TODO moderakh account config and databaseName, containerName need to passed down from the user
val cosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
val cosmosTargetContainerConfig = CosmosContainerConfig.parseCosmosContainerConfig(userConfig)

// TODO moderakh: this needs to be shared to avoid creating multiple clients
val client = new CosmosClientBuilder()
.key(TestConfigurations.MASTER_KEY)
.endpoint(TestConfigurations.HOST)
.key(cosmosAccountConfig.key)
.endpoint(cosmosAccountConfig.endpoint)
.consistencyLevel(ConsistencyLevel.EVENTUAL)
.buildAsyncClient();
val databaseName = "testDB"
val containerName = "testContainer"

override def write(internalRow: InternalRow): Unit = {
// TODO moderakh: schema is hard coded for now to make end to end TestE2EMain work implement schema inference code
Expand All @@ -36,8 +37,8 @@ class CosmosDataWriteFactory extends DataWriterFactory with CosmosLoggingTrait {
if (!objectNode.has("id")) {
objectNode.put("id", UUID.randomUUID().toString)
}
client.getDatabase(databaseName)
.getContainer(containerName)
client.getDatabase(cosmosTargetContainerConfig.database)
.getContainer(cosmosTargetContainerConfig.container)
.createItem(objectNode)
.block()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import scala.collection.JavaConverters._
* CosmosTable is the entry point this is registered in the spark
* @param userProvidedSchema
* @param transforms
* @param map
* @param userConfig
*/
class CosmosTable(val userProvidedSchema: StructType,
val transforms: Array[Transform],
val map: util.Map[String, String])
val userConfig: util.Map[String, String])
extends Table with SupportsWrite with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

Expand All @@ -35,6 +35,6 @@ class CosmosTable(val userProvidedSchema: StructType,

override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava

override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = new CosmosWriterBuilder
override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = new CosmosWriterBuilder(userConfig.asScala.toMap)

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}

class CosmosWriterBuilder extends WriteBuilder with CosmosLoggingTrait {
class CosmosWriterBuilder(userConfig: Map[String, String]) extends WriteBuilder with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def buildForBatch(): BatchWrite = new CosmosBatchWriter()
override def buildForBatch(): BatchWrite = new CosmosBatchWriter(userConfig: Map[String, String])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import org.assertj.core.api.Assertions.assertThat

class CosmosConfigSpec extends UnitSpec {
//scalastyle:off multiple.string.literals

"account endpoint" should "be parsed" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhsot:8081",
"spark.cosmos.accountKey" -> "xyz"
)

val endpointConfig = CosmosAccountConfig.parseCosmosAccountConfig(userConfig)

assertThat(endpointConfig.endpoint).isEqualTo( "https://localhsot:8081")
assertThat(endpointConfig.key).isEqualTo( "xyz")
}

"account endpoint" should "be validated" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "invalidUrl",
"spark.cosmos.accountKey" -> "xyz"
)

try {
CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
fail("invalid URL")
} catch {
case e: Exception => assertThat(e.getMessage).isEqualTo(
"invalid configuration for spark.cosmos.accountEndpoint:invalidUrl." +
" Config description: Cosmos DB Account Endpoint Uri")
}
}

"account endpoint" should "mandatory config" in {
val userConfig = Map(
"spark.cosmos.accountKey" -> "xyz"
)

try {
CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
fail("missing URL")
} catch {
case e: Exception => assertThat(e.getMessage).isEqualTo(
"mandatory option spark.cosmos.accountEndpoint is missing." +
" Config description: Cosmos DB Account Endpoint Uri")
}
}
//scalastyle:on multiple.string.literals
}