Skip to content

Commit

Permalink
Add Eq for generated types that implement PartialEq
Browse files Browse the repository at this point in the history
  • Loading branch information
jjant committed Oct 7, 2022
1 parent 3da634a commit 5169bd9
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ class StreamingShapeMetadataProvider(
override fun structureMeta(structureShape: StructureShape): RustMetadata {
val baseMetadata = base.toSymbol(structureShape).expectRustMetadata()
return if (structureShape.hasStreamingMember(model)) {
baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq)
baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq, RuntimeType.Eq)
} else baseMetadata
}

override fun unionMeta(unionShape: UnionShape): RustMetadata {
val baseMetadata = base.toSymbol(unionShape).expectRustMetadata()
return if (unionShape.hasStreamingMember(model)) {
baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq)
baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq, RuntimeType.Eq)
} else baseMetadata
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
val From = RuntimeType("From", dependency = null, namespace = "std::convert")
val TryFrom = RuntimeType("TryFrom", dependency = null, namespace = "std::convert")
val PartialEq = std.member("cmp::PartialEq")
val Eq = std.member("cmp::Eq")
val StdError = RuntimeType("Error", dependency = null, namespace = "std::error")
val String = RuntimeType("String", dependency = null, namespace = "std::string")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class BaseSymbolMetadataProvider(
companion object {
private val defaultDerives by lazy {
with(RuntimeType) {
listOf(Debug, PartialEq, Clone)
listOf(Debug, PartialEq, Eq, Clone)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class BuilderGenerator(
writer.docs("A builder for #D.", structureSymbol)
// Matching derives to the main structure + `Default` since we are a builder and everything is optional.
val baseDerives = structureSymbol.expectRustMetadata().derives
val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)) + RuntimeType.Default
val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Eq, RuntimeType.Clone)) + RuntimeType.Default
baseDerives.copy(derives = derives).render(writer)
writer.rustBlock("pub struct $builderName") {
for (member in members) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider,
override fun structureMeta(structureShape: StructureShape): RustMetadata {
val baseMetadata = base.toSymbol(structureShape).expectRustMetadata()
return if (structureShape.hasStreamingMember(model)) {
baseMetadata.withoutDerives(RuntimeType.PartialEq)
baseMetadata.withoutDerives(RuntimeType.PartialEq, RuntimeType.Eq)
} else baseMetadata
}

override fun unionMeta(unionShape: UnionShape): RustMetadata {
val baseMetadata = base.toSymbol(unionShape).expectRustMetadata()
return if (unionShape.hasStreamingMember(model)) {
baseMetadata.withoutDerives(RuntimeType.PartialEq)
baseMetadata.withoutDerives(RuntimeType.PartialEq, RuntimeType.Eq)
} else baseMetadata
}

Expand Down

0 comments on commit 5169bd9

Please sign in to comment.