Skip to content

Commit

Permalink
Support Constraint Traits & Sparse collections of collections
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Jul 26, 2024
1 parent 0e2f994 commit 2296ab9
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ class StructureGeneratorTest {
ts: Timestamp,
inner: Inner,
byteValue: Byte,
sensitiveMap: SensitiveMap
}
map SensitiveMap {
key: String
value: Password
}
// Intentionally empty
Expand Down
1 change: 1 addition & 0 deletions codegen-serde/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ if (isTestingEnabled.toBoolean()) {
runtimeOnly(project(":rust-runtime"))
testImplementation("org.junit.jupiter:junit-jupiter:5.6.1")
testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
testImplementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,26 @@ package software.amazon.smithy.rust.codegen.serde

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.plus
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand All @@ -28,6 +35,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
Expand All @@ -44,6 +52,8 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTrait

class SerializeImplGenerator(private val codegenContext: CodegenContext) {
private val model = codegenContext.model

fun generateRootSerializerForShape(shape: Shape): Writable = serializerFn(shape, null)

/**
Expand All @@ -59,17 +69,32 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
shape: Shape,
applyTo: Writable?,
): Writable {
if (shape is ServiceShape) {
return shape.operations.map { serializerFn(model.expectShape(it), null) }.join("\n")
} else if (shape is OperationShape) {
return writable {
serializerFn(model.expectShape(shape.inputShape), null)(this)
serializerFn(model.expectShape(shape.outputShape), null)(this)
}
}
val name = codegenContext.symbolProvider.shapeFunctionName(codegenContext.serviceShape, shape) + "_serde"
val deps =
when (shape) {
is StructureShape -> RuntimeType.forInlineFun(name, SerdeModule, structSerdeImpl(shape))
is UnionShape -> RuntimeType.forInlineFun(name, SerdeModule, serializeUnionImpl(shape))
is TimestampShape -> serializeDateTime(shape)
is BlobShape -> serializeBlob(shape)
is StringShape, is NumberShape -> directSerde(shape)
is BlobShape ->
if (shape.hasTrait<StreamingTrait>()) {
serializeByteStream(shape)
} else {
serializeBlob(shape)
}

is StringShape, is NumberShape, is BooleanShape -> directSerde(shape)
is DocumentShape -> serializeDocument(shape)
else -> null
}

return writable {
val wrapper =
when {
Expand All @@ -82,10 +107,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
if (wrapper != null && applyTo != null) {
rustTemplate(
"&#{wrapper}(#{applyTo})", "wrapper" to wrapper,
"applyTo" to
applyTo.letIf(shape.hasConstraintTrait()) {
it.plus { rust("") }
},
"applyTo" to applyTo,
)
} else {
deps?.toSymbol().also { addDependency(it) }
Expand All @@ -96,38 +118,63 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {

private fun serializeMap(shape: MapShape): RuntimeType =
serializeWithWrapper(shape) { value ->
val member = serializeMember(shape.value, "v")
val writeEntry =
writable {
when (shape.hasTrait<SparseTrait>()) {
true ->
rust(
"""
match v {
Some(v) => map.serialize_entry(k, #T)?,
None => map.serialize_entry(k, &None::<usize>)?
};
""",
member,
)
false -> rust("map.serialize_entry(k, #T)?;", member)
}
}
writable {
rustTemplate(
"""
use #{serde}::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(#{value}.len()))?;
for (k, v) in #{value}.iter() {
map.serialize_entry(k, #{member})?;
#{writeEntry}
}
map.end()
""",
*SupportStructures.codegenScope,
"value" to value,
"member" to serializeMember(shape.value, "v"),
"writeEntry" to writeEntry,
)
}
}

private fun serializeList(shape: CollectionShape): RuntimeType =
serializeWithWrapper(shape) { value ->
val member = serializeMember(shape.member, "v")
val serializeElement =
writable {
when (shape.hasTrait<SparseTrait>()) {
false -> rust("seq.serialize_element(#T)?;", member)
true -> rust("match v { Some(v) => seq.serialize_element(#T)?, None => seq.serialize_element(&None::<usize>)? };", member)
}
}
writable {
rustTemplate(
"""
use #{serde}::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(#{value}.len()))?;
for v in #{value}.iter() {
seq.serialize_element(#{member})?;
#{element}
}
seq.end()
""",
*SupportStructures.codegenScope,
"element" to serializeElement,
"value" to value,
"member" to serializeMember(shape.member, "v"),
)
}
}
Expand All @@ -136,7 +183,10 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
* Serialize a type that already implements `Serialize` directly via `value.serialize(serializer)`
*/
private fun directSerde(shape: Shape): RuntimeType {
return RuntimeType.forInlineFun(codegenContext.symbolProvider.toSymbol(shape).rustType().toString(), SerdeModule) {
return RuntimeType.forInlineFun(
codegenContext.symbolProvider.toSymbol(shape).rustType().toString(),
SerdeModule,
) {
implSerializeConfigured(codegenContext.symbolProvider.toSymbol(shape)) {
val baseValue =
writable { rust("self.value") }.letIf(shape.isStringShape) { it.plus(writable(".as_str()")) }
Expand Down Expand Up @@ -179,7 +229,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
// awkward hack to recover the symbol referring to the type
val type = SerdeModule.toType().resolve(name).toSymbol()
val base =
writable { rust("self.value.0") }.letIf(shape.hasConstraintTrait()) {
writable { rust("self.value.0") }.letIf(shape.hasConstraintTrait() && codegenContext.target == CodegenTarget.SERVER) {
it.plus { rust(".0") }
}
val serialization =
Expand All @@ -201,7 +251,7 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
}

private fun Shape.unwrapConstraints(): Boolean =
hasConstraintTrait() && (isBlobShape || isTimestampShape || isDocumentShape)
codegenContext.target == CodegenTarget.SERVER && hasConstraintTrait() && (isBlobShape || isTimestampShape || isDocumentShape)

/**
* Serialize the field of a structure, union, list or map.
Expand Down Expand Up @@ -240,27 +290,30 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
""",
*SupportStructures.codegenScope,
)
Attribute.AllowUnusedMut.render(this)
rust(
"let mut s = serializer.serialize_struct(${
shape.contextName(codegenContext.serviceShape).dq()
}, ${shape.members().size})?;",
)
rust("let inner = &self.value;")
for (member in shape.members()) {
val serializedName = member.memberName.dq()
val fieldName = codegenContext.symbolProvider.toMemberName(member)
val field = safeName("member")
val fieldSerialization =
writable {
rustTemplate(
"s.serialize_field($serializedName, #{member})?;",
"member" to serializeMember(member, field),
)
if (!shape.members().isEmpty()) {
rust("let inner = &self.value;")
for (member in shape.members()) {
val serializedName = member.memberName.dq()
val fieldName = codegenContext.symbolProvider.toMemberName(member)
val field = safeName("member")
val fieldSerialization =
writable {
rustTemplate(
"s.serialize_field($serializedName, #{member})?;",
"member" to serializeMember(member, field),
)
}
if (codegenContext.symbolProvider.toSymbol(member).isOptional()) {
rust("if let Some($field) = &inner.$fieldName { #T }", fieldSerialization)
} else {
rust("let $field = &inner.$fieldName; #T", fieldSerialization)
}
if (codegenContext.symbolProvider.toSymbol(member).isOptional()) {
rust("if let Some($field) = &inner.$fieldName { #T }", fieldSerialization)
} else {
rust("let $field = &inner.$fieldName; #T", fieldSerialization)
}
}
rust("s.end()")
Expand All @@ -285,10 +338,14 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
"${symbolProvider.toMemberName(member)}(inner)"
}
withBlock("#T::$variantName => {", "},", symbolProvider.toSymbol(shape)) {
rustTemplate(
"serializer.serialize_newtype_variant(${unionName.dq()}, $index, $fieldName, #{member})",
"member" to serializeMember(member, "inner"),
)
when (member.isTargetUnit()) {
true -> rust("serializer.serialize_unit_variant(${unionName.dq()}, $index, $fieldName)")
false ->
rustTemplate(
"serializer.serialize_newtype_variant(${unionName.dq()}, $index, $fieldName, #{member})",
"member" to serializeMember(member, "inner"),
)
}
}
}
if (codegenContext.target.renderUnknownVariant()) {
Expand Down Expand Up @@ -325,6 +382,26 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
}
}

private fun serializeByteStream(shape: BlobShape): RuntimeType =
RuntimeType.forInlineFun("SerializeByteStream", SerdeModule) {
implSerializeConfigured(RuntimeType.byteStream(codegenContext.runtimeConfig).toSymbol()) {
// This doesn't work yet—there is no way to get data out of a ByteStream from a sync context
rustTemplate(
"""
let Some(bytes) = self.value.bytes() else {
return serializer.serialize_str("streaming data")
};
if serializer.is_human_readable() {
serializer.serialize_str(&#{base64_encode}(bytes))
} else {
serializer.serialize_bytes(&bytes)
}
""",
"base64_encode" to RuntimeType.base64Encode(codegenContext.runtimeConfig),
)
}
}

private fun serializeDocument(shape: DocumentShape): RuntimeType =
RuntimeType.forInlineFun("SerializeDocument", SerdeModule) {
implSerializeConfigured(codegenContext.symbolProvider.toSymbol(shape)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import software.amazon.smithy.model.traits.AbstractTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.core.util.orNull

class SerdeTrait private constructor(
class SerdeTrait constructor(
private val serialize: Boolean,
private val deserialize: Boolean,
private val tag: String?,
Expand Down
Loading

0 comments on commit 2296ab9

Please sign in to comment.