From 2296ab9489bcfb9500c5ce6d581fae03adcb7b70 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Fri, 26 Jul 2024 16:33:08 -0400 Subject: [PATCH] Support Constraint Traits & Sparse collections of collections --- .../generators/StructureGeneratorTest.kt | 6 - codegen-serde/build.gradle.kts | 1 + .../codegen/serde/SerializeImplGenerator.kt | 141 ++++++++++++++---- .../smithy/rust/codegen/serde/Traits.kt | 2 +- .../rust/codegen/serde/SerdeDecoratorTest.kt | 68 ++++++++- .../codegen/serde/SerdeProtocolTestTest.kt | 37 +++++ rust-runtime/aws-smithy-types/Cargo.toml | 2 +- .../aws-smithy-types/src/byte_stream.rs | 7 + 8 files changed, 219 insertions(+), 45 deletions(-) create mode 100644 codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeProtocolTestTest.kt diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index c962a1d959d..d8a4fed3ec4 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -41,12 +41,6 @@ class StructureGeneratorTest { ts: Timestamp, inner: Inner, byteValue: Byte, - sensitiveMap: SensitiveMap - } - - map SensitiveMap { - key: String - value: Password } // Intentionally empty diff --git a/codegen-serde/build.gradle.kts b/codegen-serde/build.gradle.kts index 99d80e500a2..e011de24814 100644 --- a/codegen-serde/build.gradle.kts +++ b/codegen-serde/build.gradle.kts @@ -64,6 +64,7 @@ if (isTestingEnabled.toBoolean()) { runtimeOnly(project(":rust-runtime")) testImplementation("org.junit.jupiter:junit-jupiter:5.6.1") testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") + testImplementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") } diff --git a/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/SerializeImplGenerator.kt b/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/SerializeImplGenerator.kt index 206ffc73dad..ef714f9552f 100644 --- a/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/SerializeImplGenerator.kt +++ b/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/SerializeImplGenerator.kt @@ -7,19 +7,26 @@ package software.amazon.smithy.rust.codegen.serde import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.SensitiveTrait +import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.model.traits.StreamingTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.map import software.amazon.smithy.rust.codegen.core.rustlang.plus import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -28,6 +35,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.contextName import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator @@ -44,6 +52,8 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTrait class SerializeImplGenerator(private val codegenContext: CodegenContext) { + private val model = codegenContext.model + fun generateRootSerializerForShape(shape: Shape): Writable = serializerFn(shape, null) /** @@ -59,17 +69,32 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { shape: Shape, applyTo: Writable?, ): Writable { + if (shape is ServiceShape) { + return shape.operations.map { serializerFn(model.expectShape(it), null) }.join("\n") + } else if (shape is OperationShape) { + return writable { + serializerFn(model.expectShape(shape.inputShape), null)(this) + serializerFn(model.expectShape(shape.outputShape), null)(this) + } + } val name = codegenContext.symbolProvider.shapeFunctionName(codegenContext.serviceShape, shape) + "_serde" val deps = when (shape) { is StructureShape -> RuntimeType.forInlineFun(name, SerdeModule, structSerdeImpl(shape)) is UnionShape -> RuntimeType.forInlineFun(name, SerdeModule, serializeUnionImpl(shape)) is TimestampShape -> serializeDateTime(shape) - is BlobShape -> serializeBlob(shape) - is StringShape, is NumberShape -> directSerde(shape) + is BlobShape -> + if (shape.hasTrait()) { + serializeByteStream(shape) + } else { + serializeBlob(shape) + } + + is StringShape, is NumberShape, is BooleanShape -> directSerde(shape) is DocumentShape -> serializeDocument(shape) else -> null } + return writable { val wrapper = when { @@ -82,10 +107,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { if (wrapper != null && applyTo != null) { rustTemplate( "&#{wrapper}(#{applyTo})", "wrapper" to wrapper, - "applyTo" to - applyTo.letIf(shape.hasConstraintTrait()) { - it.plus { rust("") } - }, + "applyTo" to applyTo, ) } else { deps?.toSymbol().also { addDependency(it) } @@ -96,38 +118,63 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { private fun serializeMap(shape: MapShape): RuntimeType = serializeWithWrapper(shape) { value -> + val member = serializeMember(shape.value, "v") + val writeEntry = + writable { + when (shape.hasTrait()) { + true -> + rust( + """ + match v { + Some(v) => map.serialize_entry(k, #T)?, + None => map.serialize_entry(k, &None::)? + }; + """, + member, + ) + false -> rust("map.serialize_entry(k, #T)?;", member) + } + } writable { rustTemplate( """ use #{serde}::ser::SerializeMap; let mut map = serializer.serialize_map(Some(#{value}.len()))?; for (k, v) in #{value}.iter() { - map.serialize_entry(k, #{member})?; + #{writeEntry} } map.end() """, *SupportStructures.codegenScope, "value" to value, - "member" to serializeMember(shape.value, "v"), + "writeEntry" to writeEntry, ) } } private fun serializeList(shape: CollectionShape): RuntimeType = serializeWithWrapper(shape) { value -> + val member = serializeMember(shape.member, "v") + val serializeElement = + writable { + when (shape.hasTrait()) { + false -> rust("seq.serialize_element(#T)?;", member) + true -> rust("match v { Some(v) => seq.serialize_element(#T)?, None => seq.serialize_element(&None::)? };", member) + } + } writable { rustTemplate( """ use #{serde}::ser::SerializeSeq; let mut seq = serializer.serialize_seq(Some(#{value}.len()))?; for v in #{value}.iter() { - seq.serialize_element(#{member})?; + #{element} } seq.end() """, *SupportStructures.codegenScope, + "element" to serializeElement, "value" to value, - "member" to serializeMember(shape.member, "v"), ) } } @@ -136,7 +183,10 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { * Serialize a type that already implements `Serialize` directly via `value.serialize(serializer)` */ private fun directSerde(shape: Shape): RuntimeType { - return RuntimeType.forInlineFun(codegenContext.symbolProvider.toSymbol(shape).rustType().toString(), SerdeModule) { + return RuntimeType.forInlineFun( + codegenContext.symbolProvider.toSymbol(shape).rustType().toString(), + SerdeModule, + ) { implSerializeConfigured(codegenContext.symbolProvider.toSymbol(shape)) { val baseValue = writable { rust("self.value") }.letIf(shape.isStringShape) { it.plus(writable(".as_str()")) } @@ -179,7 +229,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { // awkward hack to recover the symbol referring to the type val type = SerdeModule.toType().resolve(name).toSymbol() val base = - writable { rust("self.value.0") }.letIf(shape.hasConstraintTrait()) { + writable { rust("self.value.0") }.letIf(shape.hasConstraintTrait() && codegenContext.target == CodegenTarget.SERVER) { it.plus { rust(".0") } } val serialization = @@ -201,7 +251,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { } private fun Shape.unwrapConstraints(): Boolean = - hasConstraintTrait() && (isBlobShape || isTimestampShape || isDocumentShape) + codegenContext.target == CodegenTarget.SERVER && hasConstraintTrait() && (isBlobShape || isTimestampShape || isDocumentShape) /** * Serialize the field of a structure, union, list or map. @@ -240,27 +290,30 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { """, *SupportStructures.codegenScope, ) + Attribute.AllowUnusedMut.render(this) rust( "let mut s = serializer.serialize_struct(${ shape.contextName(codegenContext.serviceShape).dq() }, ${shape.members().size})?;", ) - rust("let inner = &self.value;") - for (member in shape.members()) { - val serializedName = member.memberName.dq() - val fieldName = codegenContext.symbolProvider.toMemberName(member) - val field = safeName("member") - val fieldSerialization = - writable { - rustTemplate( - "s.serialize_field($serializedName, #{member})?;", - "member" to serializeMember(member, field), - ) + if (!shape.members().isEmpty()) { + rust("let inner = &self.value;") + for (member in shape.members()) { + val serializedName = member.memberName.dq() + val fieldName = codegenContext.symbolProvider.toMemberName(member) + val field = safeName("member") + val fieldSerialization = + writable { + rustTemplate( + "s.serialize_field($serializedName, #{member})?;", + "member" to serializeMember(member, field), + ) + } + if (codegenContext.symbolProvider.toSymbol(member).isOptional()) { + rust("if let Some($field) = &inner.$fieldName { #T }", fieldSerialization) + } else { + rust("let $field = &inner.$fieldName; #T", fieldSerialization) } - if (codegenContext.symbolProvider.toSymbol(member).isOptional()) { - rust("if let Some($field) = &inner.$fieldName { #T }", fieldSerialization) - } else { - rust("let $field = &inner.$fieldName; #T", fieldSerialization) } } rust("s.end()") @@ -285,10 +338,14 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { "${symbolProvider.toMemberName(member)}(inner)" } withBlock("#T::$variantName => {", "},", symbolProvider.toSymbol(shape)) { - rustTemplate( - "serializer.serialize_newtype_variant(${unionName.dq()}, $index, $fieldName, #{member})", - "member" to serializeMember(member, "inner"), - ) + when (member.isTargetUnit()) { + true -> rust("serializer.serialize_unit_variant(${unionName.dq()}, $index, $fieldName)") + false -> + rustTemplate( + "serializer.serialize_newtype_variant(${unionName.dq()}, $index, $fieldName, #{member})", + "member" to serializeMember(member, "inner"), + ) + } } } if (codegenContext.target.renderUnknownVariant()) { @@ -325,6 +382,26 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) { } } + private fun serializeByteStream(shape: BlobShape): RuntimeType = + RuntimeType.forInlineFun("SerializeByteStream", SerdeModule) { + implSerializeConfigured(RuntimeType.byteStream(codegenContext.runtimeConfig).toSymbol()) { + // This doesn't work yet—there is no way to get data out of a ByteStream from a sync context + rustTemplate( + """ + let Some(bytes) = self.value.bytes() else { + return serializer.serialize_str("streaming data") + }; + if serializer.is_human_readable() { + serializer.serialize_str(&#{base64_encode}(bytes)) + } else { + serializer.serialize_bytes(&bytes) + } + """, + "base64_encode" to RuntimeType.base64Encode(codegenContext.runtimeConfig), + ) + } + } + private fun serializeDocument(shape: DocumentShape): RuntimeType = RuntimeType.forInlineFun("SerializeDocument", SerdeModule) { implSerializeConfigured(codegenContext.symbolProvider.toSymbol(shape)) { diff --git a/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/Traits.kt b/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/Traits.kt index 9c11c13831d..448bd4549a3 100644 --- a/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/Traits.kt +++ b/codegen-serde/src/main/kotlin/software/amazon/smithy/rust/codegen/serde/Traits.kt @@ -12,7 +12,7 @@ import software.amazon.smithy.model.traits.AbstractTrait import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.util.orNull -class SerdeTrait private constructor( +class SerdeTrait constructor( private val serialize: Boolean, private val deserialize: Boolean, private val tag: String?, diff --git a/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeDecoratorTest.kt b/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeDecoratorTest.kt index cbff52ad585..c24cb264883 100644 --- a/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeDecoratorTest.kt +++ b/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeDecoratorTest.kt @@ -19,6 +19,8 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class SerdeDecoratorTest { + private val params = + IntegrationTestParams(cargoCommand = "cargo test --all-features", service = "com.example#HelloService") private val simpleModel = """ namespace com.example @@ -27,13 +29,28 @@ class SerdeDecoratorTest { use smithy.framework#ValidationException @awsJson1_0 service HelloService { - operations: [SayHello, SayGoodBye], + operations: [SayHello, SayGoodBye, Streaming], version: "1" } operation SayHello { input: TestInput errors: [ValidationException] } + + @serde + operation Streaming { + input: StreamingInput + errors: [ValidationException] + } + + structure StreamingInput { + @required + data: StreamingBlob + } + + @streaming + blob StreamingBlob + @serde structure TestInput { foo: SensitiveString, @@ -85,7 +102,8 @@ class SerdeDecoratorTest { @sensitive union U { nested: Nested, - enum: TestEnum + enum: TestEnum, + other: Unit } structure Nested { @@ -95,6 +113,7 @@ class SerdeDecoratorTest { notSensitive: AlsoTimestamps, manyEnums: TestEnumList, sparse: SparseList + map: SparseMap } list TestEnumList { @@ -111,6 +130,12 @@ class SerdeDecoratorTest { value: Timestamp } + @sparse + map SparseMap { + key: String + value: SparseList + } + @sensitive timestamp SensitiveTimestamp @@ -127,7 +152,7 @@ class SerdeDecoratorTest { @Test fun generateSerializersThatWorkServer() { - serverIntegrationTest(simpleModel, params = IntegrationTestParams(cargoCommand = "cargo test --all-features")) { ctx, crate -> + serverIntegrationTest(simpleModel, params = params) { ctx, crate -> val codegenScope = arrayOf( "crate" to RustType.Opaque(ctx.moduleUseName()), @@ -174,6 +199,22 @@ class SerdeDecoratorTest { ) } + unitTest("serde_of_bytestream") { + rustTemplate( + """ + use #{crate}::input::StreamingInput; + use #{crate}::types::ByteStream; + use #{crate}::serde::*; + let input = StreamingInput::builder().data(ByteStream::from_static(b"123")).build().unwrap(); + let settings = SerializationSettings::default(); + let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize"); + assert_eq!(serialized, ${expectedStreaming.dq()}); + + """, + *codegenScope, + ) + } + unitTest("delegated_serde") { rustTemplate( """ @@ -241,7 +282,7 @@ class SerdeDecoratorTest { "manyEnums": [ "" ], - "sparse": ["", "", ""] + "sparse": [null, "", ""] }, "union": "", "document": "hello!", @@ -249,9 +290,11 @@ class SerdeDecoratorTest { } """.replace("\\s".toRegex(), "") + private val expectedStreaming = """{"data":"MTIz"}""" + @Test fun generateSerializersThatWorkClient() { - clientIntegrationTest(simpleModel, params = IntegrationTestParams(cargoCommand = "cargo test --all-features")) { ctx, crate -> + clientIntegrationTest(simpleModel, params = params) { ctx, crate -> val codegenScope = arrayOf( "crate" to RustType.Opaque(ctx.moduleUseName()), @@ -295,6 +338,21 @@ class SerdeDecoratorTest { ) } + unitTest("serde_of_bytestream") { + rustTemplate( + """ + use #{crate}::operation::streaming::StreamingInput; + use #{crate}::primitives::ByteStream; + use #{crate}::serde::*; + let input = StreamingInput::builder().data(ByteStream::from_static(b"123")).build().unwrap(); + let settings = SerializationSettings::default(); + let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize"); + assert_eq!(serialized, ${expectedStreaming.dq()}); + """, + *codegenScope, + ) + } + unitTest("delegated_serde") { rustTemplate( """ diff --git a/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeProtocolTestTest.kt b/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeProtocolTestTest.kt new file mode 100644 index 00000000000..692c026e67c --- /dev/null +++ b/codegen-serde/src/test/kotlin/software/amazon/smithy/rust/codegen/serde/SerdeProtocolTestTest.kt @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.serde + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.SourceLocation +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.util.letIf + +class SerdeProtocolTestTest { + @Test + fun testSmithyModels() { + val serviceShapeId = ShapeId.from(ServiceShapeId.REST_JSON) + var model = Model.assembler().discoverModels().assemble().result.get() + val service = + model.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder().addTrait( + SerdeTrait(true, false, null, null, SourceLocation.NONE), + ).build() + model = + ModelTransformer.create().mapShapes(model) { serviceShape -> + serviceShape.letIf(serviceShape.id == serviceShapeId) { + service + } + } + clientIntegrationTest(model, IntegrationTestParams(service = ServiceShapeId.REST_JSON, cargoCommand = "cargo test --all-features")) { clientCodegenContext, rustCrate -> + } + } +} diff --git a/rust-runtime/aws-smithy-types/Cargo.toml b/rust-runtime/aws-smithy-types/Cargo.toml index 46433f3be9d..fe5c6d8872e 100644 --- a/rust-runtime/aws-smithy-types/Cargo.toml +++ b/rust-runtime/aws-smithy-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-types" -version = "1.2.0" +version = "1.2.1" authors = [ "AWS Rust SDK Team ", "Russell Cohen ", diff --git a/rust-runtime/aws-smithy-types/src/byte_stream.rs b/rust-runtime/aws-smithy-types/src/byte_stream.rs index 1c0df424832..dbe28edb872 100644 --- a/rust-runtime/aws-smithy-types/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-types/src/byte_stream.rs @@ -322,6 +322,13 @@ impl ByteStream { self.next().await.transpose() } + /// Returns a reference to the data if it is already available in memory + pub fn bytes(&self) -> Option<&[u8]> { + match &self.inner { + Inner { body } => body.bytes(), + } + } + /// Return the bounds on the remaining length of the `ByteStream`. pub fn size_hint(&self) -> (u64, Option) { self.inner.size_hint()