diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt new file mode 100644 index 0000000000..5e47701f55 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator + +private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { + rust("#T::from($data)", enumSymbol) +} + +fun clientInstantiator(codegenContext: CodegenContext) = + Instantiator( + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ::enumFromStringFn, + ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index a3e1ed6f9d..6f670b641f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait +import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata @@ -33,9 +34,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.testutil.TokioTest @@ -66,9 +65,7 @@ class ProtocolTestGenerator( private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape) private val operationIndex = OperationIndex.of(codegenContext.model) - private val instantiator = with(codegenContext) { - Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - } + private val instantiator = clientInstantiator(codegenContext) private val codegenScope = arrayOf( "SmithyHttp" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 1e44de1aa6..912a0f668a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -332,7 +332,7 @@ class HttpBoundProtocolTraitImplGenerator( } } - val err = if (StructureGenerator.fallibleBuilder(outputShape, symbolProvider)) { + val err = if (StructureGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { ".map_err(${format(errorSymbol)}::unhandled)?" } else "" diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt new file mode 100644 index 0000000000..e55f98c0d3 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt @@ -0,0 +1,87 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.lookup + +internal class ClientInstantiatorTest { + private val model = """ + namespace com.test + + @enum([ + { value: "t2.nano" }, + { value: "t2.micro" }, + ]) + string UnnamedEnum + + @enum([ + { + value: "t2.nano", + name: "T2_NANO", + }, + { + value: "t2.micro", + name: "T2_MICRO", + }, + ]) + string NamedEnum + """.asSmithyModel() + + private val codegenContext = testCodegenContext(model) + private val symbolProvider = codegenContext.symbolProvider + + @Test + fun `generate named enums`() { + val shape = model.lookup("com.test#NamedEnum") + val sut = clientInstantiator(codegenContext) + val data = Node.parse("t2.nano".dq()) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + unitTest("generate_named_enums") { + withBlock("let result = ", ";") { + sut.render(this, shape, data) + } + rust("assert_eq!(result, NamedEnum::T2Nano);") + } + } + project.compileAndTest() + } + + @Test + fun `generate unnamed enums`() { + val shape = model.lookup("com.test#UnnamedEnum") + val sut = clientInstantiator(codegenContext) + val data = Node.parse("t2.nano".dq()) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + unitTest("generate_unnamed_enums") { + withBlock("let result = ", ";") { + sut.render(this, shape, data) + } + rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""") + } + } + project.compileAndTest() + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 5ff4f33dd1..9ae2da62ad 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -79,7 +79,7 @@ class BuilderGenerator( } private fun renderBuildFn(implBlockWriter: RustWriter) { - val fallibleBuilder = StructureGenerator.fallibleBuilder(shape, symbolProvider) + val fallibleBuilder = StructureGenerator.hasFallibleBuilder(shape, symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) val returnType = when (fallibleBuilder) { true -> "Result<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index a869e73255..a58c0c65be 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node @@ -29,9 +30,11 @@ import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait +import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.escape @@ -40,7 +43,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider @@ -49,57 +51,48 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectMember import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.letIf /** - * Instantiator generates code to instantiate a given Shape given a `Node` representing the value + * Instantiator generates code to instantiate a given Shape given a `Node` representing the value. * - * This is primarily used during Protocol test generation + * This is only used during protocol test generation. */ -class Instantiator( +open class Instantiator( private val symbolProvider: RustSymbolProvider, private val model: Model, private val runtimeConfig: RuntimeConfig, - private val target: CodegenTarget, + /** + * A function that given a symbol for an enum shape and a string, returns a writable to instantiate the enum with + * the string value. + **/ + private val enumFromStringFn: (Symbol, String) -> Writable, + /** Fill out required fields with a default value **/ + private val defaultsForRequiredFields: Boolean = false, ) { data class Ctx( - // The Rust HTTP library lower cases headers but Smithy protocol tests - // contain httpPrefix headers with uppercase keys - val lowercaseMapKeys: Boolean, - val streaming: Boolean, - // Whether we are instantiating with a Builder, in which case all setters take Option - val builder: Boolean, - // Fill out `required` fields with a default value. - val defaultsForRequiredFields: Boolean, + // The `http` crate requires that headers be lowercase, but Smithy protocol tests + // contain headers with uppercase keys. + val lowercaseMapKeys: Boolean = false, ) - companion object { - fun defaultContext() = Ctx(lowercaseMapKeys = false, streaming = false, builder = false, defaultsForRequiredFields = false) - } - - fun render( - writer: RustWriter, - shape: Shape, - arg: Node, - ctx: Ctx = defaultContext(), - ) { + fun render(writer: RustWriter, shape: Shape, data: Node, ctx: Ctx = Ctx()) { when (shape) { // Compound Shapes - is StructureShape -> renderStructure(writer, shape, arg as ObjectNode, ctx.copy(builder = true)) - is UnionShape -> renderUnion(writer, shape, arg as ObjectNode, ctx) + is StructureShape -> renderStructure(writer, shape, data as ObjectNode, ctx) + is UnionShape -> renderUnion(writer, shape, data as ObjectNode, ctx) // Collections - is ListShape -> renderList(writer, shape, arg as ArrayNode, ctx) - is MapShape -> renderMap(writer, shape, arg as ObjectNode, ctx) - is SetShape -> renderSet(writer, shape, arg as ArrayNode, ctx) + is ListShape -> renderList(writer, shape, data as ArrayNode, ctx) + is MapShape -> renderMap(writer, shape, data as ObjectNode, ctx) + is SetShape -> renderSet(writer, shape, data as ArrayNode, ctx) // Members, supporting potentially optional members - is MemberShape -> renderMember(writer, shape, arg, ctx) + is MemberShape -> renderMember(writer, shape, data, ctx) // Wrapped Shapes - is TimestampShape -> writer.write( - "#T::from_secs(${(arg as NumberNode).value})", + is TimestampShape -> writer.rust( + "#T::from_secs(${(data as NumberNode).value})", RuntimeType.DateTime(runtimeConfig), ) @@ -108,38 +101,40 @@ class Instantiator( * Blob::new("arg") * ``` */ - is BlobShape -> if (ctx.streaming) { - writer.write( - "#T::from_static(b${(arg as StringNode).value.dq()})", + is BlobShape -> if (shape.hasTrait()) { + writer.rust( + "#T::from_static(b${(data as StringNode).value.dq()})", RuntimeType.ByteStream(runtimeConfig), ) } else { - writer.write( - "#T::new(${(arg as StringNode).value.dq()})", + writer.rust( + "#T::new(${(data as StringNode).value.dq()})", RuntimeType.Blob(runtimeConfig), ) } // Simple Shapes - is StringShape -> renderString(writer, shape, arg as StringNode) - is NumberShape -> when (arg) { + is StringShape -> renderString(writer, shape, data as StringNode) + is NumberShape -> when (data) { is StringNode -> { val numberSymbol = symbolProvider.toSymbol(shape) // support Smithy custom values, such as Infinity writer.rust( - """<#T as #T>::parse_smithy_primitive(${arg.value.dq()}).expect("invalid string for number")""", + """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", numberSymbol, CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse"), ) } - is NumberNode -> writer.write(arg.value) + + is NumberNode -> writer.write(data.value) } - is BooleanShape -> writer.write(arg.asBooleanNode().get().toString()) + + is BooleanShape -> writer.rust(data.asBooleanNode().get().toString()) is DocumentShape -> writer.rustBlock("") { val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() rustTemplate( """ - let json_bytes = br##"${Node.prettyPrintJson(arg)}"##; + let json_bytes = br##"${Node.prettyPrintJson(data)}"##; let mut tokens = #{json_token_iter}(json_bytes).peekable(); #{expect_document}(&mut tokens).expect("well formed json") """, @@ -147,26 +142,30 @@ class Instantiator( "json_token_iter" to smithyJson.member("deserialize::json_token_iter"), ) } - else -> writer.writeWithNoFormatting("todo!() /* $shape $arg */") + + else -> writer.writeWithNoFormatting("todo!() /* $shape $data */") } } /** - * If the shape is optional: `Some(inner)` or `None` - * otherwise: `inner` + * If the shape is optional: `Some(inner)` or `None`. + * Otherwise: `inner`. */ - private fun renderMember(writer: RustWriter, shape: MemberShape, arg: Node, ctx: Ctx) { - val target = model.expectShape(shape.target) - val symbol = symbolProvider.toSymbol(shape) - if (arg is NullNode) { - check( - symbol.isOptional(), - ) { "A null node was provided for $shape but the symbol was not optional. This is invalid input data." } - writer.write("None") + private fun renderMember(writer: RustWriter, memberShape: MemberShape, data: Node, ctx: Ctx) { + val targetShape = model.expectShape(memberShape.target) + val symbol = symbolProvider.toSymbol(memberShape) + if (data is NullNode) { + check(symbol.isOptional()) { + "A null node was provided for $memberShape but the symbol was not optional. This is invalid input data." + } + writer.rust("None") } else { + // Structure builder setters for structure shape members _always_ take in `Option`. + // Other aggregate shapes' members are optional only when their symbol is. writer.conditionalBlock( - "Some(", ")", - conditional = ctx.builder || symbol.isOptional(), + "Some(", + ")", + conditional = model.expectShape(memberShape.container) is StructureShape || symbol.isOptional(), ) { writer.conditionalBlock( "Box::new(", @@ -175,13 +174,11 @@ class Instantiator( ) { render( this, - target, - arg, - ctx.copy(builder = false) - .letIf(shape.getMemberTrait(model, HttpPrefixHeadersTrait::class.java).isPresent) { + targetShape, + data, + ctx.copy() + .letIf(memberShape.hasTrait()) { it.copy(lowercaseMapKeys = true) - }.letIf(shape.isStreaming(model)) { - it.copy(streaming = true) }, ) } @@ -189,9 +186,7 @@ class Instantiator( } } - private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) { - renderList(writer, shape, data, ctx) - } + private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) = renderList(writer, shape, data, ctx) /** * ```rust @@ -201,13 +196,14 @@ class Instantiator( * ret.insert("k2", ...); * ret * } + * ``` */ private fun renderMap(writer: RustWriter, shape: MapShape, data: ObjectNode, ctx: Ctx) { if (data.members.isEmpty()) { - writer.write("#T::new()", RustType.HashMap.RuntimeType) + writer.rust("#T::new()", RustType.HashMap.RuntimeType) } else { writer.rustBlock("") { - write("let mut ret = #T::new();", RustType.HashMap.RuntimeType) + rust("let mut ret = #T::new();", RustType.HashMap.RuntimeType) for ((key, value) in data.members) { withBlock("ret.insert(", ");") { renderMember(this, shape.key, key, ctx) @@ -218,7 +214,7 @@ class Instantiator( renderMember(this, shape.value, value, ctx) } } - write("ret") + rust("ret") } } } @@ -231,7 +227,7 @@ class Instantiator( private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) { val unionSymbol = symbolProvider.toSymbol(shape) - val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) { + val variant = if (defaultsForRequiredFields && data.members.isEmpty()) { val (name, memberShape) = shape.allMembers.entries.first() val targetShape = model.expectShape(memberShape.target) Node.from(name) to fillDefaultValue(targetShape) @@ -243,8 +239,8 @@ class Instantiator( val memberName = variant.first.value val member = shape.expectMember(memberName) - writer.write("#T::${symbolProvider.toMemberName(member)}", unionSymbol) - // unions should specify exactly one member + writer.rust("#T::${symbolProvider.toMemberName(member)}", unionSymbol) + // Unions should specify exactly one member. writer.withBlock("(", ")") { renderMember(this, member, variant.second, ctx) } @@ -259,7 +255,7 @@ class Instantiator( writer.withBlock("vec![", "]") { data.elements.forEach { v -> renderMember(this, shape.member, v, ctx) - write(",") + rust(",") } } } @@ -267,14 +263,10 @@ class Instantiator( private fun renderString(writer: RustWriter, shape: StringShape, arg: StringNode) { val data = writer.escape(arg.value).dq() if (!shape.hasTrait()) { - writer.rust("$data.to_string()") + writer.rust("$data.to_owned()") } else { val enumSymbol = symbolProvider.toSymbol(shape) - if (target == CodegenTarget.SERVER) { - writer.rust("""#T::try_from($data).expect("This is used in tests ONLY")""", enumSymbol) - } else { - writer.rust("#T::from($data)", enumSymbol) - } + writer.rustTemplate("#{EnumFromStringFn:W}", "EnumFromStringFn" to enumFromStringFn(enumSymbol, data)) } } @@ -290,11 +282,11 @@ class Instantiator( } } - writer.write("#T::builder()", symbolProvider.toSymbol(shape)) - if (ctx.defaultsForRequiredFields) { + writer.rust("#T::builder()", symbolProvider.toSymbol(shape)) + if (defaultsForRequiredFields) { shape.allMembers.entries .filter { (name, memberShape) -> - memberShape.isRequired && !data.members.containsKey(Node.from(name)) + !symbolProvider.toSymbol(memberShape).isOptional() && !data.members.containsKey(Node.from(name)) } .forEach { (_, memberShape) -> renderMemberHelper(memberShape, fillDefaultValue(memberShape)) @@ -305,9 +297,9 @@ class Instantiator( val memberShape = shape.expectMember(key.value) renderMemberHelper(memberShape, value) } - writer.write(".build()") - if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { - writer.write(".unwrap()") + writer.rust(".build()") + if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { + writer.rust(".unwrap()") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt index 9af2df8091..cbd0a1395c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt @@ -70,7 +70,7 @@ open class StructureGenerator( companion object { /** Returns whether a structure shape requires a fallible builder to be generated. */ - fun fallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = + fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = // All operation inputs should have fallible builders in case a new required field is added in the future. structureShape.hasTrait() || structureShape diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index c56484e969..43c18287a5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -413,7 +413,7 @@ class JsonParserGenerator( rustTemplate("let mut builder = #{Shape}::builder();", *codegenScope, "Shape" to symbol) deserializeStructInner(shape.members()) withBlock("Ok(Some(builder.build()", "))") { - if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { rustTemplate( """.map_err(|err| #{Error}::new( #{ErrorReason}::Custom(format!("{}", err).into()), None) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 3805c67c80..841975c229 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -476,7 +476,7 @@ class XmlBindingTraitParserGenerator( rust("let _ = decoder;") } withBlock("Ok(builder.build()", ")") { - if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { // NOTE:(rcoh) This branch is unreachable given the current nullability rules. // Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed // (because they're inputs, only outputs will be parsed!) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 848762ff18..f5f08e9076 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -101,7 +101,12 @@ internal fun testCodegenContext( /** * In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder. */ -fun StructureShape.renderWithModelBuilder(model: Model, symbolProvider: RustSymbolProvider, writer: RustWriter, forWhom: CodegenTarget = CodegenTarget.CLIENT) { +fun StructureShape.renderWithModelBuilder( + model: Model, + symbolProvider: RustSymbolProvider, + writer: RustWriter, + forWhom: CodegenTarget = CodegenTarget.CLIENT, +) { StructureGenerator(model, symbolProvider, writer, this).render(forWhom) val modelBuilder = BuilderGenerator(model, symbolProvider, this) modelBuilder.render(writer) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index c6f592adeb..73dbe86d5b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -6,31 +6,32 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.StringNode import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.raw +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup class InstantiatorTest { private val model = """ namespace com.test + @documentation("this documents the shape") structure MyStruct { foo: String, @@ -69,29 +70,6 @@ class InstantiatorTest { value: Integer } - structure MyStructRequired { - @required - str: String, - @required - primitiveInt: PrimitiveInteger, - @required - int: Integer, - @required - ts: Timestamp, - @required - byte: Byte - @required - union: NestedUnion, - @required - structure: NestedStruct, - @required - list: MyList, - @required - map: NestedMap, - @required - doc: Document - } - union NestedUnion { struct: NestedStruct, int: Integer @@ -108,216 +86,176 @@ class InstantiatorTest { private val symbolProvider = testSymbolProvider(model) private val runtimeConfig = TestRuntimeConfig - fun RustWriter.test(block: Writable) { - raw("#[test]") - rustBlock("fn inst()") { - block() - } - } + private fun enumFromStringFn(symbol: Symbol, data: String) = writable { } @Test fun `generate unions`() { val union = model.lookup("com.test#MyUnion") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - val data = Node.parse( - """{ - "stringVariant": "ok!" - }""", - ) - val writer = RustWriter.forModule("model") - UnionGenerator(model, symbolProvider, writer, union).render() - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(this, union, data) + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val data = Node.parse("""{ "stringVariant": "ok!" }""") + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + UnionGenerator(model, symbolProvider, this, union).render() + unitTest("generate_unions") { + withBlock("let result = ", ";") { + sut.render(this, union, data) + } + rust("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_owned()));") } - writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));") } + project.compileAndTest() } @Test fun `generate struct builders`() { val structure = model.lookup("com.test#MyStruct") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) val data = Node.parse("""{ "bar": 10, "foo": "hello" }""") - val writer = RustWriter.forModule("model") - structure.renderWithModelBuilder(model, symbolProvider, writer) - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(this, structure, data) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + structure.renderWithModelBuilder(model, symbolProvider, this) + unitTest("generate_struct_builders") { + withBlock("let result = ", ";") { + sut.render(this, structure, data) + } + rust( + """ + assert_eq!(result.bar, 10); + assert_eq!(result.foo.unwrap(), "hello"); + """, + ) } - writer.write("assert_eq!(result.bar, 10);") - writer.write("assert_eq!(result.foo.unwrap(), \"hello\");") } - writer.compileAndTest() + project.compileAndTest() } @Test fun `generate builders for boxed structs`() { val structure = model.lookup("com.test#WithBox") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) val data = Node.parse( - """ { - "member": { - "member": { } - }, "value": 10 + """ + { + "member": { + "member": { } + }, + "value": 10 } - """.trimIndent(), + """, ) - val writer = RustWriter.forModule("model") - structure.renderWithModelBuilder(model, symbolProvider, writer) - writer.test { - withBlock("let result = ", ";") { - sut.render(this, structure, data) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + structure.renderWithModelBuilder(model, symbolProvider, this) + unitTest("generate_builders_for_boxed_structs") { + withBlock("let result = ", ";") { + sut.render(this, structure, data) + } + rust( + """ + assert_eq!(result, WithBox { + value: Some(10), + member: Some(Box::new(WithBox { + value: None, + member: Some(Box::new(WithBox { value: None, member: None })), + })) + }); + """, + ) } - rust( - """ - assert_eq!(result, WithBox { - value: Some(10), - member: Some(Box::new(WithBox { - value: None, - member: Some(Box::new(WithBox { value: None, member: None })), - })) - }); - """, - ) } - writer.compileAndTest() + project.compileAndTest() } @Test fun `generate lists`() { - val data = Node.parse( - """ [ - "bar", - "foo" - ] - """, - ) - val writer = RustWriter.forModule("lib") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(writer, model.lookup("com.test#MyList"), data) + val data = Node.parse("""["bar", "foo"]""") + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + unitTest("generate_lists") { + withBlock("let result = ", ";") { + sut.render(this, model.lookup("com.test#MyList"), data) + } + rust("""assert_eq!(result, vec!["bar".to_owned(), "foo".to_owned()]);""") } - writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""") } - writer.compileAndTest() + project.compileAndTest() } @Test fun `generate sparse lists`() { - val data = Node.parse( - """ [ - "bar", - "foo", - null - ] - """, - ) - val writer = RustWriter.forModule("lib") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(writer, model.lookup("com.test#MySparseList"), data) + val data = Node.parse(""" [ "bar", "foo", null ] """) + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + unitTest("generate_sparse_lists") { + withBlock("let result = ", ";") { + sut.render(this, model.lookup("com.test#MySparseList"), data) + } + rust("""assert_eq!(result, vec![Some("bar".to_owned()), Some("foo".to_owned()), None]);""") } - writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""") } - writer.compileAndTest() + project.compileAndTest() } @Test fun `generate maps of maps`() { val data = Node.parse( - """{ - "k1": { "map": {} }, - "k2": { "map": { "k3": {} } }, - "k3": { } + """ + { + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } } """, ) - val writer = RustWriter.forModule("model") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - val inner: StructureShape = model.lookup("com.test#Inner") - inner.renderWithModelBuilder(model, symbolProvider, writer) - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(writer, model.lookup("com.test#NestedMap"), data) + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val inner = model.lookup("com.test#Inner") + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + inner.renderWithModelBuilder(model, symbolProvider, this) + unitTest("generate_maps_of_maps") { + withBlock("let result = ", ";") { + sut.render(this, model.lookup("com.test#NestedMap"), data) + } + rust( + """ + assert_eq!(result.len(), 3); + assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); + assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); + assert_eq!(result.get("k3").unwrap().map, None); + """, + ) } - writer.write( - """ - assert_eq!(result.len(), 3); - assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); - assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); - assert_eq!(result.get("k3").unwrap().map, None); - """, - ) } - writer.compileAndTest(clippy = true) + project.compileAndTest(runClippy = true) } @Test fun `blob inputs are binary data`() { // "Parameter values that contain binary data MUST be defined using values // that can be represented in plain text (for example, use "foo" and not "Zm9vCg==")." - val writer = RustWriter.forModule("lib") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.CLIENT) - writer.test { - withBlock("let blob = ", ";") { - sut.render( - this, - BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), - StringNode.parse("foo".dq()), - ) - } - write("assert_eq!(std::str::from_utf8(blob.as_ref()).unwrap(), \"foo\");") - } - writer.compileAndTest() - } + val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) - @Test - fun `generate struct with missing required members`() { - val structure = model.lookup("com.test#MyStructRequired") - val inner = model.lookup("com.test#Inner") - val nestedStruct = model.lookup("com.test#NestedStruct") - val union = model.lookup("com.test#NestedUnion") - val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) - val data = Node.parse("{}") - val writer = RustWriter.forModule("model") - structure.renderWithModelBuilder(model, symbolProvider, writer) - inner.renderWithModelBuilder(model, symbolProvider, writer) - nestedStruct.renderWithModelBuilder(model, symbolProvider, writer) - UnionGenerator(model, symbolProvider, writer, union).render() - writer.test { - writer.withBlock("let result = ", ";") { - sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + unitTest("blob_inputs_are_binary_data") { + withBlock("let blob = ", ";") { + sut.render( + this, + BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), + StringNode.parse("foo".dq()), + ) + } + rust("assert_eq!(std::str::from_utf8(blob.as_ref()).unwrap(), \"foo\");") } - writer.write( - """ - use std::collections::HashMap; - use aws_smithy_types::{DateTime, Document}; - - let expected = MyStructRequired { - str: Some("".into()), - primitive_int: 0, - int: Some(0), - ts: Some(DateTime::from_secs(0)), - byte: Some(0), - union: Some(NestedUnion::Struct(NestedStruct { - str: Some("".into()), - num: Some(0), - })), - structure: Some(NestedStruct { - str: Some("".into()), - num: Some(0), - }), - list: Some(vec![]), - map: Some(HashMap::new()), - doc: Some(Document::Object(HashMap::new())), - }; - assert_eq!(result, expected); - """, - ) } - writer.compileAndTest() + project.compileAndTest() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt new file mode 100644 index 0000000000..9189380dde --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator + +/** + * Server enums do not have an `Unknown` variant like client enums do, so constructing an enum from + * a string is a fallible operation (hence `try_from`). It's ok to panic here if construction fails, + * since this is only used in protocol tests. + */ +private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { + rust( + """#T::try_from($data).expect("This is used in tests ONLY")""", + enumSymbol, + ) +} + +fun serverInstantiator(codegenContext: CodegenContext) = + Instantiator( + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ::enumFromStringFn, + defaultsForRequiredFields = true, + ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index d68e1cb605..cf70965107 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -39,9 +39,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.testutil.TokioTest @@ -57,6 +55,7 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverInstantiator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator import java.util.logging.Logger import kotlin.reflect.KFunction1 @@ -99,9 +98,7 @@ class ServerProtocolTestGenerator( inputT to outputT } - private val instantiator = with(codegenContext) { - Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER) - } + private val instantiator = serverInstantiator(codegenContext) private val codegenScope = arrayOf( "Bytes" to RuntimeType.Bytes, @@ -174,7 +171,7 @@ class ServerProtocolTestGenerator( operations.withIndex().forEach { val (inputT, outputT) = operationInputOutputTypes[it.value]!! val operationName = operationNames[it.index] - write(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") + rust(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") } } @@ -353,7 +350,7 @@ class ServerProtocolTestGenerator( testModuleWriter.writeWithNoFormatting(testCase.documentation) } - testModuleWriter.write("Test ID: ${testCase.id}") + testModuleWriter.rust("Test ID: ${testCase.id}") testModuleWriter.newlinePrefix = "" TokioTest.render(testModuleWriter) @@ -442,7 +439,7 @@ class ServerProtocolTestGenerator( } writeInline("let output =") instantiator.render(this, shape, testCase.params) - write(";") + rust(";") val operationImpl = if (operationShape.allErrors(model).isNotEmpty()) { if (shape.hasTrait()) { val variant = symbolProvider.toSymbol(shape).name @@ -555,18 +552,13 @@ class ServerProtocolTestGenerator( // Construct a dummy response. withBlock("let response = ", ";") { - instantiator.render( - this, - outputShape, - Node.objectNode(), - Instantiator.defaultContext().copy(defaultsForRequiredFields = true), - ) + instantiator.render(this, outputShape, Node.objectNode()) } if (operationShape.errors.isEmpty()) { - write("response") + rust("response") } else { - write("Ok(response)") + rust("Ok(response)") } } @@ -759,7 +751,7 @@ class ServerProtocolTestGenerator( ) } else { assertOk(rustWriter) { - rustWriter.write( + rustWriter.rust( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", @@ -815,7 +807,7 @@ class ServerProtocolTestGenerator( ) } assertOk(rustWriter) { - write( + rust( "#T($actualExpression, $variableName)", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "validate_headers"), ) @@ -836,7 +828,7 @@ class ServerProtocolTestGenerator( strSlice(this, params) } assertOk(rustWriter) { - rustWriter.write( + rustWriter.rust( "#T($actualExpression, $expectedVariableName)", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, checkFunction), ) @@ -848,14 +840,14 @@ class ServerProtocolTestGenerator( * for pretty printing protocol test helper results */ private fun assertOk(rustWriter: RustWriter, inner: Writable) { - rustWriter.write("#T(", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "assert_ok")) + rustWriter.rust("#T(", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "assert_ok")) inner(rustWriter) - rustWriter.write(");") + rustWriter.rust(");") } private fun strSlice(writer: RustWriter, args: List) { writer.withBlock("&[", "]") { - write(args.joinToString(",") { it.dq() }) + rust(args.joinToString(",") { it.dq() }) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index e98f156f5a..2cf86be5eb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -750,7 +750,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } - val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) { + val err = if (StructureGenerator.hasFallibleBuilder(inputShape, symbolProvider)) { "?" } else "" rustTemplate("input.build()$err", *codegenScope) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index 4159e732c7..0ac168801c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -10,7 +10,6 @@ import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig @@ -70,7 +69,7 @@ fun serverTestCodegenContext( protocolShapeId: ShapeId? = null, ): ServerCodegenContext = ServerCodegenContext( model, - testSymbolProvider(model), + serverTestSymbolProvider(model), serviceShape ?: model.serviceShapes.firstOrNull() ?: ServiceShape.builder().version("test").id("test#Service").build(), diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt new file mode 100644 index 0000000000..ae50d8a0a5 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -0,0 +1,221 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class ServerInstantiatorTest { + // This model started off from the one in `InstantiatorTest.kt` from `codegen-core`. + private val model = """ + namespace com.test + + use smithy.framework#ValidationException + + @documentation("this documents the shape") + structure MyStruct { + foo: String, + @documentation("This *is* documentation about the member.") + bar: PrimitiveInteger, + baz: Integer, + ts: Timestamp, + byteValue: Byte + } + + list MyList { + member: String + } + + @sparse + list MySparseList { + member: String + } + + union MyUnion { + stringVariant: String, + numVariant: Integer + } + + structure Inner { + map: NestedMap + } + + map NestedMap { + key: String, + value: Inner + } + + structure WithBox { + member: WithBox, + value: Integer + } + + union NestedUnion { + struct: NestedStruct, + int: Integer + } + + structure NestedStruct { + @required + str: String, + @required + num: Integer + } + + structure MyStructRequired { + @required + str: String, + @required + primitiveInt: PrimitiveInteger, + @required + int: Integer, + @required + ts: Timestamp, + @required + byte: Byte + @required + union: NestedUnion, + @required + structure: NestedStruct, + @required + list: MyList, + @required + map: NestedMap, + @required + doc: Document + } + + @enum([ + { value: "t2.nano" }, + { value: "t2.micro" }, + ]) + string UnnamedEnum + + @enum([ + { + value: "t2.nano", + name: "T2_NANO", + }, + { + value: "t2.micro", + name: "T2_MICRO", + }, + ]) + string NamedEnum + """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } + + private val codegenContext = serverTestCodegenContext(model) + private val symbolProvider = codegenContext.symbolProvider + + @Test + fun `generate struct with missing required members`() { + val structure = model.lookup("com.test#MyStructRequired") + val inner = model.lookup("com.test#Inner") + val nestedStruct = model.lookup("com.test#NestedStruct") + val union = model.lookup("com.test#NestedUnion") + val sut = serverInstantiator(codegenContext) + val data = Node.parse("{}") + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + structure.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + inner.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + nestedStruct.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + UnionGenerator(model, symbolProvider, this, union).render() + + unitTest("server_instantiator_test") { + withBlock("let result = ", ";") { + sut.render(this, structure, data) + } + + rust( + """ + use std::collections::HashMap; + use aws_smithy_types::{DateTime, Document}; + + let expected = MyStructRequired { + str: "".to_owned(), + primitive_int: 0, + int: 0, + ts: DateTime::from_secs(0), + byte: 0, + union: NestedUnion::Struct(NestedStruct { + str: "".to_owned(), + num: 0, + }), + structure: NestedStruct { + str: "".to_owned(), + num: 0, + }, + list: Vec::new(), + map: HashMap::new(), + doc: Document::Object(HashMap::new()), + }; + assert_eq!(result, expected); + """, + ) + } + } + project.compileAndTest() + } + + @Test + fun `generate named enums`() { + val shape = model.lookup("com.test#NamedEnum") + val sut = serverInstantiator(codegenContext) + val data = Node.parse("t2.nano".dq()) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + unitTest("generate_named_enums") { + withBlock("let result = ", ";") { + sut.render(this, shape, data) + } + rust("assert_eq!(result, NamedEnum::T2Nano);") + } + } + project.compileAndTest() + } + + @Test + fun `generate unnamed enums`() { + val shape = model.lookup("com.test#UnnamedEnum") + val sut = serverInstantiator(codegenContext) + val data = Node.parse("t2.nano".dq()) + + val project = TestWorkspace.testProject() + project.withModule(RustModule.Model) { + EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + unitTest("generate_unnamed_enums") { + withBlock("let result = ", ";") { + sut.render(this, shape, data) + } + rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""") + } + } + project.compileAndTest() + } +}