Skip to content

Commit

Permalink
Fix constraint-related errors in Rpcv2CBOR server implementation (#3794)
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-sdk-rust-ci authored Oct 1, 2024
2 parents 684c15f + 27ca7f1 commit 191c577
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 43 deletions.
9 changes: 9 additions & 0 deletions .changelog/2155171.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
applies_to: ["server","client"]
authors: ["drganjoo"]
references: [smithy-rs#3573]
breaking: false
new_feature: true
bug_fix: false
---
Support for the [rpcv2Cbor](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html) protocol has been added, allowing services to serialize RPC payloads as CBOR (Concise Binary Object Representation), improving performance and efficiency in data transmission.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
Expand Down Expand Up @@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
"exception",
"UnmodeledError",
"${testCase.responseContentType}",
br#"${testCase.validUnmodeledError}"#
${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
);
let result = $generator::new().unmarshall(&message);
assert!(result.is_ok(), "expected ok, got: {:?}", result);
Expand Down
1 change: 1 addition & 0 deletions codegen-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies {
implementation("org.jsoup:jsoup:1.16.2")
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
api("com.moandjiezana.toml:toml4j:0.7.2")
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
Expand Down Expand Up @@ -140,9 +141,23 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
TODO("rpcv2Cbor event streams have not yet been implemented")
ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
rustTemplate(
"""
pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
#{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
}
""",
"cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
"Bytes" to RuntimeType.Bytes,
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
"DeserializeError" to
CargoDependency.smithyCbor(runtimeConfig).toType()
.resolve("decode::DeserializeError"),
"Headers" to RuntimeType.headers(runtimeConfig),
)
}

// Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
// unless there is no input or if the operation is an event stream, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -447,7 +446,24 @@ class CborParserGenerator(
}

override fun payloadParser(member: MemberShape): RuntimeType {
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
val shape = model.expectShape(member.target)
val returnSymbol = returnSymbolToParse(shape)
check(shape is UnionShape || shape is StructureShape) {
"Payload parser should only be used on structure and union shapes."
}
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
rustTemplate(
"""
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
let decoder = &mut #{Decoder}::new(value);
#{DeserializeMember}
}
""",
"ReturnType" to returnSymbol.symbol,
"DeserializeMember" to deserializeMember(member),
*codegenScope,
)
}
}

override fun operationParser(operationShape: OperationShape): RuntimeType? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
/** Manipulate the serializer context for a map prior to it being serialized. **/
data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
CborSerializerSection("BeforeIteratingOverMapOrCollection")

/** Manipulate the serializer context for a non-null member prior to it being serialized. **/
data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
CborSerializerSection("BeforeSerializingNonNullMember")
}

/**
Expand Down Expand Up @@ -200,9 +204,26 @@ class CborSerializerGenerator(
}
}

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun payloadSerializer(member: MemberShape): RuntimeType {
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
val target = model.expectShape(member.target)
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
rustBlockTemplate(
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
*codegenScope,
"target" to symbolProvider.toSymbol(target),
) {
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
rustBlock("") {
rust("let encoder = &mut encoder;")
when (target) {
is StructureShape -> serializeStructure(StructContext("input", target))
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
}
}
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
}
}
}

override fun unsetStructure(structure: StructureShape): RuntimeType =
Expand Down Expand Up @@ -311,6 +332,7 @@ class CborSerializerGenerator(
safeName().also { local ->
rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
context.valueExpression = ValueExpression.Reference(local)
resolveValueExpressionForConstrainedType(targetShape, context)
serializeMemberValue(context, targetShape)
}
if (context.writeNulls) {
Expand All @@ -320,6 +342,7 @@ class CborSerializerGenerator(
}
}
} else {
resolveValueExpressionForConstrainedType(targetShape, context)
with(serializerUtil) {
ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
serializeMemberValue(context, targetShape)
Expand All @@ -328,6 +351,20 @@ class CborSerializerGenerator(
}
}

private fun RustWriter.resolveValueExpressionForConstrainedType(
targetShape: Shape,
context: MemberContext,
) {
for (customization in customizations) {
customization.section(
CborSerializerSection.BeforeSerializingNonNullMember(
targetShape,
context,
),
)(this)
}
}

private fun RustWriter.serializeMemberValue(
context: MemberContext,
target: Shape,
Expand Down Expand Up @@ -362,7 +399,7 @@ class CborSerializerGenerator(
rust("$encoder;") // Encode the member key.
}
when (target) {
is StructureShape -> serializeStructure(StructContext(value.name, target))
is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
is CollectionShape -> serializeCollection(Context(value, target))
is MapShape -> serializeMap(Context(value, target))
is UnionShape -> serializeUnion(Context(value, target))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

package software.amazon.smithy.rust.codegen.core.testutil

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
import java.util.Base64

private fun fillInBaseModel(
protocolName: String,
namespacedProtocolName: String,
extraServiceAnnotations: String = "",
): String =
"""
namespace test
use smithy.framework#ValidationException
use aws.protocols#$protocolName
use $namespacedProtocolName
union TestUnion {
Foo: String,
Expand Down Expand Up @@ -86,22 +90,24 @@ private fun fillInBaseModel(
}
$extraServiceAnnotations
@$protocolName
@${namespacedProtocolName.substringAfter("#")}
service TestService { version: "123", operations: [TestStreamOp] }
"""

object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()

private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()

private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()

private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()

private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

data class TestCase(
val protocolShapeId: String,
Expand All @@ -120,39 +126,67 @@ object EventStreamTestModels {
override fun toString(): String = protocolShapeId
}

private fun base64Encode(input: ByteArray): String {
val encodedBytes = Base64.getEncoder().encode(input)
return String(encodedBytes)
}

private fun createCborFromJson(jsonString: String): ByteArray {
val jsonMapper = ObjectMapper()
val cborMapper = ObjectMapper(CBORFactory())
// Parse JSON string to a generic type.
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
// Convert the parsed data to CBOR.
return cborMapper.writeValueAsBytes(jsonData)
}

private val restJsonTestCase =
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) }

val TEST_CASES =
listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
restJsonTestCase,
//
// rpcV2Cbor
//
restJsonTestCase.copy(
protocolShapeId = "smithy.protocols#rpcv2Cbor",
model = rpcv2Cbor(),
mediaType = "application/cbor",
responseContentType = "application/cbor",
eventStreamMessageContentType = "application/cbor",
validTestStruct = base64Encode(createCborFromJson(restJsonTestCase.validTestStruct)),
validMessageWithNoHeaderPayloadTraits = base64Encode(createCborFromJson(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
validTestUnion = base64Encode(createCborFromJson(restJsonTestCase.validTestUnion)),
validSomeError = base64Encode(createCborFromJson(restJsonTestCase.validSomeError)),
validUnmodeledError = base64Encode(createCborFromJson(restJsonTestCase.validUnmodeledError)),
protocolBuilder = { RpcV2Cbor(it) },
),
//
// awsJson1_1
//
TestCase(
restJsonTestCase.copy(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml
Expand Down
Loading

0 comments on commit 191c577

Please sign in to comment.