Skip to content

Commit

Permalink
refactor: Deprecate Spanner extensions for protobuf values
Browse files Browse the repository at this point in the history
This is now handled by the Spanner client library.
  • Loading branch information
SanjayVas committed Sep 17, 2024
1 parent b681ced commit b1e62b0
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 110 deletions.
26 changes: 24 additions & 2 deletions src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.wfanet.measurement.common

import com.google.protobuf.AbstractMessage
import com.google.protobuf.Any as ProtoAny
import com.google.protobuf.AnyProto
import com.google.protobuf.ApiProto
Expand All @@ -24,6 +25,7 @@ import com.google.protobuf.Descriptors
import com.google.protobuf.DurationProto
import com.google.protobuf.EmptyProto
import com.google.protobuf.Message
import com.google.protobuf.ProtocolMessageEnum
import com.google.protobuf.StructProto
import com.google.protobuf.TimestampProto
import com.google.protobuf.TypeProto
Expand Down Expand Up @@ -118,16 +120,36 @@ object ProtoReflection {
}

/** Reflectively calls the `getDefaultInstance` static function for [T]. */
fun <T : Message> getDefaultInstance(kclass: KClass<T>): T {
fun <T : Message> getDefaultInstance(kClass: KClass<T>): T {
// Every Message type should have a static getDefaultInstance function.
@Suppress("UNCHECKED_CAST") // Guaranteed by predicate.
val function =
kclass.staticFunctions.single { it.name == "getDefaultInstance" && it.parameters.isEmpty() }
kClass.staticFunctions.single { it.name == "getDefaultInstance" && it.parameters.isEmpty() }
as kotlin.reflect.KFunction0<T>

return function.call()
}

/** Reflectively calls the `getDescriptorForType` static function for [T]. */
fun <T : ProtocolMessageEnum> getDescriptorForType(
kClass: KClass<T>
): Descriptors.EnumDescriptor {
@Suppress("UNCHECKED_CAST") // Guaranteed by predicate.
val function =
kClass.staticFunctions.single { it.name == "getDescriptorForType" && it.parameters.isEmpty() }
as kotlin.reflect.KFunction0<Descriptors.EnumDescriptor>
return function.call()
}

/** Reflectively calls the `getDescriptorForType` static function for [T]. */
fun <T : AbstractMessage> getDescriptorForType(kClass: KClass<T>): Descriptors.Descriptor {
@Suppress("UNCHECKED_CAST") // Guaranteed by predicate.
val function =
kClass.staticFunctions.single { it.name == "getDescriptorForType" && it.parameters.isEmpty() }
as kotlin.reflect.KFunction0<Descriptors.Descriptor>
return function.call()
}

/**
* Builds a [DescriptorProtos.FileDescriptorSet] from [descriptor], including direct and
* transitive dependencies.
Expand Down
27 changes: 19 additions & 8 deletions src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Mutations.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ import com.google.cloud.ByteArray
import com.google.cloud.Date
import com.google.cloud.Timestamp
import com.google.cloud.spanner.Mutation
import com.google.cloud.spanner.ValueBinder
import com.google.protobuf.AbstractMessage
import com.google.protobuf.Message
import com.google.protobuf.ProtocolMessageEnum
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId

Expand Down Expand Up @@ -97,14 +100,14 @@ fun Mutation.WriteBuilder.set(columnValuePair: Pair<String, ByteArray?>): Mutati
@JvmName("setInternalId")
fun Mutation.WriteBuilder.set(columnValuePair: Pair<String, InternalId>): Mutation.WriteBuilder {
val (columnName, value) = columnValuePair
return set(columnName).to(value.value)
return set(columnName).to(value)
}

/** Sets the value that should be bound to the specified column. */
@JvmName("setExternalId")
fun Mutation.WriteBuilder.set(columnValuePair: Pair<String, ExternalId>): Mutation.WriteBuilder {
val (columnName, value) = columnValuePair
return set(columnName).to(value.value)
return set(columnName).to(value)
}

/** Sets the value that should be bound to the specified column. */
Expand All @@ -113,14 +116,22 @@ fun Mutation.WriteBuilder.set(
columnValuePair: Pair<String, ProtocolMessageEnum>
): Mutation.WriteBuilder {
val (columnName, value) = columnValuePair
return set(columnName).toProtoEnum(value)
return set(columnName).to(value)
}

/** Sets the value that should be bound to the specified column. */
@Deprecated(message = "Use ValueBinder directly")
@JvmName("setProtoMessageBytes")
fun Mutation.WriteBuilder.set(columnValuePair: Pair<String, Message?>): Mutation.WriteBuilder {
inline fun <reified T : AbstractMessage> Mutation.WriteBuilder.set(
columnValuePair: Pair<String, T?>
): Mutation.WriteBuilder {
val (columnName, value) = columnValuePair
return set(columnName).toProtoBytes(value)
val binder: ValueBinder<Mutation.WriteBuilder> = set(columnName)
return if (value == null) {
binder.to(null, ProtoReflection.getDescriptorForType(T::class))
} else {
binder.to(value)
}
}

/** Sets the JSON value that should be bound to the specified string column. */
Expand All @@ -144,23 +155,23 @@ inline fun insertOrUpdateMutation(table: String, bind: Mutation.WriteBuilder.()
/** Builds and buffers an [INSERT][Mutation.Op.INSERT] [Mutation]. */
inline fun AsyncDatabaseClient.TransactionContext.bufferInsertMutation(
table: String,
bind: Mutation.WriteBuilder.() -> Unit
bind: Mutation.WriteBuilder.() -> Unit,
) {
insertMutation(table, bind).bufferTo(this)
}

/** Builds and buffers an [UPDATE][Mutation.Op.UPDATE] [Mutation]. */
inline fun AsyncDatabaseClient.TransactionContext.bufferUpdateMutation(
table: String,
bind: Mutation.WriteBuilder.() -> Unit
bind: Mutation.WriteBuilder.() -> Unit,
) {
updateMutation(table, bind).bufferTo(this)
}

/** Builds and buffers an [INSERT_OR_UPDATE][Mutation.Op.INSERT_OR_UPDATE] [Mutation]. */
inline fun AsyncDatabaseClient.TransactionContext.bufferInsertOrUpdateMutation(
table: String,
bind: Mutation.WriteBuilder.() -> Unit
bind: Mutation.WriteBuilder.() -> Unit,
) {
insertOrUpdateMutation(table, bind).bufferTo(this)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import com.google.cloud.ByteArray
import com.google.cloud.Date
import com.google.cloud.Timestamp
import com.google.cloud.spanner.Statement
import com.google.protobuf.Message
import com.google.cloud.spanner.ValueBinder
import com.google.protobuf.AbstractMessage
import com.google.protobuf.ProtocolMessageEnum
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId

Expand Down Expand Up @@ -97,32 +99,40 @@ fun Statement.Builder.bind(paramValuePair: Pair<String, ByteArray?>): Statement.
@JvmName("bindInternalId")
fun Statement.Builder.bind(paramValuePair: Pair<String, InternalId>): Statement.Builder {
val (paramName, value) = paramValuePair
return bind(paramName).to(value.value)
return bind(paramName).to(value)
}

/** Binds the value that should be bound to the specified param. */
@JvmName("bindExternalId")
fun Statement.Builder.bind(paramValuePair: Pair<String, ExternalId>): Statement.Builder {
val (paramName, value) = paramValuePair
return bind(paramName).to(value.value)
return bind(paramName).to(value)
}

/** Binds the value that should be bound to the specified param. */
@JvmName("bindProtoEnum")
fun Statement.Builder.bind(paramValuePair: Pair<String, ProtocolMessageEnum>): Statement.Builder {
val (paramName, value) = paramValuePair
return bind(paramName).toProtoEnum(value)
return bind(paramName).to(value)
}

/** Binds the value that should be bound to the specified param. */
@Deprecated(message = "Use ValueBinder directly")
@JvmName("bindProtoMessageBytes")
fun Statement.Builder.bind(paramValuePair: Pair<String, Message?>): Statement.Builder {
inline fun <reified T : AbstractMessage> Statement.Builder.bind(
paramValuePair: Pair<String, T?>
): Statement.Builder {
val (paramName, value) = paramValuePair
return bind(paramName).toProtoBytes(value)
val binder: ValueBinder<Statement.Builder> = bind(paramName)
return if (value == null) {
binder.to(null, ProtoReflection.getDescriptorForType(T::class))
} else {
binder.to(value)
}
}

/** Binds the JSON value that should be bound to the specified string param. */
fun Statement.Builder.bindJson(paramValuePair: Pair<String, Message?>): Statement.Builder {
fun Statement.Builder.bindJson(paramValuePair: Pair<String, AbstractMessage?>): Statement.Builder {
val (paramName, value) = paramValuePair
return bind(paramName).toProtoJson(value)
}
Expand Down
62 changes: 59 additions & 3 deletions src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Structs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ import com.google.cloud.Date
import com.google.cloud.Timestamp
import com.google.cloud.spanner.Struct
import com.google.cloud.spanner.StructReader
import com.google.cloud.spanner.Type
import com.google.cloud.spanner.ValueBinder
import com.google.protobuf.AbstractMessage
import com.google.protobuf.ByteString
import com.google.protobuf.Message
import com.google.protobuf.Parser
import com.google.protobuf.ProtocolMessageEnum
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId

Expand Down Expand Up @@ -98,14 +104,22 @@ fun Struct.Builder.set(columnValuePair: Pair<String, ByteArray?>): Struct.Builde
@JvmName("setProtoEnum")
fun Struct.Builder.set(columnValuePair: Pair<String, ProtocolMessageEnum>): Struct.Builder {
val (columnName, value) = columnValuePair
return set(columnName).toProtoEnum(value)
return set(columnName).to(value)
}

/** Sets the value that should be bound to the specified column. */
@Deprecated(message = "Use ValueBinder directly")
@JvmName("setProtoMessageBytes")
fun Struct.Builder.set(columnValuePair: Pair<String, Message?>): Struct.Builder {
inline fun <reified T : AbstractMessage> Struct.Builder.set(
columnValuePair: Pair<String, T?>
): Struct.Builder {
val (columnName, value) = columnValuePair
return set(columnName).toProtoBytes(value)
val binder: ValueBinder<Struct.Builder> = set(columnName)
return if (value == null) {
binder.to(null, ProtoReflection.getDescriptorForType(T::class))
} else {
binder.to(value)
}
}

/** Sets the JSON value that should be bound to the specified string column. */
Expand All @@ -126,5 +140,47 @@ fun StructReader.getInternalId(columnName: String) = InternalId(getLong(columnNa
*/
fun StructReader.getExternalId(columnName: String) = ExternalId(getLong(columnName))

private fun <T> StructReader.nullOrValue(
column: String,
typeCode: Type.Code,
getter: StructReader.(String) -> T,
): T? {
val columnType = getColumnType(column).code
check(columnType == typeCode) { "Cannot read $typeCode from $column, it has type $columnType" }
return if (isNull(column)) null else getter(column)
}

/** Returns the value of a String column even if it is null. */
fun StructReader.getNullableString(column: String): String? =
nullOrValue(column, Type.Code.STRING, StructReader::getString)

/** Returns the value of an Array of Structs column even if it is null. */
fun StructReader.getNullableStructList(column: String): MutableList<Struct>? =
nullOrValue(column, Type.Code.ARRAY, StructReader::getStructList)

/** Returns the value of a Timestamp column even if it is null. */
fun StructReader.getNullableTimestamp(column: String): Timestamp? =
nullOrValue(column, Type.Code.TIMESTAMP, StructReader::getTimestamp)

/** Returns the value of a INT64 column even if it is null. */
fun StructReader.getNullableLong(column: String): Long? =
nullOrValue(column, Type.Code.INT64, StructReader::getLong)

/** Returns a bytes column as a Kotlin native ByteArray. This is useful for deserializing protos. */
fun StructReader.getBytesAsByteArray(column: String): kotlin.ByteArray =
getBytes(column).toByteArray()

/** Returns a bytes column as a protobuf ByteString. */
fun StructReader.getBytesAsByteString(column: String): ByteString =
ByteString.copyFrom(getBytes(column).asReadOnlyByteBuffer())

/** Parses a protobuf [Message] from a BYTES column. */
@Suppress("DeprecatedCallableAddReplaceWith") // Should use manual replacement to avoid reflection.
@Deprecated(message = "Use `getProtoMessage` overload which takes in a default message instance")
inline fun <reified T : AbstractMessage> StructReader.getProtoMessage(
column: String,
parser: Parser<T>,
): T = getProtoMessage(column, ProtoReflection.getDefaultInstance(T::class))

/** Builds a [Struct]. */
inline fun struct(bind: Struct.Builder.() -> Unit): Struct = Struct.newBuilder().apply(bind).build()
Loading

0 comments on commit b1e62b0

Please sign in to comment.