Skip to content

Commit

Permalink
Cleanup module handling & add support for Cargo features (#253)
Browse files Browse the repository at this point in the history
* Cleanup module handling & add support for Cargo features

* Fix AWS tests

* Set optional in the Cargo toml
  • Loading branch information
rcoh authored Mar 16, 2021
1 parent 90f116c commit 5fde528
Show file tree
Hide file tree
Showing 19 changed files with 158 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ internal class EndpointConfigCustomizationTest {
@Test
fun `write an endpoint into the config`() {
val project = stubConfigProject(EndpointConfigCustomization(TestRuntimeConfig, model.lookup("test#TestService")))
project.useFileWriter("src/lib.rs", "crate") {
project.lib {
it.addDependency(awsTypes(TestRuntimeConfig))
it.addDependency(CargoDependency.Http)
it.unitTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal class SigV4SigningCustomizationTest {
@Test
fun `generates a valid config`() {
val project = stubConfigProject(SigV4SigningConfig(SigV4Trait.builder().name("test-service").build()))
project.useFileWriter("src/lib.rs", "crate") {
project.lib {
it.unitTest(
"""
let conf = crate::config::Config::builder().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,16 @@ class InlineDependency(
fun CargoDependency.asType(): RuntimeType =
RuntimeType(null, dependency = this, namespace = this.name.replace("-", "_"))

data class Feature(val name: String, val default: Boolean, val deps: List<String>)

/**
* A dependency on an internal or external Cargo Crate
*/
data class CargoDependency(
override val name: String,
private val location: DependencyLocation,
val scope: DependencyScope = DependencyScope.Compile,
val optional: Boolean = false,
private val features: List<String> = listOf()
) : RustDependency(name) {

Expand All @@ -137,6 +140,9 @@ data class CargoDependency(
attribs["features"] = this
}
}
if (optional) {
attribs["optional"] = true
}
return attribs
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ data class RustModule(val name: String, val rustMetadata: RustMetadata) {
}*/
return RustModule(name, RustMetadata(public = public))
}

val Config = default("config", public = true)
val Error = default("error", public = true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.rustlang

import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq

/**
* A hierarchy of types handled by Smithy codegen
Expand Down Expand Up @@ -149,14 +150,15 @@ inline fun <reified T : RustType.Container> RustType.stripOuter(): RustType {
* Meta information about a Rust construction (field, struct, or enum)
*/
data class RustMetadata(
val derives: Derives = Derives.Empty,
val derives: Attribute.Derives = Attribute.Derives.Empty,
val additionalAttributes: List<Attribute> = listOf(),
val public: Boolean
) {
fun withDerives(vararg newDerive: RuntimeType): RustMetadata =
this.copy(derives = derives.copy(derives = derives.derives + newDerive))

fun attributes(): List<Attribute> = additionalAttributes + derives
private fun attributes(): List<Attribute> = additionalAttributes + derives

fun renderAttributes(writer: RustWriter): RustMetadata {
attributes().forEach {
it.render(writer)
Expand Down Expand Up @@ -201,33 +203,40 @@ sealed class Attribute {
val NonExhaustive = Custom("non_exhaustive")
val AllowUnused = Custom("allow(dead_code)")
}
}

data class Derives(val derives: Set<RuntimeType>) : Attribute() {
override fun render(writer: RustWriter) {
if (derives.isEmpty()) {
return
data class Derives(val derives: Set<RuntimeType>) : Attribute() {
override fun render(writer: RustWriter) {
if (derives.isEmpty()) {
return
}
writer.raw("#[derive(")
derives.sortedBy { it.name }.forEach { derive ->
writer.writeInline("#T, ", derive)
}
writer.write(")]")
}
writer.raw("#[derive(")
derives.sortedBy { it.name }.forEach { derive ->
writer.writeInline("#T, ", derive)

companion object {
val Empty = Derives(setOf())
}
writer.write(")]")
}

companion object {
val Empty = Derives(setOf())
data class Custom(val annotation: String, val symbols: List<RuntimeType> = listOf()) : Attribute() {
override fun render(writer: RustWriter) {
writer.raw("#[$annotation]")
symbols.forEach {
writer.addDependency(it.dependency)
}
}
}
}

data class Custom(val annot: String, val symbols: List<RuntimeType> = listOf()) : Attribute() {
override fun render(writer: RustWriter) {
writer.raw("#[")
writer.writeInline(annot)
writer.write("]")
data class Cfg(val cond: String) : Attribute() {
override fun render(writer: RustWriter) {
writer.raw("#[cfg($cond)]")
}

symbols.forEach {
writer.addDependency(it.dependency)
companion object {
fun feature(feature: String) = Cfg("feature = ${feature.dq()}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

package software.amazon.smithy.rust.codegen.smithy

import software.amazon.smithy.build.FileManifest
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Feature
import software.amazon.smithy.rust.codegen.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.rustlang.RustDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
Expand All @@ -15,58 +19,90 @@ import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsGenerator

private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() }
open class RustCrate(
fileManifest: FileManifest,
symbolProvider: SymbolProvider,
baseModules: Map<String, RustModule>
) {
private val inner = CodegenWriterDelegator(fileManifest, symbolProvider, RustWriter.Factory)
private val modules: MutableMap<String, RustModule> = baseModules.toMutableMap()
private val features: MutableSet<Feature> = mutableSetOf()
fun useShapeWriter(shape: Shape, f: (RustWriter) -> Unit) {
inner.useShapeWriter(shape, f)
}

fun lib(moduleWriter: (RustWriter) -> Unit) {
inner.useFileWriter("src/lib.rs", "crate", moduleWriter)
}

fun addFeature(feature: Feature) = this.features.add(feature)

fun finalize(settings: RustSettings, libRsCustomizations: List<LibRsCustomization>) {
injectInlineDependencies()
val modules = inner.writers.values.mapNotNull { it.module() }.filter { it != "lib" }
.map { modules[it] ?: RustModule.default(it, false) }
inner.finalize(settings, libRsCustomizations, modules, this.features.toList())
}

private fun injectInlineDependencies() {
val unloadedDepdencies = {
this
.inner.dependencies
.map { dep -> RustDependency.fromSymbolDependency(dep) }
.filterIsInstance<InlineDependency>().distinctBy { it.key() }
.filter { !modules.contains(it.module) }
}
while (unloadedDepdencies().isNotEmpty()) {
unloadedDepdencies().forEach { dep ->
this.withModule(RustModule.default(dep.module, false)) {
dep.renderer(it)
}
}
}
}

fun withModule(
module: RustModule,
moduleWriter: (RustWriter) -> Unit
): RustCrate {
val moduleName = module.name
modules[moduleName] = module
inner.useFileWriter("src/$moduleName.rs", "crate::$moduleName", moduleWriter)
return this
}
}

// TODO: this should _probably_ be configurable via RustSettings; 2h
/**
* Allowlist of modules that will be exposed publicly in generated crates
*/
private val PublicModules = setOf("error", "operation", "model", "input", "output", "config")
val DefaultPublicModules =
setOf("error", "operation", "model", "input", "output", "config").map { it to RustModule.default(it, true) }.toMap()

/**
* Finalize all the writers by:
* - inlining inline dependencies that have been used
* - generating (and writing) a Cargo.toml based on the settings & the required dependencies
*/
fun CodegenWriterDelegator<RustWriter>.finalize(settings: RustSettings, libRsCustomizations: List<LibRsCustomization>) {
val loadDependencies = { this.dependencies.map { dep -> RustDependency.fromSymbolDependency(dep) } }
val inlineDependencies = loadDependencies().filterIsInstance<InlineDependency>().distinctBy { it.key() }
inlineDependencies.forEach { dep ->
this.useFileWriter("src/${dep.module}.rs", "crate::${dep.module}") {
dep.renderer(it)
}
}
val newDeps = loadDependencies().filterIsInstance<InlineDependency>().distinctBy { it.key() }
newDeps.forEach { dep ->
if (!this.writers.containsKey("src/${dep.module}.rs")) {
this.useFileWriter("src/${dep.module}.rs", "crate::${dep.module}") {
dep.renderer(it)
}
}
}
fun CodegenWriterDelegator<RustWriter>.finalize(
settings: RustSettings,
libRsCustomizations: List<LibRsCustomization>,
modules: List<RustModule>,
features: List<Feature>
) {
this.useFileWriter("src/lib.rs", "crate::lib") { writer ->
val includedModules = this.includedModules().toSet().filter { it != "lib" }
val modules = includedModules.map { moduleName ->
RustModule.default(moduleName, PublicModules.contains(moduleName))
}
LibRsGenerator(settings.moduleDescription, modules, libRsCustomizations).render(writer)
}
val cargoDependencies = loadDependencies().filterIsInstance<CargoDependency>().distinct()
val cargoDependencies =
this.dependencies.map { RustDependency.fromSymbolDependency(it) }.filterIsInstance<CargoDependency>().distinct()
this.useFileWriter("Cargo.toml") {
val cargoToml = CargoTomlGenerator(
settings,
it,
cargoDependencies,
features
)
cargoToml.render()
}
flushWriters()
}

fun CodegenWriterDelegator<RustWriter>.withModule(
moduleName: String,
moduleWriter: RustWriter.() -> Unit
): CodegenWriterDelegator<RustWriter> {
this.useFileWriter("src/$moduleName.rs", "crate::$moduleName", moduleWriter)
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package software.amazon.smithy.rust.codegen.smithy

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.shapes.ServiceShape
Expand All @@ -16,7 +15,6 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.CrateVersionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
Expand All @@ -43,7 +41,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
private val settings = RustSettings.from(context.model, context.settings)

private val symbolProvider: RustSymbolProvider
private val writers: CodegenWriterDelegator<RustWriter>
private val rustCrate: RustCrate
private val fileManifest = context.fileManifest
private val model: Model
private val protocolConfig: ProtocolConfig
Expand All @@ -68,10 +66,10 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

protocolConfig =
ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol, settings.moduleName)
writers = CodegenWriterDelegator(
rustCrate = RustCrate(
context.fileManifest,
symbolProvider,
RustWriter.Factory
DefaultPublicModules
)
httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
}
Expand All @@ -84,7 +82,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
val serviceShapes = Walker(model).walkShapes(service)
serviceShapes.forEach { it.accept(this) }
// TODO: if we end up with a lot of these on-by-default customizations, we may want to refactor them somewhere
writers.finalize(
rustCrate.finalize(
settings,
codegenDecorator.libRsCustomizations(
protocolConfig,
Expand All @@ -104,7 +102,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

override fun structureShape(shape: StructureShape) {
logger.fine("generating a structure...")
writers.useShapeWriter(shape) { writer ->
rustCrate.useShapeWriter(shape) { writer ->
StructureGenerator(model, symbolProvider, writer, shape).render()
if (!shape.hasTrait(SyntheticInputTrait::class.java)) {
val builderGenerator = ModelBuilderGenerator(protocolConfig.model, protocolConfig.symbolProvider, shape)
Expand All @@ -118,19 +116,19 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

override fun stringShape(shape: StringShape) {
shape.getTrait(EnumTrait::class.java).map { enum ->
writers.useShapeWriter(shape) { writer ->
rustCrate.useShapeWriter(shape) { writer ->
EnumGenerator(symbolProvider, writer, shape, enum).render()
}
}
}

override fun unionShape(shape: UnionShape) {
writers.useShapeWriter(shape) {
rustCrate.useShapeWriter(shape) {
UnionGenerator(model, symbolProvider, it, shape).render()
}
}

override fun serviceShape(shape: ServiceShape) {
ServiceGenerator(writers, httpGenerator, protocolGenerator.support(), protocolConfig, codegenDecorator).render()
ServiceGenerator(rustCrate, httpGenerator, protocolGenerator.support(), protocolConfig, codegenDecorator).render()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhaustive
import software.amazon.smithy.rust.codegen.rustlang.Derives
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata

/**
Expand Down Expand Up @@ -62,7 +62,7 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr

class BaseSymbolMetadataProvider(base: RustSymbolProvider) : SymbolMetadataProvider(base) {
private val containerDefault = RustMetadata(
Derives(defaultDerives.toSet()),
Attribute.Derives(defaultDerives.toSet()),
additionalAttributes = listOf(NonExhaustive),
public = true
)
Expand Down
Loading

0 comments on commit 5fde528

Please sign in to comment.