Skip to content

Commit

Permalink
Update serde implementation to support out of range floats (#3825)
Browse files Browse the repository at this point in the history
## Motivation and Context
Fix serde behavior to match generated code. This is important to avoid
loosing data during serialization, especially as out-of-range floats
often indicate an error.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
rcoh authored Sep 30, 2024
1 parent d8fbf47 commit 12cf916
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
Expand Down Expand Up @@ -208,13 +210,45 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
* For enums, it adds `as_str()` to convert it into a string directly.
*/
private fun serializeNumber(shape: NumberShape): RuntimeType {
val numericType = SimpleShapes.getValue(shape::class)
return when (shape) {
is FloatShape, is DoubleShape -> serializeFloat(shape)
else ->
RuntimeType.forInlineFun(
numericType.toString(),
PrimitiveShapesModule,
) {
implSerializeConfigured(symbolBuilder(shape, numericType).build()) {
rustTemplate("self.value.serialize(serializer)")
}
}
}
}

private fun serializeFloat(shape: NumberShape): RuntimeType {
val numericType = SimpleShapes.getValue(shape::class)
return RuntimeType.forInlineFun(
numericType.toString(),
PrimitiveShapesModule,
) {
implSerializeConfigured(symbolBuilder(shape, numericType).build()) {
rustTemplate("self.value.serialize(serializer)")
rustTemplate(
"""
if !self.settings.out_of_range_floats_as_strings {
return self.value.serialize(serializer)
}
if self.value.is_nan() {
serializer.serialize_str("NaN")
} else if *self.value == #{ty}::INFINITY {
serializer.serialize_str("Infinity")
} else if *self.value == #{ty}::NEG_INFINITY {
serializer.serialize_str("-Infinity")
} else {
self.value.serialize(serializer)
}
""",
"ty" to numericType,
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object SupportStructures {
{
use #{serde}::Serialize;
value
.serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: true })
.serialize_ref(&#{SerializationSettings}::redact_sensitive_fields())
.serialize(serializer)
}
""",
Expand All @@ -70,7 +70,7 @@ object SupportStructures {
{
use #{serde}::Serialize;
value
.serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: false })
.serialize_ref(&#{SerializationSettings}::leak_sensitive_fields())
.serialize(serializer)
}
""",
Expand Down Expand Up @@ -211,7 +211,6 @@ object SupportStructures {

private fun serializationSettings() =
RuntimeType.forInlineFun("SerializationSettings", supportModule) {
// TODO(serde): Consider removing `derive(Default)`
rustTemplate(
"""
/// Settings for use when serializing structures
Expand All @@ -220,17 +219,23 @@ object SupportStructures {
pub struct SerializationSettings {
/// Replace all sensitive fields with `<redacted>` during serialization
pub redact_sensitive_fields: bool,
/// Serialize Nan, infinity and negative infinity as strings.
///
/// For protocols like JSON, this avoids the loss-of-information that occurs when these out-of-range values
/// are serialized as null.
pub out_of_range_floats_as_strings: bool,
}
impl SerializationSettings {
/// Replace all `@sensitive` fields with `<redacted>` when serializing.
///
/// Note: This may alter the type of the serialized output and make it impossible to deserialize as
/// numerical fields will be replaced with strings.
pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true } }
pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true, out_of_range_floats_as_strings: false } }
/// Preserve the contents of sensitive fields during serializing
pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false } }
pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false, out_of_range_floats_as_strings: false } }
}
""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ class SerdeDecoratorTest {
blob: SensitiveBlob,
constrained: Constrained,
recursive: Recursive,
map: EnumKeyedMap
map: EnumKeyedMap,
float: Float,
double: Double
}
structure Constrained {
Expand Down Expand Up @@ -134,6 +136,8 @@ class SerdeDecoratorTest {
structure Nested {
@required
int: Integer,
float: Float,
double: Double,
sensitive: Timestamps,
notSensitive: AlsoTimestamps,
manyEnums: TestEnumList,
Expand Down Expand Up @@ -202,8 +206,12 @@ class SerdeDecoratorTest {
.e(Some(TestEnum::A))
.document(Some(Document::String("hello!".into())))
.blob(Some(Blob::new("hello")))
.float(Some(f32::INFINITY))
.double(Some(f64::NAN))
.nested(Some(Nested::builder()
.int(5)
.float(Some(f32::NEG_INFINITY))
.double(Some(f64::NEG_INFINITY))
.sensitive(Some(sensitive_map.clone()))
.not_sensitive(Some(sensitive_map))
.many_enums(Some(vec![TestEnum::A]))
Expand Down Expand Up @@ -274,6 +282,8 @@ class SerdeDecoratorTest {
"e": "A",
"nested": {
"int": 5,
"float": "-Infinity",
"double": "-Infinity",
"sensitive": {
"a": "1970-01-01T00:00:00Z"
},
Expand All @@ -289,7 +299,9 @@ class SerdeDecoratorTest {
"enum": "B"
},
"document": "hello!",
"blob": "aGVsbG8="
"blob": "aGVsbG8=",
"float": "Infinity",
"double": "NaN"
}""".replace("\\s".toRegex(), "")

private val expectedRedacted =
Expand All @@ -298,6 +310,8 @@ class SerdeDecoratorTest {
"e": "<redacted>",
"nested": {
"int": 5,
"float": "-Infinity",
"double": "-Infinity",
"sensitive": {
"a": "<redacted>"
},
Expand All @@ -311,7 +325,9 @@ class SerdeDecoratorTest {
},
"union": "<redacted>",
"document": "hello!",
"blob": "<redacted>"
"blob": "<redacted>",
"float": "Infinity",
"double": "NaN"
}
""".replace("\\s".toRegex(), "")

Expand Down Expand Up @@ -343,8 +359,12 @@ class SerdeDecoratorTest {
.e("A".into())
.document(Document::String("hello!".into()))
.blob(Blob::new("hello"))
.float(f32::INFINITY)
.double(f64::NAN)
.nested(Nested::builder()
.int(5)
.float(f32::NEG_INFINITY)
.double(f64::NEG_INFINITY)
.sensitive("a", DateTime::from(UNIX_EPOCH))
.not_sensitive("a", DateTime::from(UNIX_EPOCH))
.many_enums("A".into())
Expand All @@ -355,11 +375,15 @@ class SerdeDecoratorTest {
.build()
.unwrap();
let mut settings = #{crate}::serde::SerializationSettings::default();
settings.out_of_range_floats_as_strings = true;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_eq!(serialized, ${expectedNoRedactions.dq()});
settings.redact_sensitive_fields = true;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_eq!(serialized, ${expectedRedacted.dq()});
settings.out_of_range_floats_as_strings = false;
let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
assert_ne!(serialized, ${expectedRedacted.dq()});
""",
*codegenScope,
)
Expand Down

0 comments on commit 12cf916

Please sign in to comment.