diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt index 9d33c1b1c5..3ff80ee259 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt @@ -102,6 +102,8 @@ object ClientRustModule { /** crate::types::error */ val Error = RustModule.public("error", parent = self) } + + val waiters = RustModule.pubCrate("waiters") } class ClientModuleDocProvider( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt new file mode 100644 index 0000000000..805930d58d --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGenerator.kt @@ -0,0 +1,192 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.jmespath.JmespathExpression +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.replaceLifetimes +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.waiters.Matcher +import software.amazon.smithy.waiters.Matcher.ErrorTypeMember +import software.amazon.smithy.waiters.Matcher.InputOutputMember +import software.amazon.smithy.waiters.Matcher.OutputMember +import software.amazon.smithy.waiters.Matcher.SuccessMember +import software.amazon.smithy.waiters.PathComparator +import java.security.MessageDigest + +private typealias Scope = Array> + +/** + * Generates the Rust code for the Smithy waiter "matcher union". + * See https://smithy.io/2.0/additional-specs/waiters.html#matcher-union + */ +class RustWaiterMatcherGenerator( + private val codegenContext: ClientCodegenContext, + private val operationName: String, + private val inputShape: Shape, + private val outputShape: Shape, +) { + private val runtimeConfig = codegenContext.runtimeConfig + private val module = RustModule.pubCrate("matchers", ClientRustModule.waiters) + private val inputSymbol = codegenContext.symbolProvider.toSymbol(inputShape) + private val outputSymbol = codegenContext.symbolProvider.toSymbol(outputShape) + + fun generate( + errorSymbol: Symbol, + matcher: Matcher<*>, + ): RuntimeType { + val fnName = fnName(operationName, matcher) + val scope = + arrayOf( + *preludeScope, + "Input" to inputSymbol, + "Output" to outputSymbol, + "Error" to errorSymbol, + "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(runtimeConfig), + ) + return RuntimeType.forInlineFun(fnName, module) { + docs("Matcher union: " + Node.printJson(matcher.toNode())) + rustBlockTemplate("pub(crate) fn $fnName(_input: &#{Input}, _result: &#{Result}<#{Output}, #{Error}>) -> bool", *scope) { + when (matcher) { + is OutputMember -> generateOutputMember(outputShape, matcher, scope) + is InputOutputMember -> generateInputOutputMember(matcher, scope) + is SuccessMember -> generateSuccessMember(matcher) + is ErrorTypeMember -> generateErrorTypeMember(matcher, scope) + else -> throw CodegenException("Unknown waiter matcher type: $matcher") + } + } + } + } + + private fun RustWriter.generateOutputMember( + outputShape: Shape, + matcher: OutputMember, + scope: Scope, + ) { + val pathExpression = JmespathExpression.parse(matcher.value.path) + val pathTraversal = + RustJmespathShapeTraversalGenerator(codegenContext).generate( + pathExpression, + listOf(TraversalBinding.Global("_output", outputShape)), + ) + + generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope) + } + + private fun RustWriter.generateInputOutputMember( + matcher: InputOutputMember, + scope: Scope, + ) { + val pathExpression = JmespathExpression.parse(matcher.value.path) + val pathTraversal = + RustJmespathShapeTraversalGenerator(codegenContext).generate( + pathExpression, + listOf( + TraversalBinding.Named("input", "_input", inputShape), + TraversalBinding.Named("output", "_output", outputShape), + ), + ) + + generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope) + } + + private fun RustWriter.generatePathTraversalMatcher( + pathTraversal: GeneratedExpression, + expected: String, + comparatorKind: PathComparator, + scope: Scope, + ) { + val comparator = + writable { + rust( + when (comparatorKind) { + PathComparator.ALL_STRING_EQUALS -> "value.iter().all(|s| s == ${expected.dq()})" + PathComparator.ANY_STRING_EQUALS -> "value.iter().any(|s| s == ${expected.dq()})" + PathComparator.STRING_EQUALS -> "value == ${expected.dq()}" + PathComparator.BOOLEAN_EQUALS -> + when (pathTraversal.outputType is RustType.Reference) { + true -> "*value == $expected" + else -> "value == $expected" + } + else -> throw CodegenException("Unknown path matcher comparator: $comparatorKind") + }, + ) + } + + rustTemplate( + """ + fn path_traversal<'a>(_input: &'a #{Input}, _output: &'a #{Output}) -> #{Option}<#{TraversalOutput}> { + #{traversal} + #{Some}(${pathTraversal.identifier}) + } + _result.as_ref() + .ok() + .and_then(|output| path_traversal(_input, output)) + .map(|value| #{comparator}) + .unwrap_or_default() + """, + *scope, + "traversal" to pathTraversal.output, + "TraversalOutput" to pathTraversal.outputType.replaceLifetimes("a"), + "comparator" to comparator, + ) + } + + private fun RustWriter.generateSuccessMember(matcher: SuccessMember) { + rust( + if (matcher.value) { + "_result.is_ok()" + } else { + "_result.is_err()" + }, + ) + } + + private fun RustWriter.generateErrorTypeMember( + matcher: ErrorTypeMember, + scope: Scope, + ) { + rustTemplate( + """ + if let #{Err}(err) = _result { + if let #{Some}(code) = #{ProvideErrorMetadata}::code(err) { + return code == ${matcher.value.dq()}; + } + } + false + """, + *scope, + ) + } + + private fun fnName( + operationName: String, + matcher: Matcher<*>, + ): String { + // Smithy models don't give us anything useful to name these functions with, so just + // SHA-256 hash the matcher JSON and truncate it to a reasonable length. This will have + // a nice side-effect of de-duplicating identical matchers within a given operation. + val jsonValue = Node.printJson(matcher.toNode()) + val bytes = MessageDigest.getInstance("SHA-256").digest(jsonValue.toByteArray()) + val hex = bytes.map { byte -> String.format("%02x", byte) }.joinToString("") + return "match_${operationName.toSnakeCase()}_${hex.substring(0..16)}" + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGeneratorTest.kt new file mode 100644 index 0000000000..79d5aca9fd --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustWaiterMatcherGeneratorTest.kt @@ -0,0 +1,349 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.waiters.Matcher.SuccessMember + +private typealias Scope = Array> + +class RustWaiterMatcherGeneratorTest { + class TestCase( + codegenContext: ClientCodegenContext, + private val rustCrate: RustCrate, + matcherJson: String, + ) { + val operationShape = codegenContext.model.lookup("test#TestOperation") + val inputShape = operationShape.inputShape(codegenContext.model) + val outputShape = operationShape.outputShape(codegenContext.model) + val errorShape = codegenContext.model.lookup("test#SomeError") + val inputSymbol = codegenContext.symbolProvider.toSymbol(inputShape) + val outputSymbol = codegenContext.symbolProvider.toSymbol(outputShape) + val errorSymbol = codegenContext.symbolProvider.toSymbol(errorShape) + + val matcher = SuccessMember.fromNode(Node.parse(matcherJson)) + val matcherFn = + RustWaiterMatcherGenerator(codegenContext, "TestOperation", inputShape, outputShape) + .generate(errorSymbol, matcher) + + val scope = + arrayOf( + *preludeScope, + "Input" to inputSymbol, + "Output" to outputSymbol, + "Error" to errorSymbol, + "ErrorMetadata" to RuntimeType.errorMetadata(codegenContext.runtimeConfig), + "matcher_fn" to matcherFn, + ) + + fun renderTest( + name: String, + writeTest: TestCase.() -> Writable, + ) { + rustCrate.lib { + rustTemplate( + """ + /// Make the unit test public and document it so that compiler + /// doesn't complain about dead code. + pub fn ${name}_test_case() { + #{test} + } + ##[cfg(test)] + ##[test] + fn $name() { + ${name}_test_case(); + } + """, + *scope, + "test" to writeTest(), + ) + } + } + } + + @Test + fun tests() { + clientIntegrationTest(testModel()) { codegenContext, rustCrate -> + successMatcher(codegenContext, rustCrate) + errorMatcher(codegenContext, rustCrate) + outputPathMatcher(codegenContext, rustCrate) + inputOutputPathMatcher(codegenContext, rustCrate) + } + } + + private fun testCase( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + name: String, + matcherJson: String, + writeFn: RustWriter.(Scope) -> Unit, + ) { + TestCase(codegenContext, rustCrate, matcherJson).renderTest(name) { + writable { + writeFn(scope) + } + } + } + + private fun successMatcher( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) = testCase( + codegenContext, + rustCrate, + name = "success_matcher", + matcherJson = """{"success":true}""", + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder().some_string("bar").build()); + assert!(#{matcher_fn}(&input, &result)); + + let result = #{Err}(#{Error}::builder().message("asdf").build()); + assert!(!#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + + private fun errorMatcher( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) = testCase( + codegenContext, + rustCrate, + name = "error_matcher", + matcherJson = """{"errorType":"SomeError"}""", + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder().some_string("bar").build()); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Err}( + #{Error}::builder() + .message("asdf") + .meta(#{ErrorMetadata}::builder().code("SomeOtherError").build()) + .build() + ); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Err}( + #{Error}::builder() + .message("asdf") + .meta(#{ErrorMetadata}::builder().code("SomeError").build()) + .build() + ); + assert!(#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + + private fun outputPathMatcher( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { + fun test( + name: String, + matcherJson: String, + writeFn: RustWriter.(Scope) -> Unit, + ) = testCase(codegenContext, rustCrate, name, matcherJson, writeFn) + + fun matcherJson( + path: String, + expected: String, + comparator: String, + ) = """{"output":{"path":${path.dq()}, "expected":${expected.dq()}, "comparator": ${comparator.dq()}}}""" + + test( + "output_path_matcher_string_equals", + matcherJson( + path = "someString", + expected = "expected-value", + comparator = "stringEquals", + ), + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder().some_string("bar").build()); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Ok}(#{Output}::builder().some_string("expected-value").build()); + assert!(#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + + test( + "output_path_matcher_bool_equals", + matcherJson( + path = "someBool", + expected = "true", + comparator = "booleanEquals", + ), + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder().some_bool(false).build()); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Ok}(#{Output}::builder().some_bool(true).build()); + assert!(#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + + test( + "output_path_matcher_all_string_equals", + matcherJson( + path = "someList", + expected = "foo", + comparator = "allStringEquals", + ), + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder() + .some_list("foo") + .some_list("bar") + .build()); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Ok}(#{Output}::builder() + .some_list("foo") + .some_list("foo") + .build()); + assert!(#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + + test( + "output_path_matcher_any_string_equals", + matcherJson( + path = "someList", + expected = "foo", + comparator = "anyStringEquals", + ), + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder() + .some_list("bar") + .build()); + assert!(!#{matcher_fn}(&input, &result)); + + let result = #{Ok}(#{Output}::builder() + .some_list("bar") + .some_list("foo") + .build()); + assert!(#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + } + + private fun inputOutputPathMatcher( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { + fun test( + name: String, + matcherJson: String, + writeFn: RustWriter.(Scope) -> Unit, + ) = testCase(codegenContext, rustCrate, name, matcherJson, writeFn) + + fun matcherJson( + path: String, + expected: String, + comparator: String, + ) = """{"inputOutput":{"path":${path.dq()}, "expected":${expected.dq()}, "comparator": ${comparator.dq()}}}""" + + test( + "input_output_path_matcher_boolean_equals", + matcherJson( + path = "input.foo == 'foo' && output.someString == 'bar'", + expected = "true", + comparator = "booleanEquals", + ), + ) { scope -> + rustTemplate( + """ + let input = #{Input}::builder().foo("foo").build().unwrap(); + let result = #{Ok}(#{Output}::builder().some_string("bar").build()); + assert!(#{matcher_fn}(&input, &result)); + + let input = #{Input}::builder().foo("asdf").build().unwrap(); + assert!(!#{matcher_fn}(&input, &result)); + """, + *scope, + ) + } + } + + private fun testModel() = + """ + ${'$'}version: "2" + namespace test + + @aws.protocols#awsJson1_0 + service TestService { + operations: [TestOperation], + } + + operation TestOperation { + input: GetEntityRequest, + output: GetEntityResponse, + errors: [SomeError], + } + + @error("server") + structure SomeError { + message: String, + } + + structure GetEntityRequest { + foo: String, + } + + structure GetEntityResponse { + someString: String, + someBool: Boolean, + someList: SomeList, + } + + list SomeList { + member: String + } + """.asSmithyModel() +}