Skip to content

Commit 8a0fca5

Browse files
authored
Support arbitrary enumerables in NpgsqlArrayConverter (#3290)
Closes #3286
1 parent 30cebf0 commit 8a0fca5

File tree

3 files changed

+176
-62
lines changed

3 files changed

+176
-62
lines changed

src/EFCore.PG/Storage/Internal/Mapping/NpgsqlArrayTypeMapping.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,10 @@ public override DbParameter CreateParameter(
222222
// In queries which compose non-server-correlated LINQ operators over an array parameter (e.g. Where(b => ids.Skip(1)...) we
223223
// get an enumerable parameter value that isn't an array/list - but those aren't supported at the Npgsql ADO level.
224224
// Detect this here and evaluate the enumerable to get a fully materialized List.
225+
// Note that when we have a value converter (e.g. for HashSet), we don't want to convert it to a List, since the value converter
226+
// expects the original type.
225227
// TODO: Make Npgsql support IList<> instead of only arrays and List<>
226-
if (value is not null && !value.GetType().IsArrayOrGenericList())
228+
if (value is not null && Converter is null && !value.GetType().IsArrayOrGenericList())
227229
{
228230
switch (value)
229231
{

src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs

+153-61
Original file line numberDiff line numberDiff line change
@@ -105,93 +105,185 @@ private static Expression<Func<TInput, TOutput>> ArrayConversionExpression<TInpu
105105
p);
106106
}
107107

108-
var input = Parameter(typeof(TInput), "value");
108+
var input = Parameter(typeof(TInput), "input");
109+
var convertedInput = input;
109110
var output = Parameter(typeof(TConcreteOutput), "result");
110-
var loopVariable = Parameter(typeof(int), "i");
111111
var lengthVariable = Variable(typeof(int), "length");
112112

113113
var expressions = new List<Expression>();
114-
var variables = new List<ParameterExpression>(4)
115-
{
116-
output,
117-
lengthVariable,
118-
};
114+
var variables = new List<ParameterExpression> { output, lengthVariable };
119115

120116
Expression getInputLength;
121-
Func<Expression, Expression> indexer;
117+
Func<Expression, Expression>? indexer;
122118

123-
if (typeof(TInput).IsArray)
124-
{
125-
getInputLength = ArrayLength(input);
126-
indexer = i => ArrayAccess(input, i);
127-
}
128-
else if (typeof(TInput).IsGenericType
129-
&& typeof(TInput).GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>)))
130-
{
131-
getInputLength = Property(
132-
input,
133-
typeof(TInput).GetProperty("Count")
134-
// If TInput is an interface (IList<T>), its Count property needs to be found on ICollection<T>
135-
?? typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!);
136-
indexer = i => Property(input, input.Type.FindIndexerProperty()!, i);
137-
}
138-
else
119+
// The conversion is going to depend on what kind of input we have: array, list, collection, or arbitrary IEnumerable.
120+
// For array/list we can get the length and index inside, so we can do an efficient for loop.
121+
// For other ICollections (e.g. HashSet) we can get the length (and so pre-allocate the output), but we can't index; so we
122+
// get an enumerator and use that.
123+
// For arbitrary IEnumerable, we can't get the length so we can't preallocate output arrays; so we to call ToList() on it and then
124+
// process that (note that we could avoid that when the output is a List rather than an array).
125+
var inputInterfaces = input.Type.GetInterfaces();
126+
switch (input.Type)
139127
{
140-
// Input collection isn't typed as an ICollection<T>; it can be *typed* as an IEnumerable<T>, but we only support concrete
141-
// instances being ICollection<T>. Emit code that casts the type at runtime.
142-
var iListType = typeof(IList<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]);
128+
// Input is typed as an array - we can get its length and index into it
129+
case { IsArray: true }:
130+
getInputLength = ArrayLength(input);
131+
indexer = i => ArrayAccess(input, i);
132+
break;
133+
134+
// Input is typed as an IList - we can get its length and index into it
135+
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
136+
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IList<>)):
137+
{
138+
getInputLength = Property(
139+
input,
140+
input.Type.GetProperty("Count")
141+
// If TInput is an interface (IList<T>), its Count property needs to be found on ICollection<T>
142+
?? typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!);
143+
indexer = i => Property(input, input.Type.FindIndexerProperty()!, i);
144+
break;
145+
}
143146

144-
var convertedInput = Variable(iListType, "convertedInput");
145-
variables.Add(convertedInput);
147+
// Input is typed as an ICollection - we can get its length, but we can't index into it
148+
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
149+
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(ICollection<>)):
150+
{
151+
getInputLength = Property(
152+
input, typeof(ICollection<>).MakeGenericType(input.Type.GetGenericArguments()[0]).GetProperty("Count")!);
153+
indexer = null;
154+
break;
155+
}
146156

147-
expressions.Add(Assign(convertedInput, Convert(input, convertedInput.Type)));
157+
// Input is typed as an IEnumerable - we can't get its length, and we can't index into it.
158+
// All we can do is call ToList() on it and then process that.
159+
case { IsGenericType: true } when inputInterfaces.Append(input.Type)
160+
.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)):
161+
{
162+
// TODO: In theory, we could add runtime checks for array/list/collection, downcast for those cases and include
163+
// the logic from the other switch cases here.
164+
convertedInput = Variable(typeof(List<>).MakeGenericType(inputElementType), "convertedInput");
165+
variables.Add(convertedInput);
166+
expressions.Add(
167+
Assign(
168+
convertedInput,
169+
Call(typeof(Enumerable).GetMethod(nameof(Enumerable.ToList))!.MakeGenericMethod(inputElementType), input)));
170+
getInputLength = Property(convertedInput, convertedInput.Type.GetProperty("Count")!);
171+
indexer = i => Property(convertedInput, convertedInput.Type.FindIndexerProperty()!, i);
172+
break;
173+
}
148174

149-
// TODO: Check and properly throw for non-IList<T>, e.g. set
150-
getInputLength = Property(
151-
convertedInput, typeof(ICollection<>).MakeGenericType(typeof(TInput).GetGenericArguments()[0]).GetProperty("Count")!);
152-
indexer = i => Property(convertedInput, iListType.FindIndexerProperty()!, i);
175+
default:
176+
throw new NotSupportedException($"Array value converter input type must be an IEnumerable, but is {typeof(TInput)}");
153177
}
154178

155179
expressions.AddRange(
156180
[
157181
// Get the length of the input array or list
158-
// var length = input.Length;
159-
Assign(lengthVariable, getInputLength),
160-
161-
// Allocate an output array or list
162-
// var result = new int[length];
163-
Assign(
164-
output, typeof(TConcreteOutput).IsArray
165-
? NewArrayBounds(outputElementType, lengthVariable)
166-
: typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength
167-
? New(ctorWithLength, lengthVariable)
168-
: New(typeof(TConcreteOutput).GetConstructor([])!)),
169-
170-
// Loop over the elements, applying the element converter on them one by one
171-
// for (var i = 0; i < length; i++)
172-
// {
173-
// result[i] = input[i];
174-
// }
182+
// var length = input.Length;
183+
Assign(lengthVariable, getInputLength),
184+
185+
// Allocate an output array or list
186+
// var result = new int[length];
187+
Assign(
188+
output, typeof(TConcreteOutput).IsArray
189+
? NewArrayBounds(outputElementType, lengthVariable)
190+
: typeof(TConcreteOutput).GetConstructor([typeof(int)]) is ConstructorInfo ctorWithLength
191+
? New(ctorWithLength, lengthVariable)
192+
: New(typeof(TConcreteOutput).GetConstructor([])!))
193+
]);
194+
195+
if (indexer is not null)
196+
{
197+
// Good case: the input is an array or list, so we can index into it. Generate code for an efficient for loop, which applies
198+
// the element converter on each element.
199+
// for (var i = 0; i < length; i++)
200+
// {
201+
// result[i] = input[i];
202+
// }
203+
var counter = Parameter(typeof(int), "i");
204+
205+
expressions.Add(
175206
ForLoop(
176-
loopVar: loopVariable,
207+
loopVar: counter,
177208
initValue: Constant(0),
178-
condition: LessThan(loopVariable, lengthVariable),
179-
increment: AddAssign(loopVariable, Constant(1)),
209+
condition: LessThan(counter, lengthVariable),
210+
increment: AddAssign(counter, Constant(1)),
180211
loopContent:
181212
typeof(TConcreteOutput).IsArray
182213
? Assign(
183-
ArrayAccess(output, loopVariable),
214+
ArrayAccess(output, counter),
184215
elementConversionExpression is null
185-
? indexer(loopVariable)
186-
: Invoke(elementConversionExpression, indexer(loopVariable)))
216+
? indexer(counter)
217+
: Invoke(elementConversionExpression, indexer(counter)))
187218
: Call(
188219
output,
189220
typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!,
190221
elementConversionExpression is null
191-
? indexer(loopVariable)
192-
: Invoke(elementConversionExpression, indexer(loopVariable)))),
193-
output
194-
]);
222+
? indexer(counter)
223+
: Invoke(elementConversionExpression, indexer(counter)))));
224+
}
225+
else
226+
{
227+
// Bad case: the input is not an array or list, but is a collection (e.g. HashSet), so we can't index into it.
228+
// Generate code for a less efficient enumerator-based iteration.
229+
// enumerator = input.GetEnumerator();
230+
// counter = 0;
231+
// while (enumerator.MoveNext())
232+
// {
233+
// output[counter] = enumerator.Current;
234+
// counter++;
235+
// }
236+
var enumerableType = typeof(IEnumerable<>).MakeGenericType(inputElementType);
237+
var enumeratorType = typeof(IEnumerator<>).MakeGenericType(inputElementType);
238+
239+
var enumeratorVariable = Variable(enumeratorType, "enumerator");
240+
var counterVariable = Variable(typeof(int), "variable");
241+
variables.AddRange([enumeratorVariable, counterVariable]);
242+
243+
expressions.AddRange(
244+
[
245+
// enumerator = input.GetEnumerator();
246+
Assign(enumeratorVariable, Call(input, enumerableType.GetMethod(nameof(IEnumerable<object>.GetEnumerator))!)),
247+
248+
// counter = 0;
249+
Assign(counterVariable, Constant(0))
250+
]);
251+
252+
var breakLabel = Label("LoopBreak");
253+
254+
var loop =
255+
Loop(
256+
IfThenElse(
257+
Equal(Call(enumeratorVariable, typeof(IEnumerator).GetMethod(nameof(IEnumerator.MoveNext))!), Constant(true)),
258+
Block(
259+
typeof(TConcreteOutput).IsArray
260+
// output[counter] = enumerator.Current;
261+
? Assign(
262+
ArrayAccess(output, counterVariable),
263+
elementConversionExpression is null
264+
? Property(enumeratorVariable, "Current")
265+
: Invoke(elementConversionExpression, Property(enumeratorVariable, "Current")))
266+
// output.Add(enumerator.Current);
267+
: Call(
268+
output,
269+
typeof(TConcreteOutput).GetMethod("Add", [outputElementType])!,
270+
elementConversionExpression is null
271+
? Property(enumeratorVariable, "Current")
272+
: Invoke(elementConversionExpression, Property(enumeratorVariable, "Current"))),
273+
274+
// counter++;
275+
AddAssign(counterVariable, Constant(1))),
276+
Break(breakLabel)),
277+
breakLabel);
278+
279+
expressions.Add(
280+
TryFinally(
281+
loop,
282+
Call(enumeratorVariable, typeof(IDisposable).GetMethod(nameof(IDisposable.Dispose))!)));
283+
}
284+
285+
// return output;
286+
expressions.Add(output);
195287

196288
return Lambda<Func<TInput, TOutput>>(
197289
// First, check if the given array value is null and return null immediately if so

test/EFCore.PG.FunctionalTests/Query/PrimitiveCollectionsQueryNpgsqlTest.cs

+20
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,26 @@ WHERE NOT (p."Int" = ANY (@__ints_0) AND p."Int" = ANY (@__ints_0) IS NOT NULL)
508508
""");
509509
}
510510

511+
[ConditionalTheory]
512+
[MemberData(nameof(IsAsyncData))]
513+
public virtual async Task Parameter_collection_HashSet_with_value_converter_Contains(bool async)
514+
{
515+
HashSet<MyEnum> enums = [MyEnum.Value1, MyEnum.Value4];
516+
517+
await AssertQuery(
518+
async,
519+
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => enums.Contains(c.Enum)));
520+
521+
AssertSql(
522+
"""
523+
@__enums_0={ '0', '3' } (DbType = Object)
524+
525+
SELECT p."Id", p."Bool", p."Bools", p."DateTime", p."DateTimes", p."Enum", p."Enums", p."Int", p."Ints", p."NullableInt", p."NullableInts", p."NullableString", p."NullableStrings", p."String", p."Strings"
526+
FROM "PrimitiveCollectionsEntity" AS p
527+
WHERE p."Enum" = ANY (@__enums_0)
528+
""");
529+
}
530+
511531
public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
512532
{
513533
await base.Parameter_collection_of_ints_Contains_nullable_int(async);

0 commit comments

Comments
 (0)