Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract Details message type definitions in Kingdom internal API #1798

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.internal.kingdom.CreateDuchyMeasurementLogEntryRequest
import org.wfanet.measurement.internal.kingdom.DuchyMeasurementLogEntry
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntriesGrpcKt.MeasurementLogEntriesCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntry.ErrorDetails.Type.TRANSIENT
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntryError
import org.wfanet.measurement.internal.kingdom.StateTransitionMeasurementLogEntry
import org.wfanet.measurement.internal.kingdom.StreamStateTransitionMeasurementLogEntriesRequest
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DuchyNotFoundException
Expand All @@ -43,7 +43,9 @@ class SpannerMeasurementLogEntriesService(
): DuchyMeasurementLogEntry {

if (request.measurementLogEntryDetails.hasError()) {
grpcRequire(request.measurementLogEntryDetails.error.type == TRANSIENT) {
grpcRequire(
request.measurementLogEntryDetails.error.type == MeasurementLogEntryError.Type.TRANSIENT
) {
"MeasurementLogEntries Service only supports TRANSIENT errors, " +
"use FailComputationParticipant instead."
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.wfanet.measurement.internal.kingdom.FulfillRequisitionRequest
import org.wfanet.measurement.internal.kingdom.GetRequisitionRequest
import org.wfanet.measurement.internal.kingdom.RefuseRequisitionRequest
import org.wfanet.measurement.internal.kingdom.Requisition
import org.wfanet.measurement.internal.kingdom.Requisition.Refusal
import org.wfanet.measurement.internal.kingdom.RequisitionRefusal
import org.wfanet.measurement.internal.kingdom.RequisitionsGrpcKt.RequisitionsCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.StreamRequisitionsRequest
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DuchyNotFoundException
Expand Down Expand Up @@ -117,10 +117,12 @@ class SpannerRequisitionsService(
with(request) {
grpcRequire(externalDataProviderId != 0L) { "external_data_provider_id not specified" }
grpcRequire(externalRequisitionId != 0L) { "external_requisition_id not specified" }
grpcRequire(refusal.justification != Refusal.Justification.UNRECOGNIZED) {
grpcRequire(refusal.justification != RequisitionRefusal.Justification.UNRECOGNIZED) {
"Unrecognized refusal justification ${refusal.justificationValue}"
}
grpcRequire(refusal.justification != Refusal.Justification.JUSTIFICATION_UNSPECIFIED) {
grpcRequire(
refusal.justification != RequisitionRefusal.Justification.JUSTIFICATION_UNSPECIFIED
) {
"refusal justification not specified"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getBytesAsByteString
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.Certificate
import org.wfanet.measurement.internal.kingdom.CertificateDetails
import org.wfanet.measurement.internal.kingdom.CertificateKt
import org.wfanet.measurement.internal.kingdom.certificate
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
Expand Down Expand Up @@ -297,7 +297,7 @@ class CertificateReader(private val parentType: ParentType) :
notValidAfter = struct.getTimestamp("NotValidAfter").toProto()
revocationState =
struct.getProtoEnum("RevocationState", Certificate.RevocationState::forNumber)
details = struct.getProtoMessage("CertificateDetails", Certificate.Details.parser())
details = struct.getProtoMessage("CertificateDetails", CertificateDetails.parser())
}

/** Returns the internal Certificate ID for a Duchy Certificate, or `null` if not found. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ComputationParticipantDetails
import org.wfanet.measurement.internal.kingdom.DuchyMeasurementLogEntry
import org.wfanet.measurement.internal.kingdom.DuchyMeasurementLogEntryDetails
import org.wfanet.measurement.internal.kingdom.Measurement
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntry
import org.wfanet.measurement.internal.kingdom.MeasurementDetails
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntryDetails
import org.wfanet.measurement.internal.kingdom.computationParticipant
import org.wfanet.measurement.internal.kingdom.duchyMeasurementLogEntry
import org.wfanet.measurement.internal.kingdom.measurementLogEntry
Expand Down Expand Up @@ -96,7 +98,7 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
val measurementId: InternalId,
val measurementConsumerId: InternalId,
val measurementState: Measurement.State,
val measurementDetails: Measurement.Details,
val measurementDetails: MeasurementDetails,
)

override val builder: Statement.Builder = Statement.newBuilder(BASE_SQL)
Expand Down Expand Up @@ -144,15 +146,15 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
struct.getInternalId("MeasurementId"),
struct.getInternalId("MeasurementConsumerId"),
struct.getProtoEnum("MeasurementState", Measurement.State::forNumber),
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser()),
struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser()),
)

private fun buildComputationParticipant(struct: Struct): ComputationParticipant {
val externalMeasurementConsumerId = ExternalId(struct.getLong("ExternalMeasurementConsumerId"))
val externalMeasurementId = ExternalId(struct.getLong("ExternalMeasurementId"))
val externalComputationId = ExternalId(struct.getLong("ExternalComputationId"))
val measurementDetails =
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser())
struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser())

val duchyId = struct.getLong("DuchyId")
val externalDuchyId =
Expand All @@ -174,7 +176,7 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
externalMeasurementId: ExternalId,
externalDuchyId: String,
externalComputationId: ExternalId,
measurementDetails: Measurement.Details,
measurementDetails: MeasurementDetails,
struct: Struct,
): ComputationParticipant {
val failureLogEntry: DuchyMeasurementLogEntry? =
Expand All @@ -198,7 +200,7 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
this.etag = etag
state = struct.getProtoEnum("State", ComputationParticipant.State::forNumber)
details =
struct.getProtoMessage("ParticipantDetails", ComputationParticipant.Details.parser())
struct.getProtoMessage("ParticipantDetails", ComputationParticipantDetails.parser())
apiVersion = measurementDetails.apiVersion

if (failureLogEntry != null) {
Expand All @@ -216,7 +218,7 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
return logEntryStructs
.asSequence()
.map {
it to it.getProtoMessage("MeasurementLogDetails", MeasurementLogEntry.Details.parser())
it to it.getProtoMessage("MeasurementLogDetails", MeasurementLogEntryDetails.parser())
}
.find { (_, logEntryDetails) -> logEntryDetails.hasError() }
?.let { (struct, logEntryDetails) ->
Expand All @@ -232,7 +234,7 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
details =
struct.getProtoMessage(
"DuchyMeasurementLogDetails",
DuchyMeasurementLogEntry.Details.parser(),
DuchyMeasurementLogEntryDetails.parser(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.Certificate
import org.wfanet.measurement.internal.kingdom.DataProvider
import org.wfanet.measurement.internal.kingdom.DataProviderDetails
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException

Expand Down Expand Up @@ -118,7 +119,7 @@ class DataProviderReader : SpannerReader<DataProviderReader.Result>() {
DataProvider.newBuilder()
.apply {
externalDataProviderId = struct.getLong("ExternalDataProviderId")
details = struct.getProtoMessage("DataProviderDetails", DataProvider.Details.parser())
details = struct.getProtoMessage("DataProviderDetails", DataProviderDetails.parser())
certificate = CertificateReader.buildDataProviderCertificate(struct)
addAllRequiredExternalDuchyIds(buildExternalDuchyIdList(struct))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.EventGroupMetadataDescriptor
import org.wfanet.measurement.internal.kingdom.EventGroupMetadataDescriptorDetails
import org.wfanet.measurement.internal.kingdom.eventGroupMetadataDescriptor

class EventGroupMetadataDescriptorReader :
Expand Down Expand Up @@ -120,7 +121,7 @@ class EventGroupMetadataDescriptorReader :
}
if (!struct.isNull("DescriptorDetails")) {
details =
struct.getProtoMessage("DescriptorDetails", EventGroupMetadataDescriptor.Details.parser())
struct.getProtoMessage("DescriptorDetails", EventGroupMetadataDescriptorDetails.parser())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.EventGroup
import org.wfanet.measurement.internal.kingdom.EventGroupDetails
import org.wfanet.measurement.internal.kingdom.eventGroup

class EventGroupReader : BaseSpannerReader<EventGroupReader.Result>() {
Expand Down Expand Up @@ -124,7 +124,7 @@ class EventGroupReader : BaseSpannerReader<EventGroupReader.Result>() {
createTime = struct.getTimestamp("CreateTime").toProto()
updateTime = struct.getTimestamp("UpdateTime").toProto()
if (!struct.isNull("EventGroupDetails")) {
details = struct.getProtoMessage("EventGroupDetails", EventGroup.Details.parser())
details = struct.getProtoMessage("EventGroupDetails", EventGroupDetails.parser())
}
state = struct.getProtoEnum("State", EventGroup.State::forNumber)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.MeasurementConsumer
import org.wfanet.measurement.internal.kingdom.MeasurementConsumerDetails

class MeasurementConsumerReader : SpannerReader<MeasurementConsumerReader.Result>() {
data class Result(val measurementConsumer: MeasurementConsumer, val measurementConsumerId: Long)
Expand Down Expand Up @@ -58,7 +59,7 @@ class MeasurementConsumerReader : SpannerReader<MeasurementConsumerReader.Result
.apply {
externalMeasurementConsumerId = struct.getLong("ExternalMeasurementConsumerId")
details =
struct.getProtoMessage("MeasurementConsumerDetails", MeasurementConsumer.Details.parser())
struct.getProtoMessage("MeasurementConsumerDetails", MeasurementConsumerDetails.parser())
certificate = CertificateReader.buildMeasurementConsumerCertificate(struct)
}
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ import com.google.cloud.spanner.Struct
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.Measurement
import org.wfanet.measurement.internal.kingdom.MeasurementDetails

class MeasurementDetailsReader() : SpannerReader<MeasurementDetailsReader.Result>() {

data class Result(
val measurementConsumerId: InternalId,
val measurementId: InternalId,
val measurementDetails: Measurement.Details,
val measurementDetails: MeasurementDetails,
)

override val baseSql =
Expand All @@ -43,6 +43,6 @@ class MeasurementDetailsReader() : SpannerReader<MeasurementDetailsReader.Result
Result(
struct.getInternalId("MeasurementConsumerId"),
struct.getInternalId("MeasurementId"),
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser()),
struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser()),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.getBytesAsByteString
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.gcloud.spanner.statement
import org.wfanet.measurement.internal.kingdom.Measurement
import org.wfanet.measurement.internal.kingdom.MeasurementDetails
import org.wfanet.measurement.internal.kingdom.MeasurementKt
import org.wfanet.measurement.internal.kingdom.MeasurementKt.dataProviderValue
import org.wfanet.measurement.internal.kingdom.MeasurementKt.resultInfo
import org.wfanet.measurement.internal.kingdom.Requisition
import org.wfanet.measurement.internal.kingdom.RequisitionDetails
import org.wfanet.measurement.internal.kingdom.measurement
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.ETags
Expand Down Expand Up @@ -414,7 +414,7 @@ private fun MeasurementKt.Dsl.fillMeasurementCommon(struct: Struct) {
createTime = struct.getTimestamp("CreateTime").toProto()
updateTime = struct.getTimestamp("UpdateTime").toProto()
state = struct.getProtoEnum("MeasurementState", Measurement.State::forNumber)
details = struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser())
details = struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser())
if (state == Measurement.State.SUCCEEDED) {
for (duchyResultStruct in struct.getStructList("DuchyResults")) {
results += resultInfo {
Expand All @@ -438,7 +438,7 @@ private fun MeasurementKt.Dsl.fillDefaultView(struct: Struct) {
val measurementSucceeded = state == Measurement.State.SUCCEEDED
for (requisitionStruct in struct.getStructList("Requisitions")) {
val requisitionDetails =
requisitionStruct.getProtoMessage("RequisitionDetails", Requisition.Details.parser())
requisitionStruct.getProtoMessage("RequisitionDetails", RequisitionDetails.parser())
val externalDataProviderId = requisitionStruct.getLong("ExternalDataProviderId")
val externalDataProviderCertificateId =
requisitionStruct.getLong("ExternalDataProviderCertificateId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ComputationParticipantDetails
import org.wfanet.measurement.internal.kingdom.Measurement
import org.wfanet.measurement.internal.kingdom.MeasurementDetails
import org.wfanet.measurement.internal.kingdom.Requisition
import org.wfanet.measurement.internal.kingdom.RequisitionDetails
import org.wfanet.measurement.internal.kingdom.RequisitionKt.duchyValue
import org.wfanet.measurement.internal.kingdom.RequisitionKt.parentMeasurement
import org.wfanet.measurement.internal.kingdom.requisition
Expand Down Expand Up @@ -102,7 +103,7 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
val measurementId: InternalId,
val requisitionId: InternalId,
val requisition: Requisition,
val measurementDetails: Measurement.Details,
val measurementDetails: MeasurementDetails,
)

override val builder: Statement.Builder = Statement.newBuilder(BASE_SQL)
Expand All @@ -113,7 +114,7 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
InternalId(struct.getLong("MeasurementId")),
InternalId(struct.getLong("RequisitionId")),
buildRequisition(struct),
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser()),
struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser()),
)
}

Expand Down Expand Up @@ -206,8 +207,7 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
for ((externalDuchyId, participantStruct) in participantStructs) {
duchies[externalDuchyId] = buildDuchyValue(participantStruct)
}
details =
requisitionStruct.getProtoMessage("RequisitionDetails", Requisition.Details.parser())
details = requisitionStruct.getProtoMessage("RequisitionDetails", RequisitionDetails.parser())
dataProviderCertificate = CertificateReader.buildDataProviderCertificate(requisitionStruct)

parentMeasurement = buildParentMeasurement(measurementStruct, dataProviderCount)
Expand All @@ -224,26 +224,26 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
}

val participantDetails =
struct.getProtoMessage("ParticipantDetails", ComputationParticipant.Details.parser())
struct.getProtoMessage("ParticipantDetails", ComputationParticipantDetails.parser())
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null.
when (participantDetails.protocolCase) {
ComputationParticipant.Details.ProtocolCase.LIQUID_LEGIONS_V2 -> {
ComputationParticipantDetails.ProtocolCase.LIQUID_LEGIONS_V2 -> {
liquidLegionsV2 = participantDetails.liquidLegionsV2
}
ComputationParticipant.Details.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
ComputationParticipantDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = participantDetails.reachOnlyLiquidLegionsV2
}
ComputationParticipant.Details.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
ComputationParticipantDetails.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = participantDetails.honestMajorityShareShuffle
}
// Protocol may only be set after computation participant sets requisition params.
ComputationParticipant.Details.ProtocolCase.PROTOCOL_NOT_SET -> Unit
ComputationParticipantDetails.ProtocolCase.PROTOCOL_NOT_SET -> Unit
}
}

private fun buildParentMeasurement(struct: Struct, dataProviderCount: Int) = parentMeasurement {
val measurementDetails =
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser())
struct.getProtoMessage("MeasurementDetails", MeasurementDetails.parser())
apiVersion = measurementDetails.apiVersion
externalMeasurementConsumerCertificateId =
struct.getLong("ExternalMeasurementConsumerCertificateId")
Expand Down
Loading
Loading