diff --git a/codegen/smithy-go-codegen-test/model/main.smithy b/codegen/smithy-go-codegen-test/model/main.smithy index 7f1952135..e340feb30 100644 --- a/codegen/smithy-go-codegen-test/model/main.smithy +++ b/codegen/smithy-go-codegen-test/model/main.smithy @@ -3,6 +3,7 @@ namespace example.weather use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.waiters#waitable /// Provides weather forecasts. @fakeProtocol @@ -36,6 +37,56 @@ resource CityImage { string CityId @readonly +@waitable( + CityExists: { + description: "Waits until a city has been created", + acceptors: [ + // Fail-fast if the thing transitions to a "failed" state. + { + state: "failure", + matcher: { + errorType: "NoSuchResource" + } + }, + // Fail-fast if the thing transitions to a "failed" state. + { + state: "failure", + matcher: { + errorType: "UnModeledError" + } + }, + // Succeed when the city image value is not empty i.e. enters into a "success" state. + { + state: "success", + matcher: { + success: true + } + }, + // Retry if city id input is of same length as city name in output + { + state: "retry", + matcher: { + inputOutput: { + path: "length(input.cityId) == length(output.name)", + comparator: "booleanEquals", + expected: "true", + } + } + }, + // Success if city name in output is seattle + { + state: "success", + matcher: { + output: { + path: "name", + comparator: "stringEquals", + expected: "seattle", + } + } + } + ] + } +) @http(method: "GET", uri: "/cities/{cityId}") operation GetCity { input: GetCityInput, @@ -178,6 +229,35 @@ apply NoSuchResource @httpResponseTests([ // return truncated results. @readonly @paginated(items: "items") +@waitable( + "ListContainsCity": { + description: "Wait until ListCities operation response matches a given state", + acceptors: [ + // failure in case all items returned match to seattle + { + state: "failure", + matcher: { + output: { + path: "items", + comparator: "allStringEquals", + expected: "seattle", + } + } + }, + // success in case any items returned match to NewYork + { + state: "success", + matcher: { + output: { + path: "items", + comparator: "anyStringEquals", + expected: "NewYork", + } + } + } + ] + } +) @http(method: "GET", uri: "/cities") operation ListCities { input: ListCitiesInput, diff --git a/codegen/smithy-go-codegen/build.gradle.kts b/codegen/smithy-go-codegen/build.gradle.kts index 1d472d60a..dcea37619 100644 --- a/codegen/smithy-go-codegen/build.gradle.kts +++ b/codegen/smithy-go-codegen/build.gradle.kts @@ -19,6 +19,7 @@ extra["moduleName"] = "software.amazon.smithy.go.codegen" dependencies { api("software.amazon.smithy:smithy-codegen-core:[1.3.0,2.0.0[") + implementation("software.amazon.smithy:smithy-waiters:[1.4.0,2.0.0[") compile("com.atlassian.commonmark:commonmark:0.15.2") api("org.jsoup:jsoup:1.13.1") implementation("software.amazon.smithy:smithy-protocol-test-traits:[1.3.0,2.0.0[") diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index 071669b93..db7906f07 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -53,12 +53,16 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_RAND = smithy("rand", "smithyrand"); public static final GoDependency SMITHY_TESTING = smithy("testing", "smithytesting"); public static final GoDependency SMITHY_XML = smithy("xml", "smithyxml"); + public static final GoDependency SMITHY_WAITERS = smithy("waiter", "smithywaiter"); public static final GoDependency GO_CMP = goCmp("cmp"); public static final GoDependency GO_CMP_OPTIONS = goCmp("cmp/cmpopts"); + public static final GoDependency GO_JMESPATH = goJmespath(null); + private static final String SMITHY_SOURCE_PATH = "github.com/awslabs/smithy-go"; private static final String GO_CMP_SOURCE_PATH = "github.com/google/go-cmp"; + private static final String GO_JMESPATH_SOURCE_PATH = "github.com/jmespath/go-jmespath"; private SmithyGoDependency() { } @@ -94,6 +98,10 @@ private static GoDependency goCmp(String relativePath) { return relativePackage(GO_CMP_SOURCE_PATH, relativePath, Versions.GO_CMP, null); } + private static GoDependency goJmespath(String relativePath) { + return relativePackage(GO_JMESPATH_SOURCE_PATH, relativePath, Versions.GO_JMESPATH, null); + } + private static GoDependency relativePackage( String moduleImportPath, String relativePath, @@ -110,6 +118,7 @@ private static GoDependency relativePackage( private static final class Versions { private static final String GO_STDLIB = "1.15"; private static final String GO_CMP = "v0.5.4"; - private static final String SMITHY_GO = "v0.4.0"; + private static final String SMITHY_GO = "v0.4.1-0.20201208232924-b8cdbaa577ff"; + private static final String GO_JMESPATH = "v0.4.0"; } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/OperationInterfaceGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/OperationInterfaceGenerator.java new file mode 100644 index 000000000..57827cf0e --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/OperationInterfaceGenerator.java @@ -0,0 +1,132 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.integration; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoDelegator; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.PaginatedIndex; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.waiters.WaitableTrait; + +/** + * Generates API client Interfaces as per API operation. + */ +public class OperationInterfaceGenerator implements GoIntegration { + + private static Map> mapOfClientInterfaceOperations = new HashMap<>(); + + /** + * Returns name of an API client interface. + * + * @param operationSymbol Symbol of operation shape for which Api client interface is being generated. + * @return name of the interface. + */ + public static String getApiClientInterfaceName( + Symbol operationSymbol + ) { + return String.format("%sAPIClient", operationSymbol.getName()); + } + + @Override + public void processFinalizedModel( + GoSettings settings, + Model model + ) { + ServiceShape serviceShape = settings.getService(model); + TopDownIndex topDownIndex = TopDownIndex.of(model); + PaginatedIndex paginatedIndex = PaginatedIndex.of(model); + + Set listOfClientInterfaceOperations = new TreeSet<>(); + + // fetch operations for which paginators are generated + topDownIndex.getContainedOperations(serviceShape).stream() + .map(operationShape -> paginatedIndex.getPaginationInfo(serviceShape, operationShape)) + .filter(Optional::isPresent) + .map(Optional::get) + .forEach(paginationInfo -> listOfClientInterfaceOperations.add(paginationInfo.getOperation().getId())); + + // fetch operations for which waitable trait is applied + topDownIndex.getContainedOperations(serviceShape).stream() + .filter(operationShape -> operationShape.hasTrait(WaitableTrait.class)) + .forEach(operationShape -> listOfClientInterfaceOperations.add(operationShape.getId())); + + if (!listOfClientInterfaceOperations.isEmpty()) { + mapOfClientInterfaceOperations.put(serviceShape.getId(), listOfClientInterfaceOperations); + } + } + + @Override + public void writeAdditionalFiles( + GoSettings settings, + Model model, + SymbolProvider symbolProvider, + GoDelegator goDelegator + ) { + ShapeId serviceId = settings.getService(model).getId(); + + if (mapOfClientInterfaceOperations.containsKey(serviceId)) { + Set listOfClientInterfaceOperations = mapOfClientInterfaceOperations.get(serviceId); + listOfClientInterfaceOperations.stream().forEach(shapeId -> { + OperationShape operationShape = model.expectShape(shapeId, OperationShape.class); + goDelegator.useShapeWriter(operationShape, writer -> { + generateApiClientInterface(writer, model, symbolProvider, operationShape); + }); + }); + } + } + + private void generateApiClientInterface( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + OperationShape operationShape + ) { + Symbol contextSymbol = SymbolUtils.createValueSymbolBuilder("Context", SmithyGoDependency.CONTEXT) + .build(); + + Symbol operationSymbol = symbolProvider.toSymbol(operationShape); + + Symbol interfaceSymbol = SymbolUtils.createValueSymbolBuilder(getApiClientInterfaceName(operationSymbol)) + .build(); + + Symbol inputSymbol = symbolProvider.toSymbol(model.expectShape(operationShape.getInput().get())); + Symbol outputSymbol = symbolProvider.toSymbol(model.expectShape(operationShape.getOutput().get())); + + writer.writeDocs(String.format("%s is a client that implements the %s operation.", + interfaceSymbol.getName(), operationSymbol.getName())); + writer.openBlock("type $T interface {", "}", interfaceSymbol, () -> { + writer.write("$L($T, $P, ...func(*Options)) ($P, error)", operationSymbol.getName(), contextSymbol, + inputSymbol, outputSymbol); + }); + writer.write(""); + writer.write("var _ $T = (*Client)(nil)", interfaceSymbol); + writer.write(""); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java index 89a88788b..165af5ba7 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Paginators.java @@ -72,14 +72,14 @@ private void generateOperationPaginator( ) { Symbol operationSymbol = symbolProvider.toSymbol(paginationInfo.getOperation()); - Symbol interfaceSymbol = SymbolUtils.createValueSymbolBuilder(String.format("%sAPIClient", - operationSymbol.getName())).build(); + Symbol interfaceSymbol = SymbolUtils.createValueSymbolBuilder( + OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol) + ).build(); Symbol paginatorSymbol = SymbolUtils.createPointableSymbolBuilder(String.format("%sPaginator", operationSymbol.getName())).build(); Symbol optionsSymbol = SymbolUtils.createPointableSymbolBuilder(String.format("%sOptions", paginatorSymbol.getName())).build(); - writeClientOperationInterface(writer, symbolProvider, paginationInfo, interfaceSymbol); writePaginatorOptions(writer, model, symbolProvider, paginationInfo, operationSymbol, optionsSymbol); writePaginator(writer, model, symbolProvider, paginationInfo, interfaceSymbol, paginatorSymbol, optionsSymbol); } @@ -249,30 +249,4 @@ private void writePaginatorOptions( }); writer.write(""); } - - private void writeClientOperationInterface( - GoWriter writer, - SymbolProvider symbolProvider, - PaginationInfo paginationInfo, - Symbol interfaceSymbol - ) { - Symbol contextSymbol = SymbolUtils.createValueSymbolBuilder("Context", SmithyGoDependency.CONTEXT) - .build(); - - Symbol operationSymbol = symbolProvider.toSymbol(paginationInfo.getOperation()); - Symbol inputSymbol = symbolProvider.toSymbol(paginationInfo.getInput()); - Symbol outputSymbol = symbolProvider.toSymbol(paginationInfo.getOutput()); - - writer.writeDocs(String.format("%s is a client that implements the %s operation.", - interfaceSymbol.getName(), operationSymbol.getName())); - writer.openBlock("type $T interface {", "}", interfaceSymbol, () -> { - writer.write("$L($T, $P, ...func(*Options)) ($P, error)", operationSymbol.getName(), contextSymbol, - inputSymbol, outputSymbol); - }); - writer.write(""); - writer.write("var _ $T = (*Client)(nil)", interfaceSymbol); - writer.write(""); - } - - } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters.java new file mode 100644 index 000000000..1d07dd2b2 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters.java @@ -0,0 +1,694 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.integration; + +import java.util.Map; +import java.util.Optional; +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoDelegator; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.utils.StringUtils; +import software.amazon.smithy.waiters.Acceptor; +import software.amazon.smithy.waiters.Matcher; +import software.amazon.smithy.waiters.PathComparator; +import software.amazon.smithy.waiters.WaitableTrait; +import software.amazon.smithy.waiters.Waiter; + +/** + * Implements support for WaitableTrait. + */ +public class Waiters implements GoIntegration { + private static final String WAITER_INVOKER_FUNCTION_NAME = "Wait"; + + @Override + public void writeAdditionalFiles( + GoSettings settings, + Model model, + SymbolProvider symbolProvider, + GoDelegator goDelegator + ) { + ServiceShape serviceShape = settings.getService(model); + TopDownIndex topDownIndex = TopDownIndex.of(model); + + topDownIndex.getContainedOperations(serviceShape).stream() + .forEach(operation -> { + if (!operation.hasTrait(WaitableTrait.ID)) { + return; + } + + Map waiters = operation.expectTrait(WaitableTrait.class).getWaiters(); + + goDelegator.useShapeWriter(operation, writer -> { + generateOperationWaiter(model, symbolProvider, writer, operation, waiters); + }); + }); + } + + + /** + * Generates all waiter components used for the operation. + */ + private void generateOperationWaiter( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operation, + Map waiters + ) { + // generate waiter function + waiters.forEach((name, waiter) -> { + // write waiter options + generateWaiterOptions(model, symbolProvider, writer, operation, name, waiter); + + // write waiter client + generateWaiterClient(model, symbolProvider, writer, operation, name, waiter); + + // write waiter specific invoker + generateWaiterInvoker(model, symbolProvider, writer, operation, name, waiter); + + // write waiter state mutator for each waiter + generateRetryable(model, symbolProvider, writer, operation, name, waiter); + + }); + } + + /** + * Generates waiter options to configure a waiter client. + */ + private void generateWaiterOptions( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + String optionsName = generateWaiterOptionsName(waiterName); + String waiterClientName = generateWaiterClientName(waiterName); + + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + StructureShape outputShape = model.expectShape( + operationShape.getOutput().get(), StructureShape.class + ); + + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + Symbol outputSymbol = symbolProvider.toSymbol(outputShape); + + writer.write(""); + writer.writeDocs( + String.format("%s are waiter options for %s", optionsName, waiterClientName) + ); + + writer.openBlock("type $L struct {", "}", + optionsName, () -> { + writer.addUseImports(SmithyGoDependency.TIME); + + writer.write(""); + writer.writeDocs( + "Set of options to modify how an operation is invoked. These apply to all operations " + + "invoked for this client. Use functional options on operation call to modify " + + "this list for per operation behavior." + ); + Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", + SmithyGoDependency.SMITHY_MIDDLEWARE) + .build(); + writer.write("APIOptions []func($P) error", stackSymbol); + + writer.write(""); + writer.writeDocs( + String.format("MinDelay is the minimum amount of time to delay between retries. " + + "If unset, %s will use default minimum delay of %s seconds. " + + "Note that MinDelay must resolve to a value lesser than or equal " + + "to the MaxDelay.", waiterClientName, waiter.getMinDelay()) + ); + writer.write("MinDelay time.Duration"); + + writer.write(""); + writer.writeDocs( + String.format("MaxDelay is the maximum amount of time to delay between retries. " + + "If unset or set to zero, %s will use default max delay of %s seconds. " + + "Note that MaxDelay must resolve to value greater than or equal " + + "to the MinDelay.", waiterClientName, waiter.getMaxDelay()) + ); + writer.write("MaxDelay time.Duration"); + + writer.write(""); + writer.writeDocs("LogWaitAttempts is used to enable logging for waiter retry attempts"); + writer.write("LogWaitAttempts bool"); + + writer.write(""); + writer.writeDocs( + "Retryable is function that can be used to override the " + + "service defined waiter-behavior based on operation output, or returned error. " + + "This function is used by the waiter to decide if a state is retryable " + + "or a terminal state.\n\nBy default service-modeled logic " + + "will populate this option. This option can thus be used to define a custom " + + "waiter state with fall-back to service-modeled waiter state mutators." + + "The function returns an error in case of a failure state. " + + "In case of retry state, this function returns a bool value of true and " + + "nil error, while in case of success it returns a bool value of false and " + + "nil error." + ); + writer.write( + "Retryable func(context.Context, $P, $P, error) " + + "(bool, error)", inputSymbol, outputSymbol); + } + ); + writer.write(""); + } + + + /** + * Generates waiter client used to invoke waiter function. The waiter client is specific to a modeled waiter. + * Each waiter client is unique within a enclosure of a service. + * This function also generates a waiter client constructor that takes in a API client interface, and waiter options + * to configure a waiter client. + */ + private void generateWaiterClient( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + Symbol operationSymbol = symbolProvider.toSymbol(operationShape); + String clientName = generateWaiterClientName(waiterName); + + writer.write(""); + writer.writeDocs( + String.format("%s defines the waiters for %s", clientName, waiterName) + ); + writer.openBlock("type $L struct {", "}", + clientName, () -> { + writer.write(""); + writer.write("client $L", OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol)); + + writer.write(""); + writer.write("options $L", generateWaiterOptionsName(waiterName)); + }); + + writer.write(""); + + String constructorName = String.format("New%s", clientName); + + Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterOptionsName(waiterName) + ).build(); + + Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder( + clientName + ).build(); + + writer.writeDocs( + String.format("%s constructs a %s.", constructorName, clientName) + ); + writer.openBlock("func $L(client $L, optFns ...func($P)) $P {", "}", + constructorName, OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol), + waiterOptionsSymbol, clientSymbol, () -> { + writer.write("options := $T{}", waiterOptionsSymbol); + writer.addUseImports(SmithyGoDependency.TIME); + + // set defaults + writer.write("options.MinDelay = $L * time.Second", waiter.getMinDelay()); + writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay()); + writer.write("options.Retryable = $L", generateRetryableName(waiterName)); + writer.write(""); + + writer.openBlock("for _, fn := range optFns {", + "}", () -> { + writer.write("fn(&options)"); + }); + + writer.openBlock("return &$T {", "}", clientSymbol, () -> { + writer.write("client: client, "); + writer.write("options: options, "); + }); + }); + } + + /** + * Generates waiter invoker functions to call specific operation waiters + * These waiter invoker functions is defined on each modeled waiter client. + * The invoker function takes in a context, along with operation input, and + * optional functional options for the waiter. + */ + private void generateWaiterInvoker( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + + Symbol operationSymbol = symbolProvider.toSymbol(operationShape); + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + + Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterOptionsName(waiterName) + ).build(); + + Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterClientName(waiterName) + ).build(); + + writer.write(""); + writer.addUseImports(SmithyGoDependency.CONTEXT); + writer.addUseImports(SmithyGoDependency.TIME); + writer.writeDocs( + String.format( + "%s calls the waiter function for %s waiter. The maxWaitDur is the maximum wait duration " + + "the waiter will wait. The maxWaitDur is required and must be greater than zero.", + WAITER_INVOKER_FUNCTION_NAME, waiterName) + ); + writer.openBlock( + "func (w $P) $L(ctx context.Context, params $P, maxWaitDur time.Duration, optFns ...func($P)) error {", + "}", + clientSymbol, WAITER_INVOKER_FUNCTION_NAME, inputSymbol, waiterOptionsSymbol, + () -> { + writer.openBlock("if maxWaitDur <= 0 {", "}", () -> { + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("fmt.Errorf(\"maximum wait time for waiter must be greater than zero\")"); + }).write(""); + + writer.write("options := w.options"); + + writer.openBlock("for _, fn := range optFns {", + "}", () -> { + writer.write("fn(&options)"); + }); + writer.write(""); + + // validate values for MaxDelay from options + writer.openBlock("if options.MaxDelay <= 0 {", "}", () -> { + writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay()); + }); + writer.write(""); + + // validate that MinDelay is lesser than or equal to resolved MaxDelay + writer.openBlock("if options.MinDelay > options.MaxDelay {", "}", () -> { + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("return fmt.Errorf(\"minimum waiter delay %v must be lesser than or equal to " + + "maximum waiter delay of %v.\", options.MinDelay, options.MaxDelay)"); + }).write(""); + + writer.addUseImports(SmithyGoDependency.CONTEXT); + writer.write("ctx, cancelFn := context.WithTimeout(ctx, maxWaitDur)"); + writer.write("defer cancelFn()"); + writer.write(""); + + Symbol loggerMiddleware = SymbolUtils.createValueSymbolBuilder( + "Logger", SmithyGoDependency.SMITHY_WAITERS + ).build(); + writer.write("logger := $T{}", loggerMiddleware); + writer.write("remainingTime := maxWaitDur").write(""); + + writer.write("var attempt int64"); + writer.openBlock("for {", "}", () -> { + writer.write(""); + writer.write("attempt++"); + + writer.write("apiOptions := options.APIOptions"); + writer.write("start := time.Now()").write(""); + + // add waiter logger middleware to log an attempt, if LogWaitAttempts is enabled. + writer.openBlock("if options.LogWaitAttempts {", "}", () -> { + writer.write("logger.Attempt = attempt"); + writer.write( + "apiOptions = append([]func(*middleware.Stack) error{}, options.APIOptions...)"); + writer.write("apiOptions = append(apiOptions, logger.AddLogger)"); + }).write(""); + + // make a request + writer.openBlock("out, err := w.client.$T(ctx, params, func (o *Options) { ", "})", + operationSymbol, () -> { + writer.write("o.APIOptions = append(o.APIOptions, apiOptions...)"); + }); + writer.write(""); + + // handle response and identify waiter state + writer.write("retryable, err := options.Retryable(ctx, params, out, err)"); + writer.write("if err != nil { return err }"); + writer.write("if !retryable { return nil }").write(""); + + // update remaining time + writer.write("remainingTime -= time.Since(start)"); + + // check if next iteration is possible + writer.openBlock("if remainingTime < options.MinDelay || remainingTime <= 0 {", "}", () -> { + writer.write("break"); + }); + writer.write(""); + + // handle retry delay computation, sleep. + Symbol computeDelaySymbol = SymbolUtils.createValueSymbolBuilder( + "ComputeDelay", SmithyGoDependency.SMITHY_WAITERS + ).build(); + writer.writeDocs("compute exponential backoff between waiter retries"); + writer.openBlock("delay, err := $T(", ")", computeDelaySymbol, () -> { + writer.write("attempt, options.MinDelay, options.MaxDelay, remainingTime,"); + }); + + writer.addUseImports(SmithyGoDependency.FMT); + writer.write( + "if err != nil { return fmt.Errorf(\"error computing waiter delay, %w\", err)}"); + writer.write(""); + + // update remaining time as per computed delay + writer.write("remainingTime -= delay"); + + // sleep for delay + Symbol sleepWithContextSymbol = SymbolUtils.createValueSymbolBuilder( + "SleepWithContext", SmithyGoDependency.SMITHY_TIME + ).build(); + writer.writeDocs("sleep for the delay amount before invoking a request"); + writer.openBlock("if err := $T(ctx, delay); err != nil {", "}", sleepWithContextSymbol, + () -> { + writer.write( + "return fmt.Errorf(\"request cancelled while waiting, %w\", err)"); + }); + }); + writer.write("return fmt.Errorf(\"exceeded max wait time for $L waiter\")", waiterName); + }); + } + + /** + * Generates a waiter state mutator function which is used by the waiter retrier Middleware to mutate + * waiter state as per the defined logic and returned operation response. + * + * @param model the smithy model + * @param symbolProvider symbol provider + * @param writer the Gowriter + * @param operationShape operation shape on which the waiter is modeled + * @param waiterName the waiter name + * @param waiter the waiter structure that contains info on modeled waiter + */ + private void generateRetryable( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + StructureShape outputShape = model.expectShape( + operationShape.getOutput().get(), StructureShape.class + ); + + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + Symbol outputSymbol = symbolProvider.toSymbol(outputShape); + + writer.write(""); + writer.openBlock("func $L(ctx context.Context, input $P, output $P, err error) (bool, error) {", + "}", generateRetryableName(waiterName), inputSymbol, outputSymbol, () -> { + waiter.getAcceptors().forEach(acceptor -> { + writer.write(""); + // scope each acceptor to avoid name collisions + Matcher matcher = acceptor.getMatcher(); + switch (matcher.getMemberName()) { + case "output": + writer.addUseImports(SmithyGoDependency.GO_JMESPATH); + writer.addUseImports(SmithyGoDependency.FMT); + + Matcher.OutputMember outputMember = (Matcher.OutputMember) matcher; + String path = outputMember.getValue().getPath(); + String expectedValue = outputMember.getValue().getExpected(); + PathComparator comparator = outputMember.getValue().getComparator(); + writer.openBlock("if err == nil {", "}", () -> { + writer.write("pathValue, err := jmespath.Search($S, output)", path); + writer.openBlock("if err != nil {", "}", () -> { + writer.write( + "return false, " + + "fmt.Errorf(\"error evaluating waiter state: %w\", err)"); + }).write(""); + writer.write("expectedValue := $S", expectedValue); + writeWaiterComparator(writer, acceptor, comparator, "pathValue", + "expectedValue"); + }); + + break; + + case "inputOutput": + writer.addUseImports(SmithyGoDependency.GO_JMESPATH); + writer.addUseImports(SmithyGoDependency.FMT); + + Matcher.InputOutputMember ioMember = (Matcher.InputOutputMember) matcher; + path = ioMember.getValue().getPath(); + expectedValue = ioMember.getValue().getExpected(); + comparator = ioMember.getValue().getComparator(); + writer.openBlock("if err == nil {", "}", () -> { + writer.openBlock("pathValue, err := jmespath.Search($S, &struct{", + "})", path, () -> { + writer.write("Input $P \n Output $P \n }{", inputSymbol, + outputSymbol); + writer.write("Input: input, \n Output: output, \n"); + }); + writer.openBlock("if err != nil {", "}", () -> { + writer.write( + "return false, " + + "fmt.Errorf(\"error evaluating waiter state: %w\", err)"); + }); + writer.write(""); + writer.write("expectedValue := $S", expectedValue); + writeWaiterComparator(writer, acceptor, comparator, "pathValue", + "expectedValue"); + }); + break; + + case "success": + Matcher.SuccessMember successMember = (Matcher.SuccessMember) matcher; + writer.openBlock("if err == nil {", "}", + () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case "errorType": + Matcher.ErrorTypeMember errorTypeMember = (Matcher.ErrorTypeMember) matcher; + String errorType = errorTypeMember.getValue(); + + writer.openBlock("if err != nil {", "}", () -> { + + // identify if this is a modeled error shape + Optional errorShape = operationShape.getErrors().stream().filter( + shapeId -> { + return shapeId.getName().equalsIgnoreCase(errorType); + }).findFirst(); + + // if modeled error shape + if (errorShape.isPresent()) { + Symbol modeledErrorSymbol = SymbolUtils.createValueSymbolBuilder( + errorShape.get().getName(), "types" + ).build(); + writer.addUseImports(SmithyGoDependency.ERRORS); + writer.write("var errorType *$T", modeledErrorSymbol); + writer.openBlock("if errors.As(err, &errorType) {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + } else { + // fall back to un-modeled error shape matching + writer.addUseImports(SmithyGoDependency.SMITHY); + writer.addUseImports(SmithyGoDependency.ERRORS); + + // assert unmodeled error to smithy's API error + writer.write("var apiErr smithy.APIError"); + writer.write("ok := errors.As(err, &apiErr)"); + writer.openBlock("if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"expected err to be of type smithy.APIError\")"); + }); + writer.write(""); + + writer.openBlock("if $S == apiErr.ErrorCode() {", "}", + errorType, () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + } + }); + break; + + default: + throw new CodegenException( + String.format("unknown waiter state : %v", matcher.getMemberName()) + ); + } + }); + + writer.write(""); + writer.write("return true, nil"); + }); + } + + /** + * writes comparators for a given waiter. The comparators are defined within the waiter acceptor. + * + * @param writer the Gowriter + * @param acceptor the waiter acceptor that defines the comparator and acceptor states + * @param comparator the comparator + * @param actual the variable carrying the actual value obtained. + * This may be computed via a jmespath expression or operation response status (success/failure) + * @param expected the variable carrying the expected value. This value is as per the modeled waiter. + */ + private void writeWaiterComparator( + GoWriter writer, + Acceptor acceptor, + PathComparator comparator, + String actual, + String expected + ) { + switch (comparator) { + case STRING_EQUALS: + writer.write("value, ok := $L.(string)", actual); + writer.openBlock(" if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"waiter comparator expected string value got %T\", $L)", actual); + }); + writer.write(""); + + writer.openBlock("if value == $L {", "}", expected, () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case BOOLEAN_EQUALS: + writer.addUseImports(SmithyGoDependency.STRCONV); + writer.write("bv, err := strconv.ParseBool($L)", expected); + writer.write( + "if err != nil { return false, " + + "fmt.Errorf(\"error parsing boolean from string %w\", err)}"); + + writer.write("value, ok := $L.(bool)", actual); + writer.openBlock(" if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"waiter comparator expected bool value got %T\", $L)", actual); + }); + writer.write(""); + + writer.openBlock("if value == bv {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case ALL_STRING_EQUALS: + writer.write("var match = true"); + writer.write("listOfValues, ok := $L.([]string)", actual); + writer.openBlock(" if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"waiter comparator expected []string value got %T\", $L)", actual); + }); + writer.write(""); + + writer.write("if len(listOfValues) == 0 { match = false }"); + + writer.openBlock("for _, v := range listOfValues {", "}", () -> { + writer.write("if v != $L { match = false }", expected); + }); + writer.write(""); + + writer.openBlock("if match {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case ANY_STRING_EQUALS: + writer.write("listOfValues, ok := $L.([]string)", actual); + writer.openBlock(" if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"waiter comparator expected []string value got %T\", $L)", actual); + }); + writer.write(""); + + writer.openBlock("for _, v := range listOfValues {", "}", () -> { + writer.openBlock("if v == $L {", "}", expected, () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + }); + break; + + default: + throw new CodegenException( + String.format("Found unknown waiter path comparator, %s", comparator.toString())); + } + } + + + /** + * Writes return statement for state where a waiter's acceptor state is a match. + * + * @param writer the Go writer + * @param acceptor the waiter acceptor who's state is used to write an appropriate return statement. + */ + private void writeMatchedAcceptorReturn(GoWriter writer, Acceptor acceptor) { + switch (acceptor.getState()) { + case SUCCESS: + writer.write("return false, nil"); + break; + + case FAILURE: + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("return false, fmt.Errorf(\"waiter state transitioned to Failure\")"); + break; + + case RETRY: + writer.write("return true, nil"); + break; + + default: + throw new CodegenException("unknown acceptor state defined for the waiter"); + } + } + + private String generateWaiterOptionsName( + String waiterName + ) { + waiterName = StringUtils.capitalize(waiterName); + return String.format("%sWaiterOptions", waiterName); + } + + private String generateWaiterClientName( + String waiterName + ) { + waiterName = StringUtils.capitalize(waiterName); + return String.format("%sWaiter", waiterName); + } + + private String generateRetryableName( + String waiterName + ) { + waiterName = StringUtils.uncapitalize(waiterName); + return String.format("%sStateRetryable", waiterName); + } +} diff --git a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index 1c670e6da..fccd43e43 100644 --- a/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -4,4 +4,6 @@ software.amazon.smithy.go.codegen.integration.AddChecksumRequiredMiddleware software.amazon.smithy.go.codegen.integration.RequiresLengthTraitSupport software.amazon.smithy.go.codegen.integration.EndpointHostPrefixMiddleware software.amazon.smithy.go.codegen.integration.ClientLogger +software.amazon.smithy.go.codegen.integration.OperationInterfaceGenerator software.amazon.smithy.go.codegen.integration.Paginators +software.amazon.smithy.go.codegen.integration.Waiters diff --git a/rand/rand.go b/rand/rand.go new file mode 100644 index 000000000..4dc176d7d --- /dev/null +++ b/rand/rand.go @@ -0,0 +1,31 @@ +package rand + +import ( + "crypto/rand" + "fmt" + "io" + "math/big" +) + +func init() { + Reader = rand.Reader +} + +// Reader provides a random reader that can reset during testing. +var Reader io.Reader + +// Int63n returns a int64 between zero and value of max, read from an io.Reader source. +func Int63n(reader io.Reader, max int64) (int64, error) { + bi, err := rand.Int(reader, big.NewInt(max)) + if err != nil { + return 0, fmt.Errorf("failed to read random value, %w", err) + } + + return bi.Int64(), nil +} + +// CryptoRandInt63n returns a random int64 between zero and value of max +//obtained from the crypto rand source. +func CryptoRandInt63n(max int64) (int64, error) { + return Int63n(Reader, max) +} diff --git a/time/time.go b/time/time.go index 241c555b2..0a3fbe48c 100644 --- a/time/time.go +++ b/time/time.go @@ -1,6 +1,7 @@ package time import ( + "context" "time" ) @@ -53,3 +54,20 @@ func FormatEpochSeconds(value time.Time) float64 { func ParseEpochSeconds(value float64) time.Time { return time.Unix(0, int64(value*float64(time.Second))).UTC() } + +// SleepWithContext will wait for the timer duration to expire, or the context +// is canceled. Which ever happens first. If the context is canceled the +// Context's error will be returned. +func SleepWithContext(ctx context.Context, dur time.Duration) error { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + break + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} diff --git a/waiter/logger.go b/waiter/logger.go new file mode 100644 index 000000000..4853ace01 --- /dev/null +++ b/waiter/logger.go @@ -0,0 +1,35 @@ +package waiter + +import ( + "context" + "fmt" + "github.com/awslabs/smithy-go/logging" + "github.com/awslabs/smithy-go/middleware" +) + +// Logger is the Logger middleware used by the waiter to log an attempt +type Logger struct { + // Attempt is the current attempt to be logged + Attempt int64 +} + +// ID representing the Logger middleware +func (*Logger) ID() string { + return "WaiterLogger" +} + +// HandleInitialize performs handling of request in initialize stack step +func (m *Logger) HandleInitialize(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + logger := middleware.GetLogger(ctx) + + logger.Logf(logging.Debug, fmt.Sprintf("attempting waiter request, attempt count: %d", m.Attempt)) + + return next.HandleInitialize(ctx, in) +} + +// AddLogger is helper util to add waiter logger after `SetLogger` middleware in +func (m Logger) AddLogger(stack *middleware.Stack) error { + return stack.Initialize.Insert(&m, "SetLogger", middleware.After) +} diff --git a/waiter/waiter.go b/waiter/waiter.go new file mode 100644 index 000000000..4b35a96ff --- /dev/null +++ b/waiter/waiter.go @@ -0,0 +1,65 @@ +package waiter + +import ( + "fmt" + "github.com/awslabs/smithy-go/rand" + "math" + "time" +) + +// ComputeDelay computes delay between waiter attempts. The function takes in a current attempt count, +// minimum delay, maximum delay, and remaining wait time for waiter as input. The inputs minDelay and maxDelay +// must always be greater than 0, along with minDelay lesser than or equal to maxDelay. +// +// Returns the computed delay and if next attempt count is possible within the given input time constraints. +// Note that the zeroth attempt results in no delay. +func ComputeDelay(attempt int64, minDelay, maxDelay, remainingTime time.Duration) (delay time.Duration, err error) { + // zeroth attempt, no delay + if attempt <= 0 { + return 0, nil + } + + // remainingTime is zero or less, no delay + if remainingTime <= 0 { + return 0, nil + } + + // validate min delay is greater than 0 + if minDelay == 0 { + return 0, fmt.Errorf("minDelay must be greater than zero when computing Delay") + } + + // validate max delay is greater than 0 + if maxDelay == 0 { + return 0, fmt.Errorf("maxDelay must be greater than zero when computing Delay") + } + + // Get attempt ceiling to prevent integer overflow. + attemptCeiling := (math.Log(float64(maxDelay/minDelay)) / math.Log(2)) + 1 + + if attempt > int64(attemptCeiling) { + delay = maxDelay + } else { + // Compute exponential delay based on attempt. + ri := 1 << uint64(attempt-1) + // compute delay + delay = minDelay * time.Duration(ri) + } + + if delay != minDelay { + // randomize to get jitter between min delay and delay value + d, err := rand.CryptoRandInt63n(int64(delay - minDelay)) + if err != nil { + return 0, fmt.Errorf("error computing retry jitter, %w", err) + } + + delay = time.Duration(d) + minDelay + } + + // check if this is the last attempt possible and compute delay accordingly + if remainingTime-delay <= minDelay { + delay = remainingTime - minDelay + } + + return delay, nil +} diff --git a/waiter/waiter_test.go b/waiter/waiter_test.go new file mode 100644 index 000000000..9f749869b --- /dev/null +++ b/waiter/waiter_test.go @@ -0,0 +1,143 @@ +package waiter + +import ( + "github.com/awslabs/smithy-go/rand" + mathrand "math/rand" + "strings" + "testing" + "time" +) + +func TestComputeDelay(t *testing.T) { + cases := map[string]struct { + totalAttempts int64 + minDelay time.Duration + maxDelay time.Duration + maxWaitTime time.Duration + expectedMaxDelays []time.Duration + expectedError string + expectedMinAttempts int + }{ + "standard": { + totalAttempts: 8, + minDelay: 2 * time.Second, + maxDelay: 120 * time.Second, + maxWaitTime: 300 * time.Second, + expectedMaxDelays: []time.Duration{2, 4, 8, 16, 32, 64, 120, 120}, + expectedMinAttempts: 8, + }, + "zero minDelay": { + totalAttempts: 3, + minDelay: 0, + maxDelay: 120 * time.Second, + maxWaitTime: 300 * time.Second, + expectedError: "minDelay must be greater than zero", + }, + "zero maxDelay": { + totalAttempts: 3, + minDelay: 10 * time.Second, + maxDelay: 0, + maxWaitTime: 300 * time.Second, + expectedError: "maxDelay must be greater than zero", + }, + "zero remaining time": { + totalAttempts: 3, + minDelay: 10 * time.Second, + maxDelay: 20 * time.Second, + maxWaitTime: 0, + expectedMaxDelays: []time.Duration{0}, + expectedMinAttempts: 1, + }, + "max wait time is less than min delay": { + totalAttempts: 3, + minDelay: 10 * time.Second, + maxDelay: 20 * time.Second, + maxWaitTime: 5 * time.Second, + expectedMaxDelays: []time.Duration{0}, + expectedMinAttempts: 1, + }, + "large minDelay": { + totalAttempts: 80, + minDelay: 150 * time.Minute, + maxDelay: 200 * time.Minute, + maxWaitTime: 250 * time.Minute, + expectedMinAttempts: 1, + }, + "large maxDelay": { + totalAttempts: 80, + minDelay: 15 * time.Minute, + maxDelay: 2000 * time.Minute, + maxWaitTime: 250 * time.Minute, + expectedMinAttempts: 5, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + // mock smithy-go rand/#Reader + r := rand.Reader + defer func() { + rand.Reader = r + }() + rand.Reader = mathrand.New(mathrand.NewSource(1)) + + // mock waiter call + delays, err := mockwait(c.totalAttempts, c.minDelay, c.maxDelay, c.maxWaitTime) + + if len(c.expectedError) != 0 { + if err == nil { + t.Fatalf("expected error, got none") + } + if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expected error %v, got %v instead", e, a) + } + } else if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if e, a := c.expectedMinAttempts, len(delays); e > a { + t.Logf("%v", delays) + t.Fatalf("expected minimum attempts to be %v, got %v", e, a) + } + + for i, expectedDelay := range c.expectedMaxDelays { + if e, a := expectedDelay*time.Second, delays[i]; e < a { + t.Fatalf("attempt %d : expected delay to be less than %v, got %v", i+1, e, a) + } + + if e, a := c.minDelay, delays[i]; e > a && c.maxWaitTime > c.minDelay { + t.Fatalf("attempt %d : expected delay to be more than %v, got %v", i+1, e, a) + } + } + t.Logf("delays : %v", delays) + }) + } +} + +func mockwait(maxAttempts int64, minDelay, maxDelay, maxWaitTime time.Duration) ([]time.Duration, error) { + delays := make([]time.Duration, 0) + remainingTime := maxWaitTime + var attempt int64 + + for { + attempt++ + + if maxAttempts < attempt { + break + } + + delay, err := ComputeDelay(attempt, minDelay, maxDelay, remainingTime) + if err != nil { + return delays, err + } + + delays = append(delays, delay) + + remainingTime -= delay + if remainingTime < minDelay { + break + } + } + + return delays, nil +}