diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosElementToSqlScalarExpressionVisitor.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosElementToSqlScalarExpressionVisitor.cs new file mode 100644 index 0000000000..68232330de --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosElementToSqlScalarExpressionVisitor.cs @@ -0,0 +1,93 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Diagnostics; + using Microsoft.Azure.Cosmos.CosmosElements; + using Microsoft.Azure.Cosmos.CosmosElements.Numbers; + using Microsoft.Azure.Cosmos.SqlObjects; + + internal sealed class CosmosElementToSqlScalarExpressionVisitor : ICosmosElementVisitor + { + public static readonly CosmosElementToSqlScalarExpressionVisitor Singleton = new CosmosElementToSqlScalarExpressionVisitor(); + + private CosmosElementToSqlScalarExpressionVisitor() + { + // Private constructor, since this class is a singleton. + } + + public SqlScalarExpression Visit(CosmosArray cosmosArray) + { + List items = new List(); + foreach (CosmosElement item in cosmosArray) + { + items.Add(item.Accept(this)); + } + + return SqlArrayCreateScalarExpression.Create(items.ToImmutableArray()); + } + + public SqlScalarExpression Visit(CosmosBinary cosmosBinary) + { + // Can not convert binary to scalar expression without knowing the API type. + Debug.Fail("CosmosElementToSqlScalarExpressionVisitor Assert", "Unreachable"); + throw new InvalidOperationException(); + } + + public SqlScalarExpression Visit(CosmosBoolean cosmosBoolean) + { + return SqlLiteralScalarExpression.Create(SqlBooleanLiteral.Create(cosmosBoolean.Value)); + } + + public SqlScalarExpression Visit(CosmosGuid cosmosGuid) + { + // Can not convert guid to scalar expression without knowing the API type. + Debug.Fail("CosmosElementToSqlScalarExpressionVisitor Assert", "Unreachable"); + throw new InvalidOperationException(); + } + + public SqlScalarExpression Visit(CosmosNull cosmosNull) + { + return SqlLiteralScalarExpression.Create(SqlNullLiteral.Create()); + } + + public SqlScalarExpression Visit(CosmosNumber cosmosNumber) + { + if (!(cosmosNumber is CosmosNumber64 cosmosNumber64)) + { + throw new ArgumentException($"Unknown {nameof(CosmosNumber)} type: {cosmosNumber.GetType()}."); + } + + return SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(cosmosNumber64.GetValue())); + } + + public SqlScalarExpression Visit(CosmosObject cosmosObject) + { + List properties = new List(); + foreach (KeyValuePair prop in cosmosObject) + { + SqlPropertyName name = SqlPropertyName.Create(prop.Key); + CosmosElement value = prop.Value; + SqlScalarExpression expression = value.Accept(this); + SqlObjectProperty property = SqlObjectProperty.Create(name, expression); + properties.Add(property); + } + + return SqlObjectCreateScalarExpression.Create(properties.ToImmutableArray()); + } + + public SqlScalarExpression Visit(CosmosString cosmosString) + { + return SqlLiteralScalarExpression.Create(SqlStringLiteral.Create(cosmosString.Value)); + } + + public SqlScalarExpression Visit(CosmosUndefined cosmosUndefined) + { + return SqlLiteralScalarExpression.Create(SqlUndefinedLiteral.Create()); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/DefaultCosmosLinqSerializer.cs b/Microsoft.Azure.Cosmos/src/Linq/DefaultCosmosLinqSerializer.cs new file mode 100644 index 0000000000..024c33fab0 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Linq/DefaultCosmosLinqSerializer.cs @@ -0,0 +1,110 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Diagnostics; + using System.Linq; + using System.Linq.Expressions; + using System.Reflection; + using System.Runtime.Serialization; + using Microsoft.Azure.Documents; + using Newtonsoft.Json; + + internal class DefaultCosmosLinqSerializer : ICosmosLinqSerializer + { + public bool RequiresCustomSerialization(MemberExpression memberExpression, Type memberType) + { + // There are two ways to specify a custom attribute + // 1- by specifying the JsonConverterAttribute on a Class/Enum + // [JsonConverter(typeof(StringEnumConverter))] + // Enum MyEnum + // { + // ... + // } + // + // 2- by specifying the JsonConverterAttribute on a property + // class MyClass + // { + // [JsonConverter(typeof(StringEnumConverter))] + // public MyEnum MyEnum; + // } + // + // Newtonsoft gives high precedence to the attribute specified + // on a property over on a type (class/enum) + // so we check both attributes and apply the same precedence rules + // JsonConverterAttribute doesn't allow duplicates so it's safe to + // use FirstOrDefault() + CustomAttributeData memberAttribute = memberExpression.Member.CustomAttributes.FirstOrDefault(ca => ca.AttributeType == typeof(Newtonsoft.Json.JsonConverterAttribute)); + CustomAttributeData typeAttribute = memberType.GetsCustomAttributes().FirstOrDefault(ca => ca.AttributeType == typeof(Newtonsoft.Json.JsonConverterAttribute)); + + return memberAttribute != null || typeAttribute != null; + } + + public string Serialize(object value, MemberExpression memberExpression, Type memberType) + { + CustomAttributeData memberAttribute = memberExpression.Member.CustomAttributes.FirstOrDefault(ca => ca.AttributeType == typeof(Newtonsoft.Json.JsonConverterAttribute)); + CustomAttributeData typeAttribute = memberType.GetsCustomAttributes().FirstOrDefault(ca => ca.AttributeType == typeof(Newtonsoft.Json.JsonConverterAttribute)); + CustomAttributeData converterAttribute = memberAttribute ?? typeAttribute; + + Debug.Assert(converterAttribute.ConstructorArguments.Count > 0, $"{nameof(DefaultCosmosLinqSerializer)} Assert!", "At least one constructor argument exists."); + Type converterType = (Type)converterAttribute.ConstructorArguments[0].Value; + + string serializedValue = converterType.GetConstructor(Type.EmptyTypes) != null + ? JsonConvert.SerializeObject(value, (Newtonsoft.Json.JsonConverter)Activator.CreateInstance(converterType)) + : JsonConvert.SerializeObject(value); + + return serializedValue; + } + + public string SerializeScalarExpression(ConstantExpression inputExpression) + { + return JsonConvert.SerializeObject(inputExpression.Value); + } + + public string SerializeMemberName(MemberInfo memberInfo, CosmosLinqSerializerOptions linqSerializerOptions = null) + { + string memberName = null; + + // Check if Newtonsoft JsonExtensionDataAttribute is present on the member, if so, return empty member name. + Newtonsoft.Json.JsonExtensionDataAttribute jsonExtensionDataAttribute = memberInfo.GetCustomAttribute(true); + if (jsonExtensionDataAttribute != null && jsonExtensionDataAttribute.ReadData) + { + return null; + } + + // Json.Net honors JsonPropertyAttribute more than DataMemberAttribute + // So we check for JsonPropertyAttribute first. + JsonPropertyAttribute jsonPropertyAttribute = memberInfo.GetCustomAttribute(true); + if (jsonPropertyAttribute != null && !string.IsNullOrEmpty(jsonPropertyAttribute.PropertyName)) + { + memberName = jsonPropertyAttribute.PropertyName; + } + else + { + DataContractAttribute dataContractAttribute = memberInfo.DeclaringType.GetCustomAttribute(true); + if (dataContractAttribute != null) + { + DataMemberAttribute dataMemberAttribute = memberInfo.GetCustomAttribute(true); + if (dataMemberAttribute != null && !string.IsNullOrEmpty(dataMemberAttribute.Name)) + { + memberName = dataMemberAttribute.Name; + } + } + } + + if (memberName == null) + { + memberName = memberInfo.Name; + } + + if (linqSerializerOptions != null) + { + memberName = CosmosSerializationUtil.GetStringWithPropertyNamingPolicy(linqSerializerOptions, memberName); + } + + return memberName; + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs index fef136d8d2..5153f29115 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs @@ -14,7 +14,6 @@ namespace Microsoft.Azure.Cosmos.Linq using System.Linq.Expressions; using System.Reflection; using Microsoft.Azure.Cosmos.CosmosElements; - using Microsoft.Azure.Cosmos.CosmosElements.Numbers; using Microsoft.Azure.Cosmos.Spatial; using Microsoft.Azure.Cosmos.SqlObjects; using Microsoft.Azure.Documents; @@ -94,7 +93,7 @@ public static SqlQuery TranslateQuery( TranslationContext context = new TranslationContext(linqSerializerOptions, parameters); ExpressionToSql.Translate(inputExpression, context); // ignore result here - QueryUnderConstruction query = context.currentQuery; + QueryUnderConstruction query = context.CurrentQuery; query = query.FlattenAsPossible(); SqlQuery result = query.GetSqlQuery(); @@ -159,7 +158,7 @@ private static Collection TranslateInput(ConstantExpression inputExpression, Tra throw new DocumentQueryException(ClientResources.InputIsNotIDocumentQuery); } - context.currentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); + context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); Type elemType = TypeSystem.GetElementType(inputExpression.Type); context.SetInputParameter(elemType, ParameterSubstitution.InputParameterName); // ignore result @@ -169,7 +168,7 @@ private static Collection TranslateInput(ConstantExpression inputExpression, Tra } /// - /// Get a paramter name to be binded to the a collection from the next lambda. + /// Get a parameter name to be binded to the collection from the next lambda. /// It's merely for readability purpose. If that is not possible, use a default /// parameter name. /// @@ -189,7 +188,7 @@ private static string GetBindingParameterName(TranslationContext context) } } - if (parameterName == null) parameterName = ExpressionToSql.DefaultParameterName; + parameterName ??= ExpressionToSql.DefaultParameterName; return parameterName; } @@ -474,17 +473,17 @@ private static SqlScalarExpression VisitBinary(BinaryExpression inputExpression, if (left is SqlMemberIndexerScalarExpression && right is SqlLiteralScalarExpression literalScalarExpression) { - right = ExpressionToSql.ApplyCustomConverters(inputExpression.Left, literalScalarExpression); + right = ExpressionToSql.ApplyCustomConverters(inputExpression.Left, literalScalarExpression, context); } else if (right is SqlMemberIndexerScalarExpression && left is SqlLiteralScalarExpression sqlLiteralScalarExpression) { - left = ExpressionToSql.ApplyCustomConverters(inputExpression.Right, sqlLiteralScalarExpression); + left = ExpressionToSql.ApplyCustomConverters(inputExpression.Right, sqlLiteralScalarExpression, context); } return SqlBinaryScalarExpression.Create(op, left, right); } - private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLiteralScalarExpression right) + private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLiteralScalarExpression right, TranslationContext context) { MemberExpression memberExpression; if (left is UnaryExpression unaryExpression) @@ -504,48 +503,28 @@ private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLit memberType = memberType.NullableUnderlyingType(); } - // There are two ways to specify a custom attribute - // 1- by specifying the JsonConverterAttribute on a Class/Enum - // [JsonConverter(typeof(StringEnumConverter))] - // Enum MyEnum - // { - // ... - // } - // - // 2- by specifying the JsonConverterAttribute on a property - // class MyClass - // { - // [JsonConverter(typeof(StringEnumConverter))] - // public MyEnum MyEnum; - // } - // - // Newtonsoft gives high precedence to the attribute specified - // on a property over on a type (class/enum) - // so we check both attributes and apply the same precedence rules - // JsonConverterAttribute doesn't allow duplicates so it's safe to - // use FirstOrDefault() - CustomAttributeData memberAttribute = memberExpression.Member.CustomAttributes.Where(ca => ca.AttributeType == typeof(JsonConverterAttribute)).FirstOrDefault(); - CustomAttributeData typeAttribute = memberType.GetsCustomAttributes().Where(ca => ca.AttributeType == typeof(JsonConverterAttribute)).FirstOrDefault(); - - CustomAttributeData converterAttribute = memberAttribute ?? typeAttribute; - if (converterAttribute != null) + bool requiresCustomSerializatior = context.CosmosLinqSerializer.RequiresCustomSerialization(memberExpression, memberType); + if (requiresCustomSerializatior) { - Debug.Assert(converterAttribute.ConstructorArguments.Count > 0); - - Type converterType = (Type)converterAttribute.ConstructorArguments[0].Value; - object value = default(object); // Enum if (memberType.IsEnum()) { - Number64 number64 = ((SqlNumberLiteral)right.Literal).Value; - if (number64.IsDouble) + try { - value = Enum.ToObject(memberType, Number64.ToDouble(number64)); + Number64 number64 = ((SqlNumberLiteral)right.Literal).Value; + if (number64.IsDouble) + { + value = Enum.ToObject(memberType, Number64.ToDouble(number64)); + } + else + { + value = Enum.ToObject(memberType, Number64.ToLong(number64)); + } } - else + catch { - value = Enum.ToObject(memberType, Number64.ToLong(number64)); + value = ((SqlStringLiteral)right.Literal).Value; } } @@ -558,17 +537,7 @@ private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLit if (value != default(object)) { - string serializedValue; - - if (converterType.GetConstructor(Type.EmptyTypes) != null) - { - serializedValue = JsonConvert.SerializeObject(value, (JsonConverter)Activator.CreateInstance(converterType)); - } - else - { - serializedValue = JsonConvert.SerializeObject(value); - } - + string serializedValue = context.CosmosLinqSerializer.Serialize(value, memberExpression, memberType); return CosmosElement.Parse(serializedValue).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); } } @@ -717,17 +686,17 @@ public static SqlScalarExpression VisitConstant(ConstantExpression inputExpressi if (inputExpression.Type.IsNullable()) { - return ExpressionToSql.VisitConstant(Expression.Constant(inputExpression.Value, Nullable.GetUnderlyingType(inputExpression.Type)), context); + return VisitConstant(Expression.Constant(inputExpression.Value, Nullable.GetUnderlyingType(inputExpression.Type)), context); } - if (context.parameters != null && context.parameters.TryGetValue(inputExpression.Value, out string paramName)) + if (context.Parameters != null && context.Parameters.TryGetValue(inputExpression.Value, out string paramName)) { SqlParameter sqlParameter = SqlParameter.Create(paramName); return SqlParameterRefScalarExpression.Create(sqlParameter); } Type constantType = inputExpression.Value.GetType(); - if (constantType.IsValueType()) + if (constantType.IsValueType) { if (inputExpression.Value is bool boolValue) { @@ -764,13 +733,15 @@ public static SqlScalarExpression VisitConstant(ConstantExpression inputExpressi foreach (object item in enumerable) { - arrayItems.Add(ExpressionToSql.VisitConstant(Expression.Constant(item), context)); + arrayItems.Add(VisitConstant(Expression.Constant(item), context)); } return SqlArrayCreateScalarExpression.Create(arrayItems.ToImmutableArray()); } - return CosmosElement.Parse(JsonConvert.SerializeObject(inputExpression.Value)).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); + string serializedConstant = context.CosmosLinqSerializer.SerializeScalarExpression(inputExpression); + + return CosmosElement.Parse(serializedConstant).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); } private static SqlScalarExpression VisitConditional(ConditionalExpression inputExpression, TranslationContext context) @@ -798,7 +769,7 @@ private static SqlScalarExpression VisitParameter(ParameterExpression inputExpre private static SqlScalarExpression VisitMemberAccess(MemberExpression inputExpression, TranslationContext context) { SqlScalarExpression memberExpression = ExpressionToSql.VisitScalarExpression(inputExpression.Expression, context); - string memberName = inputExpression.Member.GetMemberName(context.linqSerializerOptions); + string memberName = inputExpression.Member.GetMemberName(context); // If the resulting memberName is null, then the indexer should be on the root of the object. if (memberName == null) @@ -809,7 +780,7 @@ private static SqlScalarExpression VisitMemberAccess(MemberExpression inputExpre // if expression is nullable if (inputExpression.Expression.Type.IsNullable()) { - MemberNames memberNames = context.memberNames; + MemberNames memberNames = context.MemberNames; // ignore .Value if (memberName == memberNames.Value) @@ -853,7 +824,7 @@ private static SqlScalarExpression[] VisitExpressionList(ReadOnlyCollectionThe scalar Any collection private static Collection ConvertToScalarAnyCollection(TranslationContext context) { - SqlQuery query = context.currentQuery.FlattenAsPossible().GetSqlQuery(); + SqlQuery query = context.CurrentQuery.FlattenAsPossible().GetSqlQuery(); SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); Binding binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); - context.currentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); - context.currentQuery.AddBinding(binding); + context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); + context.CurrentQuery.AddBinding(binding); SqlSelectSpec selectSpec = SqlSelectValueSpec.Create( SqlBinaryScalarExpression.Create( @@ -1032,7 +1003,7 @@ private static Collection ConvertToScalarAnyCollection(TranslationContext contex SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name))), SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(0)))); SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec); - context.currentQuery.AddSelectClause(selectClause); + context.CurrentQuery.AddSelectClause(selectClause); return new Collection(LinqMethods.Any); } @@ -1173,106 +1144,106 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, context.PushCollection(collection); Collection result = new Collection(inputExpression.Method.Name); - bool shouldBeOnNewQuery = context.currentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count); + bool shouldBeOnNewQuery = context.CurrentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count); context.PushSubqueryBinding(shouldBeOnNewQuery); switch (inputExpression.Method.Name) { case LinqMethods.Select: { SqlSelectClause select = ExpressionToSql.VisitSelect(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Where: { SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddWhereClause(where, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); break; } case LinqMethods.SelectMany: { - context.currentQuery = context.PackageCurrentQueryIfNeccessary(); + context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); result = ExpressionToSql.VisitSelectMany(inputExpression.Arguments, context); break; } case LinqMethods.OrderBy: { SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); - context.currentQuery = context.currentQuery.AddOrderByClause(orderBy, context); + context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); break; } case LinqMethods.OrderByDescending: { SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); - context.currentQuery = context.currentQuery.AddOrderByClause(orderBy, context); + context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); break; } case LinqMethods.ThenBy: { SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); - context.currentQuery = context.currentQuery.UpdateOrderByClause(thenBy, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); break; } case LinqMethods.ThenByDescending: { SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); - context.currentQuery = context.currentQuery.UpdateOrderByClause(thenBy, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); break; } case LinqMethods.Skip: { SqlOffsetSpec offsetSpec = ExpressionToSql.VisitSkip(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddOffsetSpec(offsetSpec, context); + context.CurrentQuery = context.CurrentQuery.AddOffsetSpec(offsetSpec, context); break; } case LinqMethods.Take: { - if (context.currentQuery.HasOffsetSpec()) + if (context.CurrentQuery.HasOffsetSpec()) { SqlLimitSpec limitSpec = ExpressionToSql.VisitTakeLimit(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddLimitSpec(limitSpec, context); + context.CurrentQuery = context.CurrentQuery.AddLimitSpec(limitSpec, context); } else { SqlTopSpec topSpec = ExpressionToSql.VisitTakeTop(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddTopSpec(topSpec); + context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); } break; } case LinqMethods.Distinct: { SqlSelectClause select = ExpressionToSql.VisitDistinct(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Max: { SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Min: { SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Average: { SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Count: { SqlSelectClause select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Sum: { SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); - context.currentQuery = context.currentQuery.AddSelectClause(select, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.Any: @@ -1282,7 +1253,7 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, { // Any is translated to an SELECT VALUE EXISTS() where Any operation itself is treated as a Where. SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.currentQuery = context.currentQuery.AddWhereClause(where, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); } break; } @@ -1509,7 +1480,7 @@ private static SqlScalarExpression VisitScalarExpression(Expression expression, ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); SqlCollection subqueryCollection = ExpressionToSql.CreateSubquerySqlCollection( - query, context, + query, isMinMaxAvgMethod ? SubqueryKind.ArrayScalarExpression : expressionObjKind.Value); Binding newBinding = new Binding(parameterExpression, subqueryCollection, @@ -1536,9 +1507,8 @@ private static SqlScalarExpression VisitScalarExpression(Expression expression, /// Create a subquery SQL collection object for a SQL query /// /// The SQL query object - /// The translation context /// The subquery type - private static SqlCollection CreateSubquerySqlCollection(SqlQuery query, TranslationContext context, SubqueryKind subqueryType) + private static SqlCollection CreateSubquerySqlCollection(SqlQuery query, SubqueryKind subqueryType) { SqlCollection subqueryCollection; switch (subqueryType) @@ -1583,18 +1553,18 @@ private static SqlQuery CreateSubquery(Expression expression, ReadOnlyCollection { bool shouldBeOnNewQuery = context.CurrentSubqueryBinding.ShouldBeOnNewQuery; - QueryUnderConstruction queryBeforeVisit = context.currentQuery; - QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.currentQuery); - packagedQuery.fromParameters.SetInputParameter(typeof(object), context.currentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); - context.currentQuery = packagedQuery; + QueryUnderConstruction queryBeforeVisit = context.CurrentQuery; + QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.CurrentQuery); + packagedQuery.fromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); + context.CurrentQuery = packagedQuery; if (shouldBeOnNewQuery) context.CurrentSubqueryBinding.ShouldBeOnNewQuery = false; Collection collection = ExpressionToSql.VisitCollectionExpression(expression, parameters, context); - QueryUnderConstruction subquery = context.currentQuery.GetSubquery(queryBeforeVisit); + QueryUnderConstruction subquery = context.CurrentQuery.GetSubquery(queryBeforeVisit); context.CurrentSubqueryBinding.ShouldBeOnNewQuery = shouldBeOnNewQuery; - context.currentQuery = queryBeforeVisit; + context.CurrentQuery = queryBeforeVisit; SqlQuery sqlSubquery = subquery.FlattenAsPossible().GetSqlQuery(); return sqlSubquery; @@ -1665,7 +1635,7 @@ private static Collection VisitSelectMany(ReadOnlyCollection argumen SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); - context.currentQuery.fromParameters.Add(binding); + context.CurrentQuery.fromParameters.Add(binding); } return collection; @@ -1902,7 +1872,7 @@ private static SqlSelectClause VisitCount( if (arguments.Count == 2) { SqlWhereClause whereClause = ExpressionToSql.VisitWhere(arguments, context); - context.currentQuery = context.currentQuery.AddWhereClause(whereClause, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(whereClause, context); } else if (arguments.Count != 1) { @@ -2004,83 +1974,6 @@ private static SqlInputPathCollection ConvertMemberIndexerToPath(SqlMemberIndexe #endregion LINQ Specific Visitors - private sealed class CosmosElementToSqlScalarExpressionVisitor : ICosmosElementVisitor - { - public static readonly CosmosElementToSqlScalarExpressionVisitor Singleton = new CosmosElementToSqlScalarExpressionVisitor(); - - private CosmosElementToSqlScalarExpressionVisitor() - { - // Private constructor, since this class is a singleton. - } - - public SqlScalarExpression Visit(CosmosArray cosmosArray) - { - List items = new List(); - foreach (CosmosElement item in cosmosArray) - { - items.Add(item.Accept(this)); - } - - return SqlArrayCreateScalarExpression.Create(items.ToImmutableArray()); - } - - public SqlScalarExpression Visit(CosmosBinary cosmosBinary) - { - // Can not convert binary to scalar expression without knowing the API type. - throw new NotImplementedException(); - } - - public SqlScalarExpression Visit(CosmosBoolean cosmosBoolean) - { - return SqlLiteralScalarExpression.Create(SqlBooleanLiteral.Create(cosmosBoolean.Value)); - } - - public SqlScalarExpression Visit(CosmosGuid cosmosGuid) - { - // Can not convert guid to scalar expression without knowing the API type. - throw new NotImplementedException(); - } - - public SqlScalarExpression Visit(CosmosNull cosmosNull) - { - return SqlLiteralScalarExpression.Create(SqlNullLiteral.Create()); - } - - public SqlScalarExpression Visit(CosmosNumber cosmosNumber) - { - if (!(cosmosNumber is CosmosNumber64 cosmosNumber64)) - { - throw new ArgumentException($"Unknown {nameof(CosmosNumber)} type: {cosmosNumber.GetType()}."); - } - - return SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(cosmosNumber64.GetValue())); - } - - public SqlScalarExpression Visit(CosmosObject cosmosObject) - { - List properties = new List(); - foreach (KeyValuePair prop in cosmosObject) - { - SqlPropertyName name = SqlPropertyName.Create(prop.Key); - CosmosElement value = prop.Value; - SqlScalarExpression expression = value.Accept(this); - SqlObjectProperty property = SqlObjectProperty.Create(name, expression); - properties.Add(property); - } - - return SqlObjectCreateScalarExpression.Create(properties.ToImmutableArray()); - } - - public SqlScalarExpression Visit(CosmosString cosmosString) - { - return SqlLiteralScalarExpression.Create(SqlStringLiteral.Create(cosmosString.Value)); - } - - public SqlScalarExpression Visit(CosmosUndefined cosmosUndefined) - { - return SqlLiteralScalarExpression.Create(SqlUndefinedLiteral.Create()); - } - } private enum SubqueryKind { ArrayScalarExpression, diff --git a/Microsoft.Azure.Cosmos/src/Linq/ICosmosLinqSerializer.cs b/Microsoft.Azure.Cosmos/src/Linq/ICosmosLinqSerializer.cs new file mode 100644 index 0000000000..f31490832d --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Linq/ICosmosLinqSerializer.cs @@ -0,0 +1,33 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Linq.Expressions; + using System.Reflection; + + internal interface ICosmosLinqSerializer + { + /// + /// Returns true if there are custom attributes on a member expression. + /// + bool RequiresCustomSerialization(MemberExpression memberExpression, Type memberType); + + // TODO : Clean up this interface member for better generalizability + /// + /// Serializes object. + /// + string Serialize(object value, MemberExpression memberExpression, Type memberType); + + /// + /// Serializes a ConstantExpression. + /// + string SerializeScalarExpression(ConstantExpression inputExpression); + + /// + /// Serializes a member name with LINQ serializer options applied. + /// + string SerializeMemberName(MemberInfo memberInfo, CosmosLinqSerializerOptions linqSerializerOptions = null); + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs index c5c0573a14..d2e19046e1 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs @@ -583,13 +583,13 @@ public QueryUnderConstruction AddOrderByClause(SqlOrderByClause orderBy, Transla public QueryUnderConstruction UpdateOrderByClause(SqlOrderByClause thenBy, TranslationContext context) { - List items = new List(context.currentQuery.orderByClause.OrderByItems); + List items = new List(context.CurrentQuery.orderByClause.OrderByItems); items.AddRange(thenBy.OrderByItems); - context.currentQuery.orderByClause = SqlOrderByClause.Create(items.ToImmutableArray()); + context.CurrentQuery.orderByClause = SqlOrderByClause.Create(items.ToImmutableArray()); - foreach (Binding binding in context.CurrentSubqueryBinding.TakeBindings()) context.currentQuery.AddBinding(binding); + foreach (Binding binding in context.CurrentSubqueryBinding.TakeBindings()) context.CurrentQuery.AddBinding(binding); - return context.currentQuery; + return context.CurrentQuery; } public QueryUnderConstruction AddOffsetSpec(SqlOffsetSpec offsetSpec, TranslationContext context) diff --git a/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs b/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs index 8648019219..c0bd6d38a0 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs @@ -6,7 +6,6 @@ namespace Microsoft.Azure.Cosmos.Linq using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.Azure.Cosmos.Query.Core; - using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Cosmos.SqlObjects; /// diff --git a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs index 8fc95d8701..f5d53bd1e7 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs @@ -7,7 +7,6 @@ namespace Microsoft.Azure.Cosmos.Linq using System; using System.Collections.Generic; using System.Linq.Expressions; - using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Cosmos.SqlObjects; using static Microsoft.Azure.Cosmos.Linq.ExpressionToSql; using static Microsoft.Azure.Cosmos.Linq.FromParameterBindings; @@ -20,39 +19,54 @@ internal sealed class TranslationContext /// /// Member names for special mapping cases /// - internal readonly MemberNames memberNames; + public readonly MemberNames MemberNames; + + /// + /// The LINQ serializer + /// + public readonly ICosmosLinqSerializer CosmosLinqSerializer; + + /// + /// User-provided LINQ serializer options + /// + public CosmosLinqSerializerOptions LinqSerializerOptions; /// /// Set of parameters in scope at any point; used to generate fresh parameter names if necessary. /// public HashSet InScope; + /// /// Query that is being assembled. /// - public QueryUnderConstruction currentQuery; + public QueryUnderConstruction CurrentQuery; /// /// Dictionary for parameter name and value /// - public IDictionary parameters; + public IDictionary Parameters; /// /// If the FROM clause uses a parameter name, it will be substituted for the parameter used in /// the lambda expressions for the WHERE and SELECT clauses. /// private ParameterSubstitution substitutions; + /// /// We are currently visiting these methods. /// private List methodStack; + /// /// Stack of parameters from lambdas currently in scope. /// private List lambdaParametersStack; + /// /// Stack of collection-valued inputs. /// private List collectionStack; + /// /// The stack of subquery binding information. /// @@ -65,15 +79,14 @@ public TranslationContext(CosmosLinqSerializerOptions linqSerializerOptions, IDi this.methodStack = new List(); this.lambdaParametersStack = new List(); this.collectionStack = new List(); - this.currentQuery = new QueryUnderConstruction(this.GetGenFreshParameterFunc()); + this.CurrentQuery = new QueryUnderConstruction(this.GetGenFreshParameterFunc()); this.subqueryBindingStack = new Stack(); - this.linqSerializerOptions = linqSerializerOptions; - this.parameters = parameters; - this.memberNames = new MemberNames(linqSerializerOptions); + this.LinqSerializerOptions = linqSerializerOptions; + this.Parameters = parameters; + this.MemberNames = new MemberNames(linqSerializerOptions); + this.CosmosLinqSerializer = new DefaultCosmosLinqSerializer(); } - public CosmosLinqSerializerOptions linqSerializerOptions; - public Expression LookupSubstitution(ParameterExpression parameter) { return this.substitutions.Lookup(parameter); @@ -103,12 +116,12 @@ public void PushParameter(ParameterExpression parameter, bool shouldBeOnNewQuery if (last.isOuter) { // substitute - ParameterExpression inputParam = this.currentQuery.GetInputParameterInContext(shouldBeOnNewQuery); + ParameterExpression inputParam = this.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery); this.substitutions.AddSubstitution(parameter, inputParam); } else { - this.currentQuery.Bind(parameter, last.inner); + this.CurrentQuery.Bind(parameter, last.inner); } } @@ -182,7 +195,7 @@ public void PopCollection() /// Suggested name for the input parameter. public ParameterExpression SetInputParameter(Type type, string name) { - return this.currentQuery.fromParameters.SetInputParameter(type, name, this.InScope); + return this.CurrentQuery.fromParameters.SetInputParameter(type, name, this.InScope); } /// @@ -193,7 +206,7 @@ public ParameterExpression SetInputParameter(Type type, string name) public void SetFromParameter(ParameterExpression parameter, SqlCollection collection) { Binding binding = new Binding(parameter, collection, isInCollection: true); - this.currentQuery.fromParameters.Add(binding); + this.CurrentQuery.fromParameters.Add(binding); } /// @@ -258,11 +271,11 @@ public QueryUnderConstruction PackageCurrentQueryIfNeccessary() { if (this.CurrentSubqueryBinding.ShouldBeOnNewQuery) { - this.currentQuery = this.currentQuery.PackageQuery(this.InScope); + this.CurrentQuery = this.CurrentQuery.PackageQuery(this.InScope); this.CurrentSubqueryBinding.ShouldBeOnNewQuery = false; } - return this.currentQuery; + return this.CurrentQuery; } public class SubqueryBinding diff --git a/Microsoft.Azure.Cosmos/src/Linq/TypeSystem.cs b/Microsoft.Azure.Cosmos/src/Linq/TypeSystem.cs index 7c8e62d69c..e11c145c71 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/TypeSystem.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/TypeSystem.cs @@ -10,10 +10,7 @@ namespace Microsoft.Azure.Cosmos.Linq using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; - using System.Runtime.Serialization; - using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Documents; - using Newtonsoft.Json; internal static class TypeSystem { @@ -22,48 +19,9 @@ public static Type GetElementType(Type type) return GetElementType(type, new HashSet()); } - public static string GetMemberName(this MemberInfo memberInfo, CosmosLinqSerializerOptions linqSerializerOptions = null) + public static string GetMemberName(this MemberInfo memberInfo, TranslationContext context) { - string memberName = null; - - // Check if Newtonsoft JsonExtensionDataAttribute is present on the member, if so, return empty member name. - JsonExtensionDataAttribute jsonExtensionDataAttribute = memberInfo.GetCustomAttribute(true); - if (jsonExtensionDataAttribute != null && jsonExtensionDataAttribute.ReadData) - { - return null; - } - - // Json.Net honors JsonPropertyAttribute more than DataMemberAttribute - // So we check for JsonPropertyAttribute first. - JsonPropertyAttribute jsonPropertyAttribute = memberInfo.GetCustomAttribute(true); - if (jsonPropertyAttribute != null && !string.IsNullOrEmpty(jsonPropertyAttribute.PropertyName)) - { - memberName = jsonPropertyAttribute.PropertyName; - } - else - { - DataContractAttribute dataContractAttribute = memberInfo.DeclaringType.GetCustomAttribute(true); - if (dataContractAttribute != null) - { - DataMemberAttribute dataMemberAttribute = memberInfo.GetCustomAttribute(true); - if (dataMemberAttribute != null && !string.IsNullOrEmpty(dataMemberAttribute.Name)) - { - memberName = dataMemberAttribute.Name; - } - } - } - - if (memberName == null) - { - memberName = memberInfo.Name; - } - - if (linqSerializerOptions != null) - { - memberName = CosmosSerializationUtil.GetStringWithPropertyNamingPolicy(linqSerializerOptions, memberName); - } - - return memberName; + return context.CosmosLinqSerializer.SerializeMemberName(memberInfo, context.LinqSerializerOptions); } private static Type GetElementType(Type type, HashSet visitedSet) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqAttributeContractBaselineTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqAttributeContractBaselineTests.cs index 7ffc36bdbd..5ea03bf0b6 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqAttributeContractBaselineTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqAttributeContractBaselineTests.cs @@ -9,14 +9,13 @@ namespace Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests using System.Collections.Generic; using System.Linq; using System.Runtime.Serialization; - using Newtonsoft.Json; - using Newtonsoft.Json.Linq; - using VisualStudio.TestTools.UnitTesting; + using System.Threading.Tasks; using BaselineTest; using Microsoft.Azure.Cosmos.Linq; using Microsoft.Azure.Cosmos.SDK.EmulatorTests; using Microsoft.Azure.Documents; - using System.Threading.Tasks; + using Newtonsoft.Json; + using VisualStudio.TestTools.UnitTesting; /// /// Class that tests to see that we honor the attributes for members in a class / struct when we create LINQ queries. @@ -173,10 +172,11 @@ public Datum2(string jsonProperty, string dataMember, string defaultMember, stri [TestMethod] public void TestAttributePriority() { - Assert.AreEqual("jsonProperty", TypeSystem.GetMemberName(typeof(Datum).GetMember("JsonProperty").First())); - Assert.AreEqual("dataMember", TypeSystem.GetMemberName(typeof(Datum).GetMember("DataMember").First())); - Assert.AreEqual("Default", TypeSystem.GetMemberName(typeof(Datum).GetMember("Default").First())); - Assert.AreEqual("jsonPropertyHasHigherPriority", TypeSystem.GetMemberName(typeof(Datum).GetMember("JsonPropertyAndDataMember").First())); + ICosmosLinqSerializer cosmosLinqSerializer = new DefaultCosmosLinqSerializer(); + Assert.AreEqual("jsonProperty", cosmosLinqSerializer.SerializeMemberName(typeof(Datum).GetMember("JsonProperty").First())); + Assert.AreEqual("dataMember", cosmosLinqSerializer.SerializeMemberName(typeof(Datum).GetMember("DataMember").First())); + Assert.AreEqual("Default", cosmosLinqSerializer.SerializeMemberName(typeof(Datum).GetMember("Default").First())); + Assert.AreEqual("jsonPropertyHasHigherPriority", cosmosLinqSerializer.SerializeMemberName(typeof(Datum).GetMember("JsonPropertyAndDataMember").First())); } ///