Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't generate prop if custom prop exists in base #4682

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Linq;
using System.Net;
using System.Text.Json;
using Microsoft.CodeAnalysis;
using Microsoft.Generator.CSharp.ClientModel.Snippets;
using Microsoft.Generator.CSharp.Expressions;
using Microsoft.Generator.CSharp.Input;
Expand Down Expand Up @@ -710,6 +711,19 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
Dictionary<JsonValueKind, List<MethodBodyStatement>> additionalPropsValueKindBodyStatements = [];
var parameters = SerializationConstructor.Signature.Parameters;

// Parse the custom serialization attributes
IEnumerable<AttributeData> serializationAttributes = _model.CustomCodeView?.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName) ?? [];
var baseModelProvider = _model.BaseModelProvider;

while (baseModelProvider != null)
{
serializationAttributes = serializationAttributes
.Concat(baseModelProvider.CustomCodeView?.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName) ?? []);
jorgerangel-msft marked this conversation as resolved.
Show resolved Hide resolved
baseModelProvider = baseModelProvider.BaseModelProvider;
}

// Create each property's deserialization statement
for (int i = 0; i < parameters.Count; i++)
{
Expand All @@ -731,7 +745,7 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var propertySerializationName = wireInfo.SerializedName;
var checkIfJsonPropEqualsName = new IfStatement(jsonProperty.NameEquals(propertySerializationName))
{
DeserializeProperty(property, jsonProperty)
DeserializeProperty(property, jsonProperty, serializationAttributes)
};
propertyDeserializationStatements.Add(checkIfJsonPropEqualsName);
}
Expand All @@ -752,7 +766,7 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var rawBinaryData = _rawDataField;
if (rawBinaryData == null)
{
var baseModelProvider = _model.BaseModelProvider;
baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
{
var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
Expand Down Expand Up @@ -1033,7 +1047,8 @@ private static SwitchStatement CreateDeserializeAdditionalPropsValueKindCheck(

private MethodBodyStatement[] DeserializeProperty(
PropertyProvider property,
ScopedApi<JsonProperty> jsonProperty)
ScopedApi<JsonProperty> jsonProperty,
IEnumerable<AttributeData> serializationAttributes)
{
var serializationFormat = property.WireInfo?.SerializationFormat ?? SerializationFormat.Default;
var propertyVarReference = property.AsVariableExpression;
Expand All @@ -1043,8 +1058,7 @@ private MethodBodyStatement[] DeserializeProperty(
propertyVarReference.Assign(value).Terminate()
};

foreach (var attribute in _model.CustomCodeView?.GetAttributes()
.Where(a => a.AttributeClass?.Name == CodeGenAttributes.CodeGenSerializationAttributeName) ?? [])
foreach (var attribute in serializationAttributes)
{
if (CodeGenAttributes.TryGetCodeGenSerializationAttributeValue(
attribute,
Expand All @@ -1059,6 +1073,7 @@ private MethodBodyStatement[] DeserializeProperty(
deserializationHook,
jsonProperty,
ByRef(propertyVarReference)).Terminate()];
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,52 @@ public async Task CanChangePropertyName()
var expected = Helpers.GetExpectedFromFile();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that if a custom property is added to the base model, and the CodeGenSerialization attribute is used,
// then the derived model includes the custom property in the serialization ctor.
[Test]
public async Task CanSerializeCustomPropertyFromBase()
{
var baseModel = InputFactory.Model(
"baseModel",
usage: InputModelTypeUsage.Input,
properties: [InputFactory.Property("BaseProp", InputPrimitiveType.Int32, isRequired: true)]);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [
InputFactory.Model(
"mockInputModel",
// use Input so that we generate a public ctor
usage: InputModelTypeUsage.Input,
properties:
[
InputFactory.Property("OtherProp", InputPrimitiveType.Int32, isRequired: true),
],
baseModel: baseModel),
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelTypeProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t is ModelProvider && t.Name == "MockInputModel");
Assert.IsNotNull(modelTypeProvider);

var baseModelTypeProvider = (modelTypeProvider as ModelProvider)?.BaseModelProvider;
Assert.IsNotNull(baseModelTypeProvider);
var customCodeView = baseModelTypeProvider!.CustomCodeView;
Assert.IsNotNull(customCodeView);
Assert.IsNull(modelTypeProvider!.CustomCodeView);

Assert.AreEqual(1, baseModelTypeProvider!.Properties.Count);
Assert.AreEqual("BaseProp", baseModelTypeProvider.Properties[0].Name);
Assert.AreEqual(new CSharpType(typeof(int)), baseModelTypeProvider.Properties[0].Type);
Assert.AreEqual(1, customCodeView!.Properties.Count);
Assert.AreEqual("Prop1", customCodeView.Properties[0].Name);

Assert.AreEqual(1, modelTypeProvider.Properties.Count);
Assert.AreEqual("OtherProp", modelTypeProvider.Properties[0].Name);

// the custom property should exist in the full ctor
var fullCtor = modelTypeProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal));
Assert.IsNotNull(fullCtor);
Assert.IsTrue(fullCtor!.Signature.Parameters.Any(p => p.Name == "prop1"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,38 @@ public async Task CanCustomizeSerializationMethodForRenamedProperty()
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that the custom serialization method is used in the serialization provider
// for the custom property that exists in the base model.
[Test]
public async Task CanCustomizeSerializationMethodForPropertyInBase()
{
var baseModel = InputFactory.Model(
"baseModel",
usage: InputModelTypeUsage.Input,
properties: [InputFactory.Property("Prop1", InputPrimitiveType.Int32, isRequired: true)]);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [
InputFactory.Model(
"mockInputModel",
usage: InputModelTypeUsage.Json,
properties:
[
InputFactory.Property("OtherProp", InputPrimitiveType.Int32, isRequired: true),
],
baseModel: baseModel),
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t is ModelProvider);
Assert.IsNotNull(modelProvider);
var serializationProvider = modelProvider!.SerializationProviders.Single(t => t is MrwSerializationTypeDefinition);
Assert.IsNotNull(serializationProvider);

var writer = new TypeProviderWriter(serializationProvider);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Validates that a properties serialization name can be changed using custom code.
[Test]
public async Task CanChangePropertySerializedName()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Sample;
using System;
using System.Collections.Generic;
using Microsoft.Generator.CSharp.Customization;

namespace Sample.Models;

[CodeGenSerialization(nameof(Prop1), DeserializationValueHook = nameof(DeserializationMethod))]
public partial class BaseModel
{
internal string Prop1 { get; set; }

private static void DeserializationMethod(JsonProperty property, ref string fieldValue)
=> fieldValue = property.Value.GetString();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Text.Json;
using Sample;

namespace Sample.Models
{
/// <summary></summary>
public partial class MockInputModel : global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>
{
internal MockInputModel()
{
}

void global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
writer.WriteStartObject();
this.JsonModelWriteCore(writer, options);
writer.WriteEndObject();
}

/// <param name="writer"> The JSON writer. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format.");
}
base.JsonModelWriteCore(writer, options);
writer.WritePropertyName("otherProp"u8);
writer.WriteNumberValue(OtherProp);
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.BaseModel JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{format}' format.");
}
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader);
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options);
}

internal static global::Sample.Models.MockInputModel DeserializeMockInputModel(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null))
{
return null;
}
int otherProp = default;
int prop1 = default;
global::System.Collections.Generic.IDictionary<string, global::System.BinaryData> additionalBinaryDataProperties = new global::Sample.ChangeTrackingDictionary<string, global::System.BinaryData>();
foreach (var prop in element.EnumerateObject())
{
if (prop.NameEquals("otherProp"u8))
{
otherProp = prop.Value.GetInt32();
continue;
}
if (prop.NameEquals("prop1"u8))
{
DeserializationMethod(prop, ref prop1);
continue;
}
if ((options.Format != "W"))
{
additionalBinaryDataProperties.Add(prop.Name, global::System.BinaryData.FromString(prop.Value.GetRawText()));
}
}
return new global::Sample.Models.MockInputModel(otherProp, prop1, additionalBinaryDataProperties);
}

global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options);

/// <param name="options"> The client options for reading and writing models. </param>
protected override global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options);
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{options.Format}' format.");
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.BaseModel PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data))
{
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options);
}
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{options.Format}' format.");
}
}

string global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J";

/// <param name="mockInputModel"> The <see cref="global::Sample.Models.MockInputModel"/> to serialize into <see cref="global::System.ClientModel.BinaryContent"/>. </param>
public static implicit operator BinaryContent(global::Sample.Models.MockInputModel mockInputModel)
{
return global::System.ClientModel.BinaryContent.Create(mockInputModel, global::Sample.ModelSerializationExtensions.WireOptions);
}

/// <param name="result"> The <see cref="global::System.ClientModel.ClientResult"/> to deserialize the <see cref="global::Sample.Models.MockInputModel"/> from. </param>
public static explicit operator MockInputModel(global::System.ClientModel.ClientResult result)
{
using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse();
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content);
return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

using Microsoft.Generator.CSharp.Customization;

namespace Sample.Models
{
[CodeGenSerialization(nameof(Prop1), SerializationValueHook = nameof(SerializationMethod), DeserializationValueHook = nameof(DeserializationMethod))]
public partial class BaseModel
{
private void SerializationMethod(Utf8JsonWriter writer, ModelReaderWriterOptions options)
=> writer.WriteObjectValue(Prop1, options);

private static void DeserializationMethod(JsonProperty property, ref string fieldValue)
=> fieldValue = property.Value.GetString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private IReadOnlyList<ModelProvider> BuildDerivedModels()

return [.. derivedModels];
}
internal override TypeProvider? BaseTypeProvider => BaseModelProvider;

public ModelProvider? BaseModelProvider
=> _baseModelProvider ??= (_baseTypeProvider?.Value is ModelProvider baseModelProvider ? baseModelProvider : null);
Expand Down
Loading
Loading