diff --git a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs index b83dc622604..eb55e255aa8 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs @@ -5,6 +5,8 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; @@ -13,7 +15,7 @@ namespace Microsoft.EntityFrameworkCore { /// - /// Cosmos DB specific extension methods for LINQ queries. + /// Cosmos-specific extension methods for LINQ queries. /// public static class CosmosQueryableExtensions { @@ -46,5 +48,56 @@ source.Provider is EntityQueryProvider Expression.Constant(partitionKey))) : source; } + + /// + /// + /// Creates a LINQ query based on a raw SQL query. + /// + /// + /// You can compose on top of the raw SQL query using LINQ operators: + /// + /// context.Blogs.FromSqlRaw("SELECT * FROM root c).OrderBy(b => b.Name) + /// + /// As with any API that accepts SQL it is important to parameterize any user input to protect against a SQL injection + /// attack. You can include parameter place holders in the SQL query string and then supply parameter values as additional + /// arguments. Any parameter values you supply will automatically be converted to a Cosmos parameter: + /// + /// context.Blogs.FromSqlRaw(""SELECT * FROM root c WHERE c["Name"] = {0})", userSuppliedSearchTerm) + /// + /// The type of the elements of . + /// + /// An to use as the base of the raw SQL query (typically a ). + /// + /// The raw SQL query. + /// The values to be assigned to parameters. + /// An representing the raw SQL query. + [StringFormatMethod("sql")] + public static IQueryable FromSqlRaw( + this IQueryable source, + [NotParameterized] string sql, + params object[] parameters) + where TEntity : class + { + Check.NotNull(source, nameof(source)); + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(parameters, nameof(parameters)); + + var queryRootExpression = (QueryRootExpression)source.Expression; + + var entityType = queryRootExpression.EntityType; + + Check.DebugAssert( + (entityType.BaseType is null && !entityType.GetDirectlyDerivedTypes().Any()) + || entityType.FindDiscriminatorProperty() is not null, + "Found FromSql on a TPT entity type, but TPT isn't supported on Cosmos"); + + var fromSqlQueryRootExpression = new FromSqlQueryRootExpression( + queryRootExpression.QueryProvider!, + entityType, + sql, + Expression.Constant(parameters)); + + return source.Provider.CreateQuery(fromSqlQueryRootExpression); + } } } diff --git a/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs b/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs index 60d76dbfb03..623a51283c6 100644 --- a/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs +++ b/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs @@ -102,14 +102,14 @@ private void ProcessEntityType(IConventionEntityTypeBuilder entityTypeBuilder) return; } - if (!entityType.IsDocumentRoot()) + if (entityType.IsDocumentRoot()) { - entityTypeBuilder.HasNoDiscriminator(); + entityTypeBuilder.HasDiscriminator(typeof(string)) + ?.HasValue(entityType, entityType.ShortName()); } else { - entityTypeBuilder.HasDiscriminator(typeof(string)) - ?.HasValue(entityType, entityType.ShortName()); + entityTypeBuilder.HasNoDiscriminator(); } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs index 4f520ef4a66..6c61757941d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs @@ -44,8 +44,7 @@ public override Expression Process(Expression query) { query = base.Process(query); - if (query is ShapedQueryExpression shapedQueryExpression - && shapedQueryExpression.QueryExpression is SelectExpression selectExpression) + if (query is ShapedQueryExpression { QueryExpression: SelectExpression selectExpression }) { // Cosmos does not have nested select expression so this should be safe. selectExpression.ApplyProjection(); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 6977c168f13..5966236baf8 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -128,7 +128,7 @@ public override Expression Visit(Expression expression) var readItemExpression = new ReadItemExpression(entityType, propertyParameterList); - return CreateShapedQueryExpression(readItemExpression, entityType) + return CreateShapedQueryExpression(entityType, readItemExpression) .UpdateResultCardinality(ResultCardinality.Single); } } @@ -187,6 +187,24 @@ static bool TryGetPartitionKeyProperty(IEntityType entityType, out IProperty par } } + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case FromSqlQueryRootExpression fromSqlQueryRootExpression: + return CreateShapedQueryExpression( + fromSqlQueryRootExpression.EntityType, + _sqlExpressionFactory.Select( + fromSqlQueryRootExpression.EntityType, + fromSqlQueryRootExpression.Sql, + fromSqlQueryRootExpression.Argument)); + + default: + return base.VisitExtension(extensionExpression); + } + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -246,10 +264,10 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType var selectExpression = _sqlExpressionFactory.Select(entityType); - return CreateShapedQueryExpression(selectExpression, entityType); + return CreateShapedQueryExpression(entityType, selectExpression); } - private ShapedQueryExpression CreateShapedQueryExpression(Expression queryExpression, IEntityType entityType) + private ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression) => new( queryExpression, new EntityShaperExpression( diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 7474b41f5b4..5045fc389f5 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -65,7 +65,6 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery switch (shapedQueryExpression.QueryExpression) { case SelectExpression selectExpression: - shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor( selectExpression, jObjectParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) @@ -92,7 +91,6 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery Expression.Constant(_threadSafetyChecksEnabled)); case ReadItemExpression readItemExpression: - shaperBody = new CosmosProjectionBindingRemovingReadItemExpressionVisitor( readItemExpression, jObjectParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) diff --git a/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs b/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs new file mode 100644 index 00000000000..160e8e36180 --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Utilities; + +#nullable disable + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class FromSqlExpression : RootReferenceExpression, ICloneable, IPrintableExpression + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlExpression(IEntityType entityType, string alias, string sql, Expression arguments) : base(entityType, alias) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(arguments, nameof(arguments)); + + Sql = sql; + Arguments = arguments; + } + + /// + public override string Alias => base.Alias!; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string Sql { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Arguments { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual FromSqlExpression Update(Expression arguments) + { + Check.NotNull(arguments, nameof(arguments)); + + return arguments != Arguments + ? new FromSqlExpression(EntityType, Alias, Sql, arguments) + : this; + } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return this; + } + + /// + public override Type Type + => typeof(object); + + /// + public virtual object Clone() => new FromSqlExpression(EntityType, Alias, Sql, Arguments); + + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + expressionPrinter.Append(Sql); + } + + /// + public override bool Equals(object obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is FromSqlExpression fromSqlExpression + && Equals(fromSqlExpression)); + + private bool Equals(FromSqlExpression fromSqlExpression) + => base.Equals(fromSqlExpression) + && Sql == fromSqlExpression.Sql + && ExpressionEqualityComparer.Instance.Equals(Arguments, fromSqlExpression.Arguments); + + /// + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Sql); + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs b/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs new file mode 100644 index 00000000000..2ac394e4c4a --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs @@ -0,0 +1,141 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class FromSqlQueryRootExpression : QueryRootExpression + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlQueryRootExpression( + IAsyncQueryProvider queryProvider, + IEntityType entityType, + string sql, + Expression argument) + : base(queryProvider, entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlQueryRootExpression( + IEntityType entityType, + string sql, + Expression argument) + : base(entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string Sql { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Argument { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression DetachQueryProvider() + => new FromSqlQueryRootExpression(EntityType, Sql, Argument); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var argument = visitor.Visit(Argument); + + return argument != Argument + ? new FromSqlQueryRootExpression(EntityType, Sql, argument) + : this; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override void Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + base.Print(expressionPrinter); + expressionPrinter.Append($".FromSql({Sql}, "); + expressionPrinter.Visit(Argument); + expressionPrinter.AppendLine(")"); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is FromSqlQueryRootExpression queryRootExpression + && Equals(queryRootExpression)); + + private bool Equals(FromSqlQueryRootExpression queryRootExpression) + => base.Equals(queryRootExpression) + && Sql == queryRootExpression.Sql + && ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Sql, ExpressionEqualityComparer.Instance.GetHashCode(Argument)); + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs index 9b5e762bbcb..e6f9ef0d41c 100644 --- a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs @@ -285,5 +285,13 @@ SqlConditionalExpression Condition( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// SelectExpression Select(IEntityType entityType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + SelectExpression Select(IEntityType entityType, string sql, Expression argument); } } diff --git a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs index 1a75b103766..82a028b177c 100644 --- a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs +++ b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs @@ -8,6 +8,7 @@ using System.Text; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; using Newtonsoft.Json; @@ -25,10 +26,11 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal /// public class QuerySqlGenerator : SqlExpressionVisitor { - private readonly StringBuilder _sqlBuilder = new(); + private readonly IndentedStringBuilder _sqlBuilder = new(); private IReadOnlyDictionary _parameterValues; private List _sqlParameters; private bool _useValueProjection; + private ParameterNameGenerator _parameterNameGenerator; private readonly IDictionary _operatorMap = new Dictionary { @@ -77,6 +79,7 @@ public virtual CosmosSqlQuery GetSqlQuery( _sqlBuilder.Clear(); _parameterValues = parameterValues; _sqlParameters = new List(); + _parameterNameGenerator = new ParameterNameGenerator(); Visit(selectExpression); @@ -108,7 +111,7 @@ protected override Expression VisitObjectArrayProjection(ObjectArrayProjectionEx { Check.NotNull(objectArrayProjectionExpression, nameof(objectArrayProjectionExpression)); - _sqlBuilder.Append(objectArrayProjectionExpression); + _sqlBuilder.Append(objectArrayProjectionExpression.ToString()); return objectArrayProjectionExpression; } @@ -123,7 +126,7 @@ protected override Expression VisitKeyAccess(KeyAccessExpression keyAccessExpres { Check.NotNull(keyAccessExpression, nameof(keyAccessExpression)); - _sqlBuilder.Append(keyAccessExpression); + _sqlBuilder.Append(keyAccessExpression.ToString()); return keyAccessExpression; } @@ -138,7 +141,7 @@ protected override Expression VisitObjectAccess(ObjectAccessExpression objectAcc { Check.NotNull(objectAccessExpression, nameof(objectAccessExpression)); - _sqlBuilder.Append(objectAccessExpression); + _sqlBuilder.Append(objectAccessExpression.ToString()); return objectAccessExpression; } @@ -180,7 +183,7 @@ protected override Expression VisitRootReference(RootReferenceExpression rootRef { Check.NotNull(rootReferenceExpression, nameof(rootReferenceExpression)); - _sqlBuilder.Append(rootReferenceExpression); + _sqlBuilder.Append(rootReferenceExpression.ToString()); return rootReferenceExpression; } @@ -225,7 +228,14 @@ protected override Expression VisitSelect(SelectExpression selectExpression) _sqlBuilder.AppendLine(); - _sqlBuilder.Append("FROM root "); + if (selectExpression.FromExpression is FromSqlExpression) + { + _sqlBuilder.Append("FROM "); + } + else + { + _sqlBuilder.Append("FROM root "); + } Visit(selectExpression.FromExpression); _sqlBuilder.AppendLine(); @@ -272,6 +282,51 @@ protected override Expression VisitSelect(SelectExpression selectExpression) return selectExpression; } + /// + protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression) + { + Check.NotNull(fromSqlExpression, nameof(fromSqlExpression)); + + var sql = fromSqlExpression.Sql; + + var arguments = fromSqlExpression.Arguments switch + { + ConstantExpression { Value : object[] constantValues } + => constantValues, + ParameterExpression { Name : not null } parameterExpression + when _parameterValues.TryGetValue(parameterExpression.Name, out var parameterValue) + && parameterValue is object[] parameterValues + => parameterValues, + _ => null + }; + + if (arguments is not null) + { + var substitutions = new string[arguments.Length]; + for (var i = 0; i < arguments.Length; i++) + { + var parameterName = _parameterNameGenerator.GenerateNext(); + _sqlParameters.Add(new SqlParameter(parameterName, arguments[i])); + substitutions[i] = parameterName; + } + + sql = string.Format(sql, substitutions); + } + + _sqlBuilder.AppendLine("("); + + using (_sqlBuilder.Indent()) + { + _sqlBuilder.AppendLines(sql); + } + + _sqlBuilder + .Append(") ") + .Append(fromSqlExpression.Alias); + + return fromSqlExpression; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -350,7 +405,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio private void GenerateList( IReadOnlyList items, Action generationAction, - Action joinAction = null) + Action joinAction = null) { joinAction ??= (isb => isb.Append(", ")); @@ -488,5 +543,16 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return sqlFunctionExpression; } + + private class ParameterNameGenerator + { + private int _count; + + public string GenerateNext() + => "@p" + _count++; + + public void Reset() + => _count = 0; + } } } diff --git a/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs b/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs index cae1cbbed00..7a861eaf5b2 100644 --- a/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs +++ b/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs @@ -45,6 +45,19 @@ public SelectExpression(IEntityType entityType) _projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, FromExpression); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SelectExpression(IEntityType entityType, string sql, Expression argument) + { + Container = entityType.GetContainer(); + FromExpression = new FromSqlExpression(entityType, RootAlias, sql, argument); + _projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, new RootReferenceExpression(entityType, RootAlias)); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -166,18 +179,14 @@ public virtual void SetPartitionKey(IProperty partitionKeyProperty, Expression e /// public virtual string GetPartitionKey(IReadOnlyDictionary parameterValues) { - switch (_partitionKeyValue) + return _partitionKeyValue switch { - case ConstantExpression constantExpression: - return GetString(_partitionKeyValueConverter, constantExpression.Value); - - case ParameterExpression parameterExpression - when parameterValues.TryGetValue(parameterExpression.Name, out var value): - return GetString(_partitionKeyValueConverter, value); - - default: - return null; - } + ConstantExpression constantExpression + => GetString(_partitionKeyValueConverter, constantExpression.Value), + ParameterExpression parameterExpression when parameterValues.TryGetValue(parameterExpression.Name, out var value) + => GetString(_partitionKeyValueConverter, value), + _ => null + }; static string GetString(ValueConverter converter, object value) => converter is null diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs index 335c2c8c380..1db49aaa3b8 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs @@ -526,6 +526,20 @@ public virtual SelectExpression Select(IEntityType entityType) return selectExpression; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SelectExpression Select(IEntityType entityType, string sql, Expression argument) + { + var selectExpression = new SelectExpression(entityType, sql, argument); + AddDiscriminator(selectExpression, entityType); + + return selectExpression; + } + private void AddDiscriminator(SelectExpression selectExpression, IEntityType entityType) { var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs index 151de23dc5f..5bdc0c898fe 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs @@ -46,6 +46,9 @@ protected override Expression VisitExtension(Expression extensionExpression) case ObjectArrayProjectionExpression arrayProjectionExpression: return VisitObjectArrayProjection(arrayProjectionExpression); + case FromSqlExpression fromSqlExpression: + return VisitFromSql(fromSqlExpression); + case RootReferenceExpression rootReferenceExpression: return VisitRootReference(rootReferenceExpression); @@ -83,6 +86,14 @@ protected override Expression VisitExtension(Expression extensionExpression) return base.VisitExtension(extensionExpression); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected abstract Expression VisitFromSql(FromSqlExpression fromSqlExpression); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs index 08f9baf030a..08bc511144a 100644 --- a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs @@ -92,94 +92,88 @@ public virtual SelectExpression Expand( [return: NotNullIfNotNull("expression")] public override Expression? Visit(Expression? expression) { - if (expression is FromSqlExpression fromSql) + if (expression is not FromSqlExpression fromSql) { - if (!_visitedFromSqlExpressions.TryGetValue(fromSql, out var updatedFromSql)) - { - switch (fromSql.Arguments) + return base.Visit(expression); + } + + if (_visitedFromSqlExpressions.TryGetValue(fromSql, out var visitedFromSql)) + { + return visitedFromSql; + } + + switch (fromSql.Arguments) + { + case ParameterExpression parameterExpression: + // parameter value will never be null. It could be empty object[] + var parameterValues = (object[])_parametersValues[parameterExpression.Name!]!; + _canCache = false; + + var subParameters = new List(parameterValues.Length); + // ReSharper disable once ForCanBeConvertedToForeach + for (var i = 0; i < parameterValues.Length; i++) { - case ParameterExpression parameterExpression: - // parameter value will never be null. It could be empty object[] - var parameterValues = (object[])_parametersValues[parameterExpression.Name!]!; - _canCache = false; - - var subParameters = new List(parameterValues.Length); - // ReSharper disable once ForCanBeConvertedToForeach - for (var i = 0; i < parameterValues.Length; i++) + var parameterName = _parameterNameGenerator.GenerateNext(); + if (parameterValues[i] is DbParameter dbParameter) + { + if (string.IsNullOrEmpty(dbParameter.ParameterName)) + { + dbParameter.ParameterName = parameterName; + } + else { - var parameterName = _parameterNameGenerator.GenerateNext(); - if (parameterValues[i] is DbParameter dbParameter) - { - if (string.IsNullOrEmpty(dbParameter.ParameterName)) - { - dbParameter.ParameterName = parameterName; - } - else - { - parameterName = dbParameter.ParameterName; - } - - subParameters.Add(new RawRelationalParameter(parameterName, dbParameter)); - } - else - { - subParameters.Add( - new TypeMappedRelationalParameter( - parameterName, - parameterName, - _typeMappingSource.GetMappingForValue(parameterValues[i]), - parameterValues[i]?.GetType().IsNullableType())); - } + parameterName = dbParameter.ParameterName; } - updatedFromSql = fromSql.Update( - Expression.Constant(new CompositeRelationalParameter(parameterExpression.Name!, subParameters))); + subParameters.Add(new RawRelationalParameter(parameterName, dbParameter)); + } + else + { + subParameters.Add( + new TypeMappedRelationalParameter( + parameterName, + parameterName, + _typeMappingSource.GetMappingForValue(parameterValues[i]), + parameterValues[i]?.GetType().IsNullableType())); + } + } - _visitedFromSqlExpressions[fromSql] = updatedFromSql; - break; + return _visitedFromSqlExpressions[fromSql] = fromSql.Update( + Expression.Constant(new CompositeRelationalParameter(parameterExpression.Name!, subParameters))); - case ConstantExpression constantExpression: - var existingValues = constantExpression.GetConstantValue(); - var constantValues = new object?[existingValues.Length]; - for (var i = 0; i < existingValues.Length; i++) + case ConstantExpression constantExpression: + var existingValues = constantExpression.GetConstantValue(); + var constantValues = new object?[existingValues.Length]; + for (var i = 0; i < existingValues.Length; i++) + { + var value = existingValues[i]; + if (value is DbParameter dbParameter) + { + var parameterName = _parameterNameGenerator.GenerateNext(); + if (string.IsNullOrEmpty(dbParameter.ParameterName)) { - var value = existingValues[i]; - if (value is DbParameter dbParameter) - { - var parameterName = _parameterNameGenerator.GenerateNext(); - if (string.IsNullOrEmpty(dbParameter.ParameterName)) - { - dbParameter.ParameterName = parameterName; - } - else - { - parameterName = dbParameter.ParameterName; - } - - constantValues[i] = new RawRelationalParameter(parameterName, dbParameter); - } - else - { - constantValues[i] = _sqlExpressionFactory.Constant( - value, _typeMappingSource.GetMappingForValue(value)); - } + dbParameter.ParameterName = parameterName; + } + else + { + parameterName = dbParameter.ParameterName; } - updatedFromSql = fromSql.Update(Expression.Constant(constantValues, typeof(object?[]))); - - _visitedFromSqlExpressions[fromSql] = updatedFromSql; - break; - - default: - Check.DebugAssert(false, "FromSql.Arguments must be Constant/ParameterExpression"); - break; + constantValues[i] = new RawRelationalParameter(parameterName, dbParameter); + } + else + { + constantValues[i] = _sqlExpressionFactory.Constant( + value, _typeMappingSource.GetMappingForValue(value)); + } } - } - return updatedFromSql; - } + return _visitedFromSqlExpressions[fromSql] = fromSql.Update(Expression.Constant(constantValues, typeof(object?[]))); - return base.Visit(expression); + default: + Check.DebugAssert(false, "FromSql.Arguments must be Constant/ParameterExpression"); + return null; + } } } } diff --git a/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs index 041f2cbc066..ca7dc4cae04 100644 --- a/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs +++ b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs @@ -125,7 +125,7 @@ public override bool Equals(object? obj) private bool Equals(FromSqlQueryRootExpression queryRootExpression) => base.Equals(queryRootExpression) - && string.Equals(Sql, queryRootExpression.Sql, StringComparison.OrdinalIgnoreCase) + && Sql == queryRootExpression.Sql && ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument); /// diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 032941ba7cd..fe716535d00 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -363,8 +363,7 @@ private void GenerateFromSql(FromSqlExpression fromSqlExpression) switch (fromSqlExpression.Arguments) { - case ConstantExpression constantExpression - when constantExpression.Value is CompositeRelationalParameter compositeRelationalParameter: + case ConstantExpression { Value: CompositeRelationalParameter compositeRelationalParameter }: { var subParameters = compositeRelationalParameter.RelationalParameters; substitutions = new string[subParameters.Count]; @@ -378,8 +377,7 @@ private void GenerateFromSql(FromSqlExpression fromSqlExpression) break; } - case ConstantExpression constantExpression - when constantExpression.Value is object[] constantValues: + case ConstantExpression { Value: object[] constantValues }: { substitutions = new string[constantValues.Length]; for (var i = 0; i < constantValues.Length; i++) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 656cc9793c2..d4cb21f6d02 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -2818,7 +2818,7 @@ EntityProjectionExpression LiftEntityProjectionFromSubquery(EntityProjectionExpr } /// - /// Checks whether this representes a which is not composed upon. + /// Checks whether this represents a which is not composed upon. /// /// A bool value indicating a non-composed . public bool IsNonComposedFromSql() diff --git a/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj b/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj index ca49776d0ea..c4671ee3f09 100644 --- a/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj +++ b/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj @@ -15,6 +15,10 @@ + + + + diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs new file mode 100644 index 00000000000..0e14467b999 --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs @@ -0,0 +1,617 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.TestModels.Northwind; +using Microsoft.EntityFrameworkCore.TestUtilities; +using Microsoft.EntityFrameworkCore.Utilities; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public class FromSqlQueryCosmosTest : QueryTestBase> + { + private static readonly string _eol = Environment.NewLine; + + public FromSqlQueryCosmosTest( + NorthwindQueryCosmosFixture fixture, + ITestOutputHelper testOutputHelper) + : base(fixture) + { + ClearLog(); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + } + + protected NorthwindContext CreateContext() + => Fixture.CreateContext(); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple(bool async) + { + using var context = CreateContext(); + var query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""ContactName""] LIKE '%z%'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + Assert.Equal(14, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""ContactName""] LIKE '%z%' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple_columns_out_of_order(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Equal(91, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_columns(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""PostalCode""] AS Foo, c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Equal(91, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""PostalCode""] AS Foo, c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c").Where(c => c.ContactName.Contains("z")); + + var sql = query.ToQueryString(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + Assert.Equal(14, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_after_removing_whitespaces(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + _eol + " " + _eol + _eol + _eol + "SELECT" + _eol + "* FROM root c") + .Where(c => c.ContactName.Contains("z")); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + +AssertSql( + @"SELECT c +FROM ( + + + + + SELECT + * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_composed_compiled(bool async) + { + if (async) + { + var query = EF.CompileAsyncQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = await query(context).ToListAsync(); + + Assert.Equal(14, actual.Count); + } + } + else + { + var query = EF.CompileQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = query(context).ToArray(); + + Assert.Equal(14, actual.Length); + } + } + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_compiled_with_parameter(bool async) + { + if (async) + { + var query = EF.CompileAsyncQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""CustomerID""] = {0}", "CONSH") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = await query(context).ToListAsync(); + + Assert.Single(actual); + } + } + else + { + var query = EF.CompileQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""CustomerID""] = {0}", "CONSH") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = query(context).ToArray(); + + Assert.Single(actual); + } + } + + AssertSql( + @"@p0='CONSH' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""CustomerID""] = @p0 +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_multiple_line_query(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c +WHERE c[""City""] = 'London'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * + FROM root c + WHERE c[""City""] = 'London' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_multiple_line_query(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c") + .Where(c => c.City == "London"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * + FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""City""] = ""London""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}", city, + contactTitle); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters_inline(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}", "London", + "Sales Representative"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_null_parameter(bool async) + { + uint? reportsTo = null; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""ReportsTo""] = {0} OR (IS_NULL(c[""ReportsTo""]) AND IS_NULL({0}))", reportsTo); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Single(actual); + + AssertSql( + @"@p0=null + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""ReportsTo""] = @p0 OR (IS_NULL(c[""ReportsTo""]) AND IS_NULL(@p0)) +) c +WHERE (c[""Discriminator""] = ""Employee"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters_and_closure(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0}", city) + .Where(c => c.ContactTitle == contactTitle); + var queryString = query.ToQueryString(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@__contactTitle_1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""ContactTitle""] = @__contactTitle_1))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_cache_key_includes_query_string(bool async) + { + using var context = CreateContext(); + var query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""City""] = 'London'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""City""] = 'Seattle'"); + + actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Single(actual); + Assert.True(actual.All(c => c.City == "Seattle")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = 'London' +) c +WHERE (c[""Discriminator""] = ""Customer"")", + // + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = 'Seattle' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_with_parameters_cache_key_includes_parameters(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + var sql = @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(sql, city, contactTitle); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + city = "Madrid"; + contactTitle = "Accounting Manager"; + + query = context.Set().FromSqlRaw(sql, city, contactTitle); + + actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(2, actual.Length); + Assert.True(actual.All(c => c.City == "Madrid")); + Assert.True(actual.All(c => c.ContactTitle == "Accounting Manager")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")", + // + @"@p0='Madrid' +@p1='Accounting Manager' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_as_no_tracking_not_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .AsNoTracking(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Empty(context.ChangeTracker.Entries()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_projection_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c +WHERE NOT c[""Discontinued""] AND ((c[""UnitsInStock""] + c[""UnitsOnOrder""]) < c[""ReorderLevel""])") + .Select(p => p.ProductName); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(2, actual.Length); + + AssertSql( + @"SELECT c[""ProductName""] +FROM ( + SELECT * + FROM root c + WHERE NOT c[""Discontinued""] AND ((c[""UnitsInStock""] + c[""UnitsOnOrder""]) < c[""ReorderLevel""]) +) c +WHERE (c[""Discriminator""] = ""Product"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_composed_with_nullable_predicate(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName == c.CompanyName); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Empty(actual); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""ContactName""] = c[""CompanyName""]))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_does_not_parameterize_interpolated_string(bool async) + { + using var context = CreateContext(); + var propertyName = "OrderID"; + var max = 10250; + var query = context.Orders.FromSqlRaw($@"SELECT * FROM root c WHERE c[""{propertyName}""] < {{0}}", max); + + var actual = async + ? await query.ToListAsync() + : query.ToList(); + + Assert.Equal(2, actual.Count); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_projection_not_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .Select( + c => new { c.CustomerID, c.City }) + .AsNoTracking(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Empty(context.ChangeTracker.Entries()); + + AssertSql( + @"SELECT c[""CustomerID""], c[""City""] +FROM ( + SELECT * FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); + + protected void ClearLog() + => Fixture.TestSqlLoggerFactory.Clear(); + } +} diff --git a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs index 7354517548a..503bc5d2504 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs @@ -297,41 +297,41 @@ public virtual async Task FromSqlRaw_queryable_composed_compiled(bool async) } } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] - public virtual async Task FromSqlRaw_queryable_composed_compiled_with_parameter(bool async) - { - if (async) + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_compiled_with_parameter(bool async) { - var query = EF.CompileAsyncQuery( - (NorthwindContext context) => context.Set() - .FromSqlRaw( - NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") - .Where(c => c.ContactName.Contains("z"))); - - using (var context = CreateContext()) + if (async) { - var actual = await query(context).ToListAsync(); + var query = EF.CompileAsyncQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw( + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") + .Where(c => c.ContactName.Contains("z"))); - Assert.Single(actual); - } - } - else - { - var query = EF.CompileQuery( - (NorthwindContext context) => context.Set() - .FromSqlRaw( - NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") - .Where(c => c.ContactName.Contains("z"))); + using (var context = CreateContext()) + { + var actual = await query(context).ToListAsync(); - using (var context = CreateContext()) + Assert.Single(actual); + } + } + else { - var actual = query(context).ToArray(); + var query = EF.CompileQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw( + NormalizeDelimitersInRawString("SELECT * FROM [Customers] WHERE [CustomerID] = {0}"), "CONSH") + .Where(c => c.ContactName.Contains("z"))); - Assert.Single(actual); + using (var context = CreateContext()) + { + var actual = query(context).ToArray(); + + Assert.Single(actual); + } } } - } [ConditionalTheory] [MemberData(nameof(IsAsyncData))] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/FromSqlQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/FromSqlQuerySqlServerTest.cs index a9177f58fcd..d76bd2a10bc 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/FromSqlQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/FromSqlQuerySqlServerTest.cs @@ -35,7 +35,7 @@ public override async Task FromSqlRaw_queryable_simple_columns_out_of_order(bool @"SELECT ""Region"", ""PostalCode"", ""Phone"", ""Fax"", ""CustomerID"", ""Country"", ""ContactTitle"", ""ContactName"", ""CompanyName"", ""City"", ""Address"" FROM ""Customers"""); } - public override async Task FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_columns(bool async) + public override async Task FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_columns(bool async) { await base.FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_columns(async); @@ -68,7 +68,7 @@ public override async Task FromSqlRaw_queryable_composed_after_removing_whitespa @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM ( - + SELECT