Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add [BatchSize] and pass thru to multi-row execute #76

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ enum ParameterMode
}
}

int? batchSize = null;
foreach (var attrib in methodAttribs)
{
if (IsDapperAttribute(attrib))
Expand Down Expand Up @@ -778,6 +779,12 @@ enum ParameterMode
case Types.CommandPropertyAttribute:
cmdPropsCount++;
break;
case Types.BatchSizeAttribute:
if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int batchTmp)
{
batchSize = batchTmp;
}
break;
}
}
}
Expand Down Expand Up @@ -806,8 +813,8 @@ enum ParameterMode
}


return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null
? null : new(rowCountHint, rowCountHintMember?.Member.Name, cmdProps);
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps);
}

internal static ImmutableArray<ElementMember>? SharedGetParametersToInclude(MemberMap? map, ref OperationFlags flags, string? sql, Action<Diagnostic>? reportDiagnostic, out SqlParseOutputFlags parseFlags)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
bool isAsync = flags.HasAny(OperationFlags.Async);
sb.Append("Execute").Append(isAsync ? "Async" : "").Append("(");
sb.Append("(").Append(castType).Append(")param!");
if (additionalCommandState?.BatchSize is { } batchSize)
{
sb.Append(", batchSize: ").Append(batchSize);
}
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
Expand Down
16 changes: 12 additions & 4 deletions src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public bool Equals(in CommandProperty other)
internal sealed class AdditionalCommandState : IEquatable<AdditionalCommandState>
{
public readonly int RowCountHint;
public readonly int? BatchSize;
public readonly string? RowCountHintMemberName;
public readonly ImmutableArray<CommandProperty> CommandProperties;

Expand Down Expand Up @@ -72,7 +73,8 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited,
countMember = null;
}

return new(count, countMember, Concat(inherited.CommandProperties, overrides.CommandProperties));
return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize,
Concat(inherited.CommandProperties, overrides.CommandProperties));
}

static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x, ImmutableArray<CommandProperty> y)
Expand All @@ -85,10 +87,13 @@ static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x,
return builder.ToImmutable();
}

internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName, ImmutableArray<CommandProperty> commandProperties)
internal AdditionalCommandState(
int rowCountHint, string? rowCountHintMemberName, int? batchSize,
ImmutableArray<CommandProperty> commandProperties)
{
RowCountHint = rowCountHint;
RowCountHintMemberName = rowCountHintMemberName;
BatchSize = batchSize;
CommandProperties = commandProperties;
}

Expand All @@ -98,7 +103,9 @@ internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName
bool IEquatable<AdditionalCommandState>.Equals(AdditionalCommandState other) => Equals(in other);

public bool Equals(in AdditionalCommandState other)
=> RowCountHint == other.RowCountHint && RowCountHintMemberName == other.RowCountHintMemberName
=> RowCountHint == other.RowCountHint
&& BatchSize == other.BatchSize
&& RowCountHintMemberName == other.RowCountHintMemberName
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties));

private static bool Equals(in ImmutableArray<CommandProperty> x, in ImmutableArray<CommandProperty> y)
Expand Down Expand Up @@ -136,6 +143,7 @@ static int GetHashCode(in ImmutableArray<CommandProperty> x)
}

public override int GetHashCode()
=> (RowCountHint + (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
=> (RowCountHint + BatchSize.GetValueOrDefault()
+ (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties));
}
3 changes: 2 additions & 1 deletion src/Dapper.AOT.Analyzers/Internal/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ public const string
IDynamicParameters = nameof(IDynamicParameters),
SqlMapper = nameof(SqlMapper),
SqlAttribute = nameof(SqlAttribute),
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute);
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute),
BatchSizeAttribute = nameof(BatchSizeAttribute);
}
19 changes: 19 additions & 0 deletions src/Dapper.AOT/BatchSizeAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.ComponentModel;
using System.Diagnostics;

namespace Dapper;

/// <summary>
/// Indicates the batch size to use when executing commands with a sequence of argument rows.
/// </summary>
[Conditional("DEBUG")] // not needed post-build, so: evaporate
[ImmutableObject(true)]
[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = false)]
public sealed class BatchSizeAttribute : Attribute
{
/// <summary>
/// Indicates the batch size to use when executing commands with a sequence of argument row; a value of zero disables batch usage; a negative value uses a single batch for all rows.
/// </summary>
public BatchSizeAttribute(int batchSize) => _ = batchSize;
}
23 changes: 23 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/BatchSize.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Dapper;
using System.Data.Common;

[module: DapperAot]

public static class Foo
{
[BatchSize(10)] // should be passed explicitly
static void SomeCode(DbConnection connection, string sql, string bar)
{
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };

connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
}

// no batch size, should be passed implicitly
static void SomeOtherCode(DbConnection connection, string sql, string bar)
{
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };

connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
}
}
144 changes: 144 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/BatchSize.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 13, 20)]
internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::<anonymous type: int id, string bar>[]
// parameter map: bar id
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((object?[])param!, batchSize: 10);

}

[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 21, 20)]
internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::<anonymous type: int id, string bar>[]
// parameter map: bar id
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((object?[])param!);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class CommandFactory0 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.id);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "bar";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.bar);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.id);
ps[1].Value = AsValue(typed.bar);

}
public override bool CanPrepare => true;

}

private sealed class CommandFactory1 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
{
internal static readonly CommandFactory1 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.id);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "bar";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.bar);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.id);
ps[1].Value = AsValue(typed.bar);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading
Loading