diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs index ef300aac..da9719ef 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs @@ -747,6 +747,7 @@ enum ParameterMode } } + int? batchSize = null; foreach (var attrib in methodAttribs) { if (IsDapperAttribute(attrib)) @@ -778,6 +779,12 @@ enum ParameterMode case Types.CommandPropertyAttribute: cmdPropsCount++; break; + case Types.BatchSizeAttribute: + if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int batchTmp) + { + batchSize = batchTmp; + } + break; } } } @@ -806,8 +813,8 @@ enum ParameterMode } - return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null - ? null : new(rowCountHint, rowCountHintMember?.Member.Name, cmdProps); + return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null + ? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps); } internal static ImmutableArray? SharedGetParametersToInclude(MemberMap? map, ref OperationFlags flags, string? sql, Action? reportDiagnostic, out SqlParseOutputFlags parseFlags) diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Multi.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Multi.cs index 7f47f49f..df84eb6f 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Multi.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Multi.cs @@ -44,6 +44,10 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType) bool isAsync = flags.HasAny(OperationFlags.Async); sb.Append("Execute").Append(isAsync ? "Async" : "").Append("("); sb.Append("(").Append(castType).Append(")param!"); + if (additionalCommandState?.BatchSize is { } batchSize) + { + sb.Append(", batchSize: ").Append(batchSize); + } if (isAsync && HasParam(methodParameters, "cancellationToken")) { sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken")); diff --git a/src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs b/src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs index ed0e3ec1..816b1847 100644 --- a/src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs +++ b/src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs @@ -35,6 +35,7 @@ public bool Equals(in CommandProperty other) internal sealed class AdditionalCommandState : IEquatable { public readonly int RowCountHint; + public readonly int? BatchSize; public readonly string? RowCountHintMemberName; public readonly ImmutableArray CommandProperties; @@ -72,7 +73,8 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited, countMember = null; } - return new(count, countMember, Concat(inherited.CommandProperties, overrides.CommandProperties)); + return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize, + Concat(inherited.CommandProperties, overrides.CommandProperties)); } static ImmutableArray Concat(ImmutableArray x, ImmutableArray y) @@ -85,10 +87,13 @@ static ImmutableArray Concat(ImmutableArray x, return builder.ToImmutable(); } - internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName, ImmutableArray commandProperties) + internal AdditionalCommandState( + int rowCountHint, string? rowCountHintMemberName, int? batchSize, + ImmutableArray commandProperties) { RowCountHint = rowCountHint; RowCountHintMemberName = rowCountHintMemberName; + BatchSize = batchSize; CommandProperties = commandProperties; } @@ -98,7 +103,9 @@ internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName bool IEquatable.Equals(AdditionalCommandState other) => Equals(in other); public bool Equals(in AdditionalCommandState other) - => RowCountHint == other.RowCountHint && RowCountHintMemberName == other.RowCountHintMemberName + => RowCountHint == other.RowCountHint + && BatchSize == other.BatchSize + && RowCountHintMemberName == other.RowCountHintMemberName && ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties)); private static bool Equals(in ImmutableArray x, in ImmutableArray y) @@ -136,6 +143,7 @@ static int GetHashCode(in ImmutableArray x) } public override int GetHashCode() - => (RowCountHint + (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode())) + => (RowCountHint + BatchSize.GetValueOrDefault() + + (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode())) ^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties)); } diff --git a/src/Dapper.AOT.Analyzers/Internal/Types.cs b/src/Dapper.AOT.Analyzers/Internal/Types.cs index 2fe64963..74cbe3ed 100644 --- a/src/Dapper.AOT.Analyzers/Internal/Types.cs +++ b/src/Dapper.AOT.Analyzers/Internal/Types.cs @@ -16,5 +16,6 @@ public const string IDynamicParameters = nameof(IDynamicParameters), SqlMapper = nameof(SqlMapper), SqlAttribute = nameof(SqlAttribute), - ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute); + ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute), + BatchSizeAttribute = nameof(BatchSizeAttribute); } diff --git a/src/Dapper.AOT/BatchSizeAttribute.cs b/src/Dapper.AOT/BatchSizeAttribute.cs new file mode 100644 index 00000000..1315a883 --- /dev/null +++ b/src/Dapper.AOT/BatchSizeAttribute.cs @@ -0,0 +1,19 @@ +using System; +using System.ComponentModel; +using System.Diagnostics; + +namespace Dapper; + +/// +/// Indicates the batch size to use when executing commands with a sequence of argument rows. +/// +[Conditional("DEBUG")] // not needed post-build, so: evaporate +[ImmutableObject(true)] +[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = false)] +public sealed class BatchSizeAttribute : Attribute +{ + /// + /// Indicates the batch size to use when executing commands with a sequence of argument row; a value of zero disables batch usage; a negative value uses a single batch for all rows. + /// + public BatchSizeAttribute(int batchSize) => _ = batchSize; +} diff --git a/test/Dapper.AOT.Test/Interceptors/BatchSize.input.cs b/test/Dapper.AOT.Test/Interceptors/BatchSize.input.cs new file mode 100644 index 00000000..3fe2b9cf --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/BatchSize.input.cs @@ -0,0 +1,23 @@ +using Dapper; +using System.Data.Common; + +[module: DapperAot] + +public static class Foo +{ + [BatchSize(10)] // should be passed explicitly + static void SomeCode(DbConnection connection, string sql, string bar) + { + var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } }; + + connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs); + } + + // no batch size, should be passed implicitly + static void SomeOtherCode(DbConnection connection, string sql, string bar) + { + var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } }; + + connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs); + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/BatchSize.output.cs b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.cs new file mode 100644 index 00000000..8c603f88 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.cs @@ -0,0 +1,144 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 13, 20)] + internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::[] + // parameter map: bar id + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((object?[])param!, batchSize: 10); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 21, 20)] + internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::[] + // parameter map: bar id + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((object?[])param!); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "bar"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.bar); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.id); + ps[1].Value = AsValue(typed.bar); + + } + public override bool CanPrepare => true; + + } + + private sealed class CommandFactory1 : CommonCommandFactory // + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "bar"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.bar); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.id); + ps[1].Value = AsValue(typed.bar); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.cs new file mode 100644 index 00000000..8c603f88 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.cs @@ -0,0 +1,144 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 13, 20)] + internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::[] + // parameter map: bar id + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((object?[])param!, batchSize: 10); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 21, 20)] + internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Execute, HasParameters, Text, KnownParameters + // takes parameter: global::[] + // parameter map: bar id + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((object?[])param!); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "bar"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.bar); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.id); + ps[1].Value = AsValue(typed.bar); + + } + public override bool CanPrepare => true; + + } + + private sealed class CommandFactory1 : CommonCommandFactory // + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "id"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.id); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "bar"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.bar); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.id); + ps[1].Value = AsValue(typed.bar); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.txt b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.txt new file mode 100644 index 00000000..29906b6b --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.netfx.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 2 of 2 possible call-sites using 2 interceptors, 2 commands and 0 readers diff --git a/test/Dapper.AOT.Test/Interceptors/BatchSize.output.txt b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.txt new file mode 100644 index 00000000..29906b6b --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/BatchSize.output.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 2 of 2 possible call-sites using 2 interceptors, 2 commands and 0 readers