Skip to content

Commit

Permalink
feat: Support RPCv2 CBOR wire protocol (#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
dayaffe authored Jan 10, 2025
1 parent 4631b40 commit 6aa9ebb
Show file tree
Hide file tree
Showing 22 changed files with 350 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import ClientRuntime
import SmithyHTTPAPI

public struct CborValidateResponseHeaderMiddleware<Input, Output> {
public let id: Swift.String = "CborValidateResponseHeaderMiddleware"

public init() {}
}

public enum ServiceResponseError: Error {
case missingHeader(String)
case badHeaderValue(String)
}

extension CborValidateResponseHeaderMiddleware: Interceptor {

public typealias InputType = Input
public typealias OutputType = Output
public typealias RequestType = HTTPRequest
public typealias ResponseType = HTTPResponse

public func readBeforeDeserialization(
context: some BeforeDeserialization<InputType, RequestType, ResponseType>
) async throws {
let response = context.getResponse()
let smithyProtocolHeader = response.headers.value(for: "smithy-protocol")

guard let smithyProtocolHeader else {
throw ServiceResponseError.missingHeader(
"smithy-protocol header is missing from a response over RpcV2 Cbor!"
)
}

guard smithyProtocolHeader == "rpc-v2-cbor" else {
throw ServiceResponseError.badHeaderValue(
"smithy-protocol header is set to \(smithyProtocolHeader) instead of expected value rpc-v2-cbor"
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ extension UserAgentMiddleware: Interceptor {
serviceID: serviceID,
version: version,
config: UserAgentValuesFromConfig(config: config),
context: context.getAttributes()
context: context.getAttributes(),
headers: context.getRequest().headers
).userAgent
let builder = context.getRequest().toBuilder()
builder.withHeader(name: USER_AGENT, value: awsUserAgentString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public struct AWSJSONError: BaseError {

extension AWSJSONError {
@_spi(SmithyReadWrite)
public static func makeQueryCompatibleAWSJsonError(
public static func makeQueryCompatibleError(
httpResponse: HTTPResponse,
responseReader: Reader,
noErrorWrapping: Bool,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import protocol ClientRuntime.BaseError
import enum ClientRuntime.BaseErrorDecodeError
import class SmithyHTTPAPI.HTTPResponse
@_spi(SmithyReadWrite) import class SmithyCBOR.Reader

public struct RpcV2CborError: BaseError {
public let code: String
public let message: String?
public let requestID: String?
@_spi(SmithyReadWrite) public var errorBodyReader: Reader { responseReader }

public let httpResponse: HTTPResponse
private let responseReader: Reader

@_spi(SmithyReadWrite)
public init(httpResponse: HTTPResponse, responseReader: Reader, noErrorWrapping: Bool, code: String? = nil) throws {
switch responseReader.cborValue {
case .map(let errorDetails):
if case let .text(errorCode) = errorDetails["__type"] {
self.code = sanitizeErrorType(errorCode)
} else {
self.code = "UnknownError"
}

if case let .text(errorMessage) = errorDetails["Message"] {
self.message = errorMessage
} else {
self.message = nil
}
default:
self.code = "UnknownError"
self.message = nil
}

self.httpResponse = httpResponse
self.responseReader = responseReader
self.requestID = nil
}
}

// support awsQueryCompatible trait
extension RpcV2CborError {
@_spi(SmithyReadWrite)
public static func makeQueryCompatibleError(
httpResponse: HTTPResponse,
responseReader: Reader,
noErrorWrapping: Bool,
errorDetails: String?
) throws -> RpcV2CborError {
let errorCode = try AwsQueryCompatibleErrorDetails.parse(errorDetails).code
return try RpcV2CborError(
httpResponse: httpResponse,
responseReader: responseReader,
noErrorWrapping: noErrorWrapping,
code: errorCode
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ClientRuntime
import class Smithy.Context
import struct SmithyHTTPAPI.Headers

public struct AWSUserAgentMetadata {
let sdkMetadata: SDKMetadata
Expand Down Expand Up @@ -73,7 +74,8 @@ public struct AWSUserAgentMetadata {
serviceID: String,
version: String,
config: UserAgentValuesFromConfig,
context: Context
context: Context,
headers: Headers
) -> AWSUserAgentMetadata {
let apiMetadata = APIMetadata(serviceID: serviceID, version: version)
let sdkMetadata = SDKMetadata(version: apiMetadata.version)
Expand All @@ -82,7 +84,7 @@ public struct AWSUserAgentMetadata {
let osVersion = PlatformOperationSystemVersion.operatingSystemVersion()
let osMetadata = OSMetadata(family: currentOS, version: osVersion)
let languageMetadata = LanguageMetadata(version: swiftVersion)
let businessMetrics = BusinessMetrics(config: config, context: context)
let businessMetrics = BusinessMetrics(config: config, context: context, headers: headers)
let appIDMetadata = AppIDMetadata(name: config.appID)
let frameworkMetadata = [FrameworkMetadata]()
return AWSUserAgentMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
import ClientRuntime
import class Smithy.Context
import struct Smithy.AttributeKey
import struct SmithyHTTPAPI.Headers

struct BusinessMetrics {
// Mapping of human readable feature ID to the corresponding metric value
let features: [String: String]

init(
config: UserAgentValuesFromConfig,
context: Context
context: Context,
headers: Headers
) {
setFlagsIntoContext(config: config, context: context)
setFlagsIntoContext(config: config, context: context, headers: headers)
self.features = context.businessMetrics
}
}
Expand Down Expand Up @@ -73,7 +75,7 @@ public let businessMetricsKey = AttributeKey<Dictionary<String, String>>(name: "
"S3_EXPRESS_BUCKET" : "J" :
"S3_ACCESS_GRANTS" : "K" :
"GZIP_REQUEST_COMPRESSION" : "L" :
"PROTOCOL_RPC_V2_CBOR" : "M" :
"PROTOCOL_RPC_V2_CBOR" : "M" : Y
"ENDPOINT_OVERRIDE" : "N" : Y
"ACCOUNT_ID_ENDPOINT" : "O" :
"ACCOUNT_ID_MODE_PREFERRED" : "P" :
Expand All @@ -84,7 +86,8 @@ public let businessMetricsKey = AttributeKey<Dictionary<String, String>>(name: "
*/
private func setFlagsIntoContext(
config: UserAgentValuesFromConfig,
context: Context
context: Context,
headers: Headers
) {
// Handle D, E, F
switch config.awsRetryMode {
Expand All @@ -103,4 +106,8 @@ private func setFlagsIntoContext(
if context.selectedAuthScheme?.schemeID == "aws.auth#sigv4a" {
context.businessMetrics = ["SIGV4A_SIGNING": "S"]
}
// Handle M
if headers.value(for: "smithy-protocol") == "rpc-v2-cbor" {
context.businessMetrics = ["PROTOCOL_RPC_V2_CBOR": "M"]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ import ClientRuntime
@testable import AWSClientRuntime
import SmithyRetriesAPI
import SmithyHTTPAuthAPI
import SmithyHTTPAPI
import SmithyIdentity
import SmithyRetriesAPI
import Smithy

class BusinessMetricsTests: XCTestCase {
var context: Context!
var headers: Headers!

override func setUp() async throws {
context = Context(attributes: Attributes())
headers = Headers()
}

func test_business_metrics_section_truncation() {
Expand All @@ -29,7 +32,8 @@ class BusinessMetricsTests: XCTestCase {
serviceID: "test",
version: "1.0",
config: UserAgentValuesFromConfig(appID: nil, endpoint: nil, awsRetryMode: .standard),
context: context
context: context,
headers: headers
)
// Assert values in context match with values assigned to user agent
XCTAssertEqual(userAgent.businessMetrics?.features, context.businessMetrics)
Expand All @@ -51,7 +55,8 @@ class BusinessMetricsTests: XCTestCase {
serviceID: "test",
version: "1.0",
config: UserAgentValuesFromConfig(appID: nil, endpoint: "test-endpoint", awsRetryMode: .adaptive),
context: context
context: context,
headers: headers
)
// F comes from retry mode being adaptive & N comes from endpoint override
let expectedString = "m/A,B,F,N,S"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This SDK is open-source. Code is available on Github [here](https://github.com/

[Smithy](../../../../../swift/api/smithy/latest)

[SmithyCBOR](../../../../../swift/api/smithycbor/latest)

[SmithyChecksums](../../../../../swift/api/smithychecksums/latest)

[SmithyChecksumsAPI](../../../../../swift/api/smithychecksumsapi/latest)
Expand Down
3 changes: 2 additions & 1 deletion codegen/Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ private var protocolTestTargets: [Target] {
.init(name: "EventStream", sourcePath: "\(baseDirLocal)/EventStream", buildOnly: true),
.init(name: "RPCEventStream", sourcePath: "\(baseDirLocal)/RPCEventStream", buildOnly: true),
.init(name: "Waiters", sourcePath: "\(baseDirLocal)/Waiters", testPath: "../codegen/protocol-test-codegen-local/Tests"),
.init(name: "StringArrayEndpointParam", sourcePath: "\(baseDirLocal)/StringArrayEndpointParam")
.init(name: "StringArrayEndpointParam", sourcePath: "\(baseDirLocal)/StringArrayEndpointParam"),
.init(name: "RPCV2CBORTestSDK", sourcePath: "\(baseDir)/smithy-rpcv2-cbor")
]
return protocolTests.flatMap { protocolTest in
let target = Target.target(
Expand Down
4 changes: 3 additions & 1 deletion codegen/protocol-test-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies {
implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation(project(":smithy-aws-swift-codegen"))
implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion")
}

val enabledProtocols = listOf(
Expand All @@ -39,7 +40,8 @@ val enabledProtocols = listOf(
ProtocolTest("apigateway", "com.amazonaws.apigateway#BackplaneControlService", "APIGatewayTestSDK"),
ProtocolTest("glacier", "com.amazonaws.glacier#Glacier", "GlacierTestSDK"),
ProtocolTest("s3", "com.amazonaws.s3#AmazonS3", "S3TestSDK"),
ProtocolTest("machinelearning", "com.amazonaws.machinelearning#AmazonML_20141212", "MachineLearningTestSDK")
ProtocolTest("machinelearning", "com.amazonaws.machinelearning#AmazonML_20141212", "MachineLearningTestSDK"),
ProtocolTest("smithy-rpcv2-cbor", "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "RPCV2CBORTestSDK"),
)

// This project doesn't produce a JAR.
Expand Down
1 change: 1 addition & 0 deletions codegen/smithy-aws-swift-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies {
api("software.amazon.smithy:smithy-aws-iam-traits:$smithyVersion")
api("software.amazon.smithy:smithy-aws-cloudformation-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-traits:$smithyVersion")
testImplementation("org.junit.jupiter:junit-jupiter:$junitVersion")
testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion")
implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ abstract class AWSHTTPBindingProtocolGenerator(
operation,
OperationEndpointResolverMiddleware(ctx, customizations.endpointMiddlewareSymbol)
)
}

override fun addUserAgentMiddleware(ctx: ProtocolGenerator.GenerationContext, operation: OperationShape) {
operationMiddleware.appendMiddleware(operation, UserAgentMiddleware(ctx.settings))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@ package software.amazon.smithy.aws.swift.codegen

import software.amazon.smithy.aws.traits.ServiceTrait
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.swift.codegen.model.expectTrait
import software.amazon.smithy.swift.codegen.model.getTrait

/**
* Get the [sdkId](https://smithy.io/2.0/aws/aws-core.html#sdkid) from the (AWS) service shape
* or return a default value if the trait is not present.
*/
val ServiceShape.sdkId: String
get() = expectTrait<ServiceTrait>().sdkId
get() = getTrait<ServiceTrait>()?.sdkId ?: "defaultSdkId"

/**
* Get the [arnNamespace](https://smithy.io/2.0/aws/aws-core.html#arnnamespace)
* from the (AWS) service shape
* from the (AWS) service shape or return a default value if the trait is not present.
*/
val ServiceShape.arnNamespace: String
get() = expectTrait<ServiceTrait>().arnNamespace
get() = getTrait<ServiceTrait>()?.arnNamespace ?: "defaultArnNamespace"

/**
* Get the [endpointPrefix](https://smithy.io/2.0/aws/aws-core.html#endpointprefix)
* from the (AWS) service shape
* from the (AWS) service shape or return a default value if the trait is not present.
*/
val ServiceShape.endpointPrefix: String
get() = expectTrait<ServiceTrait>().endpointPrefix
get() = getTrait<ServiceTrait>()?.endpointPrefix ?: "defaultEndpointPrefix"
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ class AWSSmokeTestGenerator(
)

override fun getServiceName(): String {
return "AWS" + ctx.service.getTrait(ServiceTrait::class.java).get().sdkId.toUpperCamelCase()
val serviceTrait = ctx.service.getTrait(ServiceTrait::class.java).orElse(null)
val sdkId = serviceTrait?.sdkId?.toUpperCamelCase() ?: "DefaultService"
return "AWS$sdkId"
}

override fun getClientName(): String {
return ctx.service.getTrait(ServiceTrait::class.java).get().sdkId.toUpperCamelCase().removeSuffix("Service") + "Client"
val serviceTrait = ctx.service.getTrait(ServiceTrait::class.java).orElse(null)
val sdkId = serviceTrait?.sdkId?.toUpperCamelCase()?.removeSuffix("Service") ?: "Default"
return "${sdkId}Client"
}

override fun renderCustomFilePrivateVariables(writer: SwiftWriter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.aws.swift.codegen.protocols.awsquery.AWSQueryProto
import software.amazon.smithy.aws.swift.codegen.protocols.ec2query.EC2QueryProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.protocols.restjson.AWSRestJson1ProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.protocols.restxml.RestXMLProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.protocols.rpcv2cbor.RpcV2CborProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration

Expand All @@ -31,6 +32,7 @@ class AddProtocols : SwiftIntegration {
AWSJSON1_1ProtocolGenerator(),
RestXMLProtocolGenerator(),
AWSQueryProtocolGenerator(),
EC2QueryProtocolGenerator()
EC2QueryProtocolGenerator(),
RpcV2CborProtocolGenerator()
)
}
Loading

0 comments on commit 6aa9ebb

Please sign in to comment.