diff --git a/src/Core/Models/GraphQLFilterParsers.cs b/src/Core/Models/GraphQLFilterParsers.cs index db99537313..54eb05b7a0 100644 --- a/src/Core/Models/GraphQLFilterParsers.cs +++ b/src/Core/Models/GraphQLFilterParsers.cs @@ -11,6 +11,7 @@ using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; +using Azure.DataApiBuilder.Service.Services; using HotChocolate.Language; using HotChocolate.Resolvers; using Microsoft.AspNetCore.Http; @@ -65,12 +66,12 @@ public Predicate Parse( string dataSourceName = _configProvider.GetConfig().GetDataSourceNameFromEntityName(entityName); ISqlMetadataProvider metadataProvider = _metadataProviderFactory.GetMetadataProvider(dataSourceName); - InputObjectType filterArgumentObject = ResolverMiddleware.InputObjectTypeFromIInputField(filterArgumentSchema); + InputObjectType filterArgumentObject = ExecutionHelper.InputObjectTypeFromIInputField(filterArgumentSchema); List predicates = new(); foreach (ObjectFieldNode field in fields) { - object? fieldValue = ResolverMiddleware.ExtractValueFromIValueNode( + object? fieldValue = ExecutionHelper.ExtractValueFromIValueNode( value: field.Value, argumentSchema: filterArgumentObject.Fields[field.Name.Value], variables: ctx.Variables); @@ -85,7 +86,7 @@ public Predicate Parse( bool fieldIsAnd = string.Equals(name, $"{PredicateOperation.AND}", StringComparison.OrdinalIgnoreCase); bool fieldIsOr = string.Equals(name, $"{PredicateOperation.OR}", StringComparison.OrdinalIgnoreCase); - InputObjectType filterInputObjectType = ResolverMiddleware.InputObjectTypeFromIInputField(filterArgumentObject.Fields[name]); + InputObjectType filterInputObjectType = ExecutionHelper.InputObjectTypeFromIInputField(filterArgumentObject.Fields[name]); if (fieldIsAnd || fieldIsOr) { PredicateOperation op = fieldIsAnd ? PredicateOperation.AND : PredicateOperation.OR; @@ -509,7 +510,7 @@ private Predicate ParseAndOr( List operands = new(); foreach (IValueNode field in fields) { - object? fieldValue = ResolverMiddleware.ExtractValueFromIValueNode( + object? fieldValue = ExecutionHelper.ExtractValueFromIValueNode( value: field, argumentSchema: argumentSchema, ctx.Variables); @@ -598,11 +599,11 @@ public static Predicate Parse( { List predicates = new(); - InputObjectType argumentObject = ResolverMiddleware.InputObjectTypeFromIInputField(argumentSchema); + InputObjectType argumentObject = ExecutionHelper.InputObjectTypeFromIInputField(argumentSchema); foreach (ObjectFieldNode field in fields) { string name = field.Name.ToString(); - object? value = ResolverMiddleware.ExtractValueFromIValueNode( + object? value = ExecutionHelper.ExtractValueFromIValueNode( value: field.Value, argumentSchema: argumentObject.Fields[field.Name.Value], variables: ctx.Variables); diff --git a/src/Core/Resolvers/ArrayPoolWriter.cs b/src/Core/Resolvers/ArrayPoolWriter.cs new file mode 100644 index 0000000000..5d617b059d --- /dev/null +++ b/src/Core/Resolvers/ArrayPoolWriter.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; + +/// +/// A helper to write to pooled arrays. +/// +internal sealed class ArrayPoolWriter : IBufferWriter, IDisposable +{ + private const int INITIAL_BUFFER_SIZE = 512; + private byte[] _buffer; + private int _capacity; + private int _start; + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + public ArrayPoolWriter() + { + _buffer = ArrayPool.Shared.Rent(INITIAL_BUFFER_SIZE); + _capacity = _buffer.Length; + _start = 0; + } + + /// + /// Gets the part of the buffer that has been written to. + /// + /// + /// A of the written portion of the buffer. + /// + public ReadOnlyMemory GetWrittenMemory() + => _buffer.AsMemory()[.._start]; + + /// + /// Gets the part of the buffer that has been written to. + /// + /// + /// A of the written portion of the buffer. + /// + public ReadOnlySpan GetWrittenSpan() + => _buffer.AsSpan()[.._start]; + + /// + /// Advances the writer by the specified number of bytes. + /// + /// + /// The number of bytes to advance the writer by. + /// + /// + /// Thrown if is negative or + /// if is greater than the + /// available capacity on the internal buffer. + /// + public void Advance(int count) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(ArrayPoolWriter)); + } + + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count)); + } + + if (count > _capacity) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Cannot advance past the end of the buffer."); + } + + _start += count; + _capacity -= count; + } + + /// + /// Gets a to write to. + /// + /// + /// The minimum size of the returned . + /// + /// + /// A to write to. + /// + /// + /// Thrown if is negative. + /// + public Memory GetMemory(int sizeHint = 0) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(ArrayPoolWriter)); + } + + if (sizeHint < 0) + { + throw new ArgumentOutOfRangeException(nameof(sizeHint)); + } + + int size = sizeHint < 1 ? INITIAL_BUFFER_SIZE : sizeHint; + EnsureBufferCapacity(size); + return _buffer.AsMemory().Slice(_start, size); + } + + /// + /// Gets a to write to. + /// + /// + /// The minimum size of the returned . + /// + /// + /// A to write to. + /// + /// + /// Thrown if is negative. + /// + public Span GetSpan(int sizeHint = 0) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(ArrayPoolWriter)); + } + + if (sizeHint < 0) + { + throw new ArgumentOutOfRangeException(nameof(sizeHint)); + } + + int size = sizeHint < 1 ? INITIAL_BUFFER_SIZE : sizeHint; + EnsureBufferCapacity(size); + return _buffer.AsSpan().Slice(_start, size); + } + + /// + /// Ensures that the internal buffer has the needed capacity. + /// + /// + /// The needed capacity on the internal buffer. + /// + private void EnsureBufferCapacity(int neededCapacity) + { + // check if we have enough capacity available on the buffer. + if (_capacity < neededCapacity) + { + // if we need to expand the buffer we first capture the original buffer. + byte[] buffer = _buffer; + + // next we determine the new size of the buffer, we at least double the size to avoid + // expanding the buffer too often. + int newSize = buffer.Length * 2; + + // if that new buffer size is not enough to satisfy the needed capacity + // we add the needed capacity to the doubled buffer capacity. + if (neededCapacity > newSize - _start) + { + newSize += neededCapacity; + } + + // next we will rent a new array from the array pool that supports + // the new capacity requirements. + _buffer = ArrayPool.Shared.Rent(newSize); + + // the rented array might have a larger size than the needed capacity, + // so we will take the buffer length and calculate from that the free capacity. + _capacity += _buffer.Length - buffer.Length; + + // finally we copy the data from the original buffer to the new buffer. + buffer.AsSpan().CopyTo(_buffer); + + // last but not least we return the original buffer to the array pool. + ArrayPool.Shared.Return(buffer); + } + } + + /// + public void Dispose() + { + if (!_disposed) + { + ArrayPool.Shared.Return(_buffer); + _buffer = Array.Empty(); + _capacity = 0; + _start = 0; + _disposed = true; + } + } +} diff --git a/src/Core/Resolvers/CosmosQueryEngine.cs b/src/Core/Resolvers/CosmosQueryEngine.cs index f84a11d66a..1f9ea1e12d 100644 --- a/src/Core/Resolvers/CosmosQueryEngine.cs +++ b/src/Core/Resolvers/CosmosQueryEngine.cs @@ -189,14 +189,14 @@ public Task ExecuteAsync(StoredProcedureRequestContext context, s } /// - public JsonDocument ResolveInnerObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) + public JsonElement ResolveObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) { - //TODO: Try to avoid additional deserialization/serialization here. - return JsonDocument.Parse(element.ToString()); + return element; } /// - public object ResolveListType(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) + /// metadata is not used in this method, but it is required by the interface. + public object ResolveList(JsonElement array, IObjectField fieldSchema, ref IMetadata metadata) { IType listType = fieldSchema.Type; // Is the List type nullable? [...]! vs [...] @@ -217,10 +217,10 @@ public object ResolveListType(JsonElement element, IObjectField fieldSchema, ref if (listType.IsObjectType()) { - return JsonSerializer.Deserialize>(element); + return JsonSerializer.Deserialize>(array); } - return JsonSerializer.Deserialize(element, fieldSchema.RuntimeType); + return JsonSerializer.Deserialize(array, fieldSchema.RuntimeType); } /// diff --git a/src/Core/Resolvers/IQueryEngine.cs b/src/Core/Resolvers/IQueryEngine.cs index 8ebd4d4c21..72d93db0f5 100644 --- a/src/Core/Resolvers/IQueryEngine.cs +++ b/src/Core/Resolvers/IQueryEngine.cs @@ -44,11 +44,11 @@ public interface IQueryEngine /// /// Resolves a jsonElement representing an inner object based on the field's schema and metadata /// - public JsonDocument? ResolveInnerObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata); + public JsonElement ResolveObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata); /// /// Resolves a jsonElement representing a list type based on the field's schema and metadata /// - public object? ResolveListType(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata); + public object ResolveList(JsonElement array, IObjectField fieldSchema, ref IMetadata? metadata); } } diff --git a/src/Core/Resolvers/JsonObjectExtensions.cs b/src/Core/Resolvers/JsonObjectExtensions.cs new file mode 100644 index 0000000000..670fb9c894 --- /dev/null +++ b/src/Core/Resolvers/JsonObjectExtensions.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; +using System.Text.Json; +using System.Text.Json.Nodes; + +/// +/// This extension class provides helpers to convert a mutable JSON object +/// to a JSON element or JSON document. +/// +internal static class JsonObjectExtensions +{ + /// + /// Converts a mutable JSON object to an immutable JSON element. + /// + /// + /// The mutable JSON object to convert. + /// + /// + /// An immutable JSON element. + /// + /// + /// Thrown if is . + /// + public static JsonElement ToJsonElement(this JsonObject obj) + { + if (obj == null) + { + throw new ArgumentNullException(nameof(obj)); + } + + // we first write the mutable JsonObject to the pooled buffer and avoid serializing + // to a full JSON string. + using ArrayPoolWriter buffer = new(); + obj.WriteTo(buffer); + + // next we take the reader here and parse the JSON element from the buffer. + Utf8JsonReader reader = new(buffer.GetWrittenSpan()); + + // the underlying JsonDocument will not use pooled arrays to store metadata on it ... + // this JSON element can be safely returned. + return JsonElement.ParseValue(ref reader); + } + + /// + /// Converts a mutable JSON object to an immutable JSON document. + /// + /// + /// The mutable JSON object to convert. + /// + /// + /// An immutable JSON document. + /// + /// + /// Thrown if is . + /// + public static JsonDocument ToJsonDocument(this JsonObject obj) + { + if (obj == null) + { + throw new ArgumentNullException(nameof(obj)); + } + + // we first write the mutable JsonObject to the pooled buffer and avoid serializing + // to a full JSON string. + using ArrayPoolWriter buffer = new(); + obj.WriteTo(buffer); + + // next we parse the JSON document from the buffer. + // this JSON document will be disposed by the GraphQL execution engine. + return JsonDocument.Parse(buffer.GetWrittenMemory()); + } + + private static void WriteTo(this JsonObject obj, IBufferWriter bufferWriter) + { + if (obj == null) + { + throw new ArgumentNullException(nameof(obj)); + } + + if (bufferWriter == null) + { + throw new ArgumentNullException(nameof(bufferWriter)); + } + + using Utf8JsonWriter writer = new(bufferWriter); + obj.WriteTo(writer); + writer.Flush(); + } +} diff --git a/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs index 1921d14ad8..7238db4883 100644 --- a/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs @@ -11,6 +11,7 @@ using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Service.Exceptions; +using Azure.DataApiBuilder.Service.Services; using HotChocolate.Language; using HotChocolate.Resolvers; using Microsoft.AspNetCore.Http; @@ -186,8 +187,8 @@ public void AddJoinPredicatesForRelatedEntity( { // Case where fk in parent entity references the nested entity. // Verify this is a valid fk definition before adding the join predicate. - if (foreignKeyDefinition.ReferencingColumns.Count() > 0 - && foreignKeyDefinition.ReferencedColumns.Count() > 0) + if (foreignKeyDefinition.ReferencingColumns.Count > 0 + && foreignKeyDefinition.ReferencedColumns.Count > 0) { subQuery.Predicates.AddRange(CreateJoinPredicates( SourceAlias, @@ -199,8 +200,8 @@ public void AddJoinPredicatesForRelatedEntity( else if (foreignKeyDefinition.Pair.ReferencingDbTable.Equals(relatedEntityDbObject)) { // Case where fk in nested entity references the parent entity. - if (foreignKeyDefinition.ReferencingColumns.Count() > 0 - && foreignKeyDefinition.ReferencedColumns.Count() > 0) + if (foreignKeyDefinition.ReferencingColumns.Count > 0 + && foreignKeyDefinition.ReferencedColumns.Count > 0) { subQuery.Predicates.AddRange(CreateJoinPredicates( relatedSourceAlias, @@ -431,18 +432,17 @@ internal static List GetSubArgumentNamesFromGQLMutArguments { IObjectField fieldSchema = context.Selection.Field; IInputField itemsArgumentSchema = fieldSchema.Arguments[fieldName]; - InputObjectType itemsArgumentObject = ResolverMiddleware.InputObjectTypeFromIInputField(itemsArgumentSchema); + InputObjectType itemsArgumentObject = ExecutionHelper.InputObjectTypeFromIInputField(itemsArgumentSchema); - Dictionary mutationInput; // An inline argument was set // TODO: This assumes the input was NOT nullable. if (item is List mutationInputRaw) { - mutationInput = new Dictionary(); + Dictionary mutationInput = new(); foreach (ObjectFieldNode node in mutationInputRaw) { string nodeName = node.Name.Value; - mutationInput.Add(nodeName, ResolverMiddleware.ExtractValueFromIValueNode( + mutationInput.Add(nodeName, ExecutionHelper.ExtractValueFromIValueNode( value: node.Value, argumentSchema: itemsArgumentObject.Fields[nodeName], variables: context.Variables)); diff --git a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs index f08b76ded6..dd82e0369a 100644 --- a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs @@ -13,6 +13,7 @@ using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; +using Azure.DataApiBuilder.Service.Services; using HotChocolate.Language; using HotChocolate.Resolvers; using Microsoft.AspNetCore.Http; @@ -635,7 +636,7 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC subStatusCode: DataApiBuilderException.SubStatusCodes.UnexpectedError); } - IDictionary subqueryParams = ResolverMiddleware.GetParametersFromSchemaAndQueryFields(subschemaField, field, _ctx.Variables); + IDictionary subqueryParams = ExecutionHelper.GetParametersFromSchemaAndQueryFields(subschemaField, field, _ctx.Variables); SqlQueryStructure subquery = new( _ctx, subqueryParams, @@ -720,10 +721,10 @@ private List ProcessGqlOrderByArg(List orderByFi HashSet remainingPkCols = new(PrimaryKey()); - InputObjectType orderByArgumentObject = ResolverMiddleware.InputObjectTypeFromIInputField(orderByArgumentSchema); + InputObjectType orderByArgumentObject = ExecutionHelper.InputObjectTypeFromIInputField(orderByArgumentSchema); foreach (ObjectFieldNode field in orderByFields) { - object? fieldValue = ResolverMiddleware.ExtractValueFromIValueNode( + object? fieldValue = ExecutionHelper.ExtractValueFromIValueNode( value: field.Value, argumentSchema: orderByArgumentObject.Fields[field.Name.Value], variables: _ctx.Variables); diff --git a/src/Core/Resolvers/SqlPaginationUtil.cs b/src/Core/Resolvers/SqlPaginationUtil.cs index bda4f4dd6c..0e419e548e 100644 --- a/src/Core/Resolvers/SqlPaginationUtil.cs +++ b/src/Core/Resolvers/SqlPaginationUtil.cs @@ -4,6 +4,7 @@ using System.Collections.Specialized; using System.Net; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -29,16 +30,36 @@ public static class SqlPaginationUtil /// *Connection.hasNextPage which is decided on whether structure.Limit() elements have been returned /// /// - public static JsonDocument CreatePaginationConnectionFromJsonElement(JsonElement root, PaginationMetadata paginationMetadata) + public static JsonElement CreatePaginationConnectionFromJsonElement(JsonElement root, PaginationMetadata paginationMetadata) + => CreatePaginationConnection(root, paginationMetadata).ToJsonElement(); + + /// + /// Wrapper for CreatePaginationConnectionFromJsonElement + /// + public static JsonDocument CreatePaginationConnectionFromJsonDocument(JsonDocument? jsonDocument, PaginationMetadata paginationMetadata) + { + // necessary for MsSql because it doesn't coalesce list query results like Postgres + if (jsonDocument is null) + { + jsonDocument = JsonDocument.Parse("[]"); + } + + JsonElement root = jsonDocument.RootElement.Clone(); + + // create the connection object. + return CreatePaginationConnection(root, paginationMetadata).ToJsonDocument(); + } + + private static JsonObject CreatePaginationConnection(JsonElement root, PaginationMetadata paginationMetadata) { // Maintains the connection JSON object *Connection - Dictionary connectionJson = new(); + JsonObject connection = new(); // in dw we wrap array with "" and hence jsonValueKind is string instead of array. if (root.ValueKind is JsonValueKind.String) { - JsonDocument document = JsonDocument.Parse(root.GetString()!); - root = document.RootElement; + using JsonDocument document = JsonDocument.Parse(root.GetString()!); + root = document.RootElement.Clone(); } IEnumerable rootEnumerated = root.EnumerateArray(); @@ -51,7 +72,7 @@ public static JsonDocument CreatePaginationConnectionFromJsonElement(JsonElement hasExtraElement = rootEnumerated.Count() == paginationMetadata.Structure!.Limit(); // add hasNextPage to connection elements - connectionJson.Add(QueryBuilder.HAS_NEXT_PAGE_FIELD_NAME, hasExtraElement ? true : false); + connection.Add(QueryBuilder.HAS_NEXT_PAGE_FIELD_NAME, hasExtraElement); if (hasExtraElement) { @@ -68,12 +89,12 @@ public static JsonDocument CreatePaginationConnectionFromJsonElement(JsonElement { // use rootEnumerated to make the *Connection.items since the last element of rootEnumerated // is removed if the result has an extra element - connectionJson.Add(QueryBuilder.PAGINATION_FIELD_NAME, JsonSerializer.Serialize(rootEnumerated.ToArray())); + connection.Add(QueryBuilder.PAGINATION_FIELD_NAME, JsonSerializer.Serialize(rootEnumerated.ToArray())); } else { // if the result doesn't have an extra element, just return the dbResult for *Connection.items - connectionJson.Add(QueryBuilder.PAGINATION_FIELD_NAME, root.ToString()!); + connection.Add(QueryBuilder.PAGINATION_FIELD_NAME, root.ToString()!); } } @@ -84,7 +105,7 @@ public static JsonDocument CreatePaginationConnectionFromJsonElement(JsonElement if (returnedElemNo > 0) { JsonElement lastElemInRoot = rootEnumerated.ElementAtOrDefault(returnedElemNo - 1); - connectionJson.Add(QueryBuilder.PAGINATION_TOKEN_FIELD_NAME, + connection.Add(QueryBuilder.PAGINATION_TOKEN_FIELD_NAME, MakeCursorFromJsonElement( lastElemInRoot, paginationMetadata.Structure!.PrimaryKey(), @@ -96,30 +117,7 @@ public static JsonDocument CreatePaginationConnectionFromJsonElement(JsonElement } } - return JsonDocument.Parse(JsonSerializer.Serialize(connectionJson)); - } - - /// - /// Wrapper for CreatePaginationConnectionFromJsonElement - /// Disposes the JsonDocument passed to it - /// - public static JsonDocument CreatePaginationConnectionFromJsonDocument(JsonDocument? jsonDocument, PaginationMetadata paginationMetadata) - { - // necessary for MsSql because it doesn't coalesce list query results like Postgres - if (jsonDocument is null) - { - jsonDocument = JsonDocument.Parse("[]"); - } - - JsonElement root = jsonDocument.RootElement; - - // this is intentionally not disposed since it will be used for processing later - JsonDocument result = CreatePaginationConnectionFromJsonElement(root, paginationMetadata); - - // no longer needed, so it is disposed - jsonDocument.Dispose(); - - return result; + return connection; } /// diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index ec2205bdf7..689c8426a1 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using Azure.DataApiBuilder.Auth; @@ -11,6 +12,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; using HotChocolate.Resolvers; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; @@ -170,9 +172,14 @@ public async Task ExecuteAsync(StoredProcedureRequestContext cont } /// - public JsonDocument? ResolveInnerObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) + public JsonElement ResolveObject(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) { PaginationMetadata parentMetadata = (PaginationMetadata)metadata; + if (parentMetadata.Subqueries.TryGetValue(QueryBuilder.PAGINATION_FIELD_NAME, out PaginationMetadata? paginationObjectMetadata)) + { + parentMetadata = paginationObjectMetadata; + } + PaginationMetadata currentMetadata = parentMetadata.Subqueries[fieldSchema.Name.Value]; metadata = currentMetadata; @@ -180,22 +187,68 @@ public async Task ExecuteAsync(StoredProcedureRequestContext cont { return SqlPaginationUtil.CreatePaginationConnectionFromJsonElement(element, currentMetadata); } - else + + // In certain cirumstances (e.g. when processing a DW result), the JsonElement will be JsonValueKind.String instead + // of JsonValueKind.Object. In this case, we need to parse the JSON. This snippet can be removed when DW result is consistent + // with MSSQL result. + if (element.ValueKind is JsonValueKind.String) { - //TODO: Try to avoid additional deserialization/serialization here. - return ResolverMiddleware.RepresentsNullValue(element) ? null : JsonDocument.Parse(element.ToString()); + return JsonDocument.Parse(element.ToString()).RootElement.Clone(); } + + return element; } - /// - public object? ResolveListType(JsonElement element, IObjectField fieldSchema, ref IMetadata metadata) + /// + /// Resolves the JsonElement, an array, into a list of jsonelements where each element represents + /// an entry in the original array. + /// + /// JsonElement representing a JSON array. The possible representations: + /// JsonValueKind.Array -> ["item1","itemN"] + /// JsonValueKind.String -> "[ { "field1": "field1Value" }, { "field2": "field2Value" }, { ... } ]" + /// - Input JsonElement is JsonValueKind.String because the array and enclosed objects haven't been deserialized yet. + /// - This method deserializes the JSON string (representing a JSON array) and collects each element (Json object) within the + /// list of json elements returned by this method. + /// Definition of field being resolved. For lists: [/]items:[entity!]!] + /// PaginationMetadata of the parent field of the currently processed field in HC middlewarecontext. + /// List of JsonElements parsed from the provided JSON array. + /// Return type is 'object' instead of a 'List of JsonElements' because when this function returns JsonElement, + /// the HC12 engine doesn't know how to handle the JsonElement and results in requests failing at runtime. + public object ResolveList(JsonElement array, IObjectField fieldSchema, ref IMetadata? metadata) { - PaginationMetadata parentMetadata = (PaginationMetadata)metadata; - PaginationMetadata currentMetadata = parentMetadata.Subqueries[fieldSchema.Name.Value]; - metadata = currentMetadata; + if (metadata is not null) + { + PaginationMetadata parentMetadata = (PaginationMetadata)metadata; + PaginationMetadata currentMetadata = parentMetadata.Subqueries[fieldSchema.Name.Value]; + metadata = currentMetadata; + } + + List resolvedList = new(); + + if (array.ValueKind is JsonValueKind.Array) + { + foreach (JsonElement element in array.EnumerateArray()) + { + resolvedList.Add(element); + } + } + else if (array.ValueKind is JsonValueKind.String) + { + using ArrayPoolWriter buffer = new(); + + string text = array.GetString()!; + int neededCapacity = Encoding.UTF8.GetMaxByteCount(text.Length); + int written = Encoding.UTF8.GetBytes(text, buffer.GetSpan(neededCapacity)); + buffer.Advance(written); + + Utf8JsonReader reader = new(buffer.GetWrittenSpan()); + foreach (JsonElement element in JsonElement.ParseValue(ref reader).EnumerateArray()) + { + resolvedList.Add(element); + } + } - //TODO: Try to avoid additional deserialization/serialization here. - return JsonSerializer.Deserialize>(element.ToString()); + return resolvedList; } // diff --git a/src/Core/Services/BuildRequestStateMiddleware.cs b/src/Core/Services/BuildRequestStateMiddleware.cs new file mode 100644 index 0000000000..05d9eaa8a1 --- /dev/null +++ b/src/Core/Services/BuildRequestStateMiddleware.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Core.Authorization; +using HotChocolate.Execution; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; +using RequestDelegate = HotChocolate.Execution.RequestDelegate; + +/// +/// This request middleware will build up our request state and will be invoked once per request. +/// +public sealed class BuildRequestStateMiddleware +{ + private readonly RequestDelegate _next; + + public BuildRequestStateMiddleware(RequestDelegate next) + { + _next = next; + } + + /// + /// Middleware invocation method which attempts to replicate the + /// http context's "X-MS-API-ROLE" header/value to HotChocolate's request context. + /// + /// HotChocolate execution request context. + public async ValueTask InvokeAsync(IRequestContext context) + { + if (context.ContextData.TryGetValue(nameof(HttpContext), out object? value) && + value is HttpContext httpContext) + { + // Because Request.Headers is a NameValueCollection type, key not found will return StringValues.Empty and not an exception. + StringValues clientRoleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; + context.ContextData.TryAdd(key: AuthorizationResolver.CLIENT_ROLE_HEADER, value: clientRoleHeader); + } + + await _next(context).ConfigureAwait(false); + } +} diff --git a/src/Core/Services/ExecutionHelper.cs b/src/Core/Services/ExecutionHelper.cs new file mode 100644 index 0000000000..9a13b1296f --- /dev/null +++ b/src/Core/Services/ExecutionHelper.cs @@ -0,0 +1,596 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Globalization; +using System.Text.Json; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Core.Resolvers.Factories; +using Azure.DataApiBuilder.Service.GraphQLBuilder; +using Azure.DataApiBuilder.Service.GraphQLBuilder.CustomScalars; +using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; +using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; +using HotChocolate.Execution; +using HotChocolate.Language; +using HotChocolate.Resolvers; +using HotChocolate.Types.NodaTime; +using NodaTime.Text; + +namespace Azure.DataApiBuilder.Service.Services +{ + /// + /// This helper class provides the various resolvers and middlewares used + /// during query execution. + /// + internal sealed class ExecutionHelper + { + internal readonly IQueryEngineFactory _queryEngineFactory; + internal readonly IMutationEngineFactory _mutationEngineFactory; + internal readonly RuntimeConfigProvider _runtimeConfigProvider; + + private const string PURE_RESOLVER_CONTEXT_SUFFIX = "_PURE_RESOLVER_CTX"; + + public ExecutionHelper( + IQueryEngineFactory queryEngineFactory, + IMutationEngineFactory mutationEngineFactory, + RuntimeConfigProvider runtimeConfigProvider) + { + _queryEngineFactory = queryEngineFactory; + _mutationEngineFactory = mutationEngineFactory; + _runtimeConfigProvider = runtimeConfigProvider; + } + + /// + /// Represents the root query resolver and fetches the initial data from the query engine. + /// + /// + /// The middleware context. + /// + public async ValueTask ExecuteQueryAsync(IMiddlewareContext context) + { + string dataSourceName = GraphQLUtils.GetDataSourceNameFromGraphQLContext(context, _runtimeConfigProvider.GetConfig()); + DataSource ds = _runtimeConfigProvider.GetConfig().GetDataSourceFromDataSourceName(dataSourceName); + IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(ds.DatabaseType); + + IDictionary parameters = GetParametersFromContext(context); + + if (context.Selection.Type.IsListType()) + { + Tuple, IMetadata?> result = + await queryEngine.ExecuteListAsync(context, parameters, dataSourceName); + + // this will be run after the query / mutation has completed. + context.RegisterForCleanup( + () => + { + foreach (JsonDocument document in result.Item1) + { + document.Dispose(); + } + }); + + context.Result = result.Item1.Select(t => t.RootElement).ToArray(); + SetNewMetadata(context, result.Item2); + } + else + { + Tuple result = + await queryEngine.ExecuteAsync(context, parameters, dataSourceName); + SetContextResult(context, result.Item1); + SetNewMetadata(context, result.Item2); + } + } + + /// + /// Represents the root mutation resolver and invokes the mutation on the query engine. + /// + /// + /// The middleware context. + /// + public async ValueTask ExecuteMutateAsync(IMiddlewareContext context) + { + string dataSourceName = GraphQLUtils.GetDataSourceNameFromGraphQLContext(context, _runtimeConfigProvider.GetConfig()); + DataSource ds = _runtimeConfigProvider.GetConfig().GetDataSourceFromDataSourceName(dataSourceName); + IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(ds.DatabaseType); + + IDictionary parameters = GetParametersFromContext(context); + + // Only Stored-Procedure has ListType as returnType for Mutation + if (context.Selection.Type.IsListType()) + { + // Both Query and Mutation execute the same SQL statement for Stored Procedure. + Tuple, IMetadata?> result = + await queryEngine.ExecuteListAsync(context, parameters, dataSourceName); + + // this will be run after the query / mutation has completed. + context.RegisterForCleanup( + () => + { + foreach (JsonDocument document in result.Item1) + { + document.Dispose(); + } + }); + + context.Result = result.Item1.Select(t => t.RootElement).ToArray(); + SetNewMetadata(context, result.Item2); + } + else + { + IMutationEngine mutationEngine = _mutationEngineFactory.GetMutationEngine(ds.DatabaseType); + Tuple result = + await mutationEngine.ExecuteAsync(context, parameters, dataSourceName); + SetContextResult(context, result.Item1); + SetNewMetadata(context, result.Item2); + } + } + + /// + /// Represents a pure resolver for a leaf field. + /// This resolver extracts the field value from the json object. + /// + /// + /// The pure resolver context. + /// + /// + /// Returns the runtime field value. + /// + public static object? ExecuteLeafField(IPureResolverContext context) + { + // This means this field is a scalar, so we don't need to do + // anything for it. + if (TryGetPropertyFromParent(context, out JsonElement fieldValue) && + fieldValue.ValueKind is not (JsonValueKind.Undefined or JsonValueKind.Null)) + { + // The selection type can be a wrapper type like NonNullType or ListType. + // To get the most inner type (aka the named type) we use our named type helper. + INamedType namedType = context.Selection.Field.Type.NamedType(); + + // Each scalar in HotChocolate has a runtime type representation. + // In order to let scalar values flow through the GraphQL type completion + // efficiently we want the leaf types to match the runtime type. + // If that is not the case a value will go through the type converter to try to + // transform it into the runtime type. + // We also want to ensure here that we do not unnecessarily convert values to + // strings and then force the conversion to parse them. + return namedType switch + { + StringType => fieldValue.GetString(), // spec + ByteType => fieldValue.GetByte(), + ShortType => fieldValue.GetInt16(), + IntType => fieldValue.GetInt32(), // spec + LongType => fieldValue.GetInt64(), + FloatType => fieldValue.GetDouble(), // spec + SingleType => fieldValue.GetSingle(), + DecimalType => fieldValue.GetDecimal(), + DateTimeType => DateTimeOffset.Parse(fieldValue.GetString()!, DateTimeFormatInfo.InvariantInfo, DateTimeStyles.AssumeUniversal), + DateType => DateTimeOffset.Parse(fieldValue.GetString()!), + LocalTimeType => LocalTimePattern.ExtendedIso.Parse(fieldValue.GetString()!).Value, + ByteArrayType => fieldValue.GetBytesFromBase64(), + BooleanType => fieldValue.GetBoolean(), // spec + UrlType => new Uri(fieldValue.GetString()!), + UuidType => fieldValue.GetGuid(), + TimeSpanType => TimeSpan.Parse(fieldValue.GetString()!), + _ => fieldValue.GetString() + }; + } + + return null; + } + + /// + /// Represents a pure resolver for an object field. + /// This resolver extracts another json object from the parent json property. + /// + /// + /// The pure resolver context. + /// + /// + /// Returns a new json object. + /// + public object? ExecuteObjectField(IPureResolverContext context) + { + string dataSourceName = GraphQLUtils.GetDataSourceNameFromGraphQLContext(context, _runtimeConfigProvider.GetConfig()); + DataSource ds = _runtimeConfigProvider.GetConfig().GetDataSourceFromDataSourceName(dataSourceName); + IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(ds.DatabaseType); + + if (TryGetPropertyFromParent(context, out JsonElement objectValue) && + objectValue.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined) + { + IMetadata metadata = GetMetadataObjectField(context); + objectValue = queryEngine.ResolveObject(objectValue, context.Selection.Field, ref metadata); + + // Since the query engine could null the object out we need to check again + // if it's null. + if (objectValue.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined) + { + return null; + } + + SetNewMetadataChildren(context, metadata); + return objectValue; + } + + return null; + } + + /// + /// The ListField pure resolver is executed when processing "list" fields. + /// For example, when executing the query { myEntity { items { entityField1 } } } + /// this pure resolver will be executed when processing the field "items" because + /// it will contain the "list" of results. + /// + /// PureResolver context provided by HC middleware. + /// The resolved list, a JSON array, returned as type 'object?'. + /// Return type is 'object?' instead of a 'List of JsonElements' because when this function returns JsonElement, + /// the HC12 engine doesn't know how to handle the JsonElement and results in requests failing at runtime. + public object? ExecuteListField(IPureResolverContext context) + { + string dataSourceName = GraphQLUtils.GetDataSourceNameFromGraphQLContext(context, _runtimeConfigProvider.GetConfig()); + DataSource ds = _runtimeConfigProvider.GetConfig().GetDataSourceFromDataSourceName(dataSourceName); + IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(ds.DatabaseType); + + if (TryGetPropertyFromParent(context, out JsonElement listValue) && + listValue.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined) + { + IMetadata? metadata = GetMetadata(context); + object result = queryEngine.ResolveList(listValue, context.Selection.Field, ref metadata); + SetNewMetadataChildren(context, metadata); + return result; + } + + return null; + } + + /// + /// Set the context's result and dispose properly. If result is not null + /// clone root and dispose, otherwise set to null. + /// + /// Context to store result. + /// Result to store in context. + private static void SetContextResult(IMiddlewareContext context, JsonDocument? result) + { + if (result is not null) + { + context.RegisterForCleanup(() => result.Dispose()); + // Since the JsonDocument instance is registered for disposal, + // we don't need to clone the root element. Since the JsonDocument + // won't be disposed of after this code block. + context.Result = result.RootElement; + } + else + { + context.Result = null; + } + } + + private static bool TryGetPropertyFromParent( + IPureResolverContext context, + out JsonElement propertyValue) + { + JsonElement parent = context.Parent(); + + if (parent.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + propertyValue = default; + return false; + } + + return parent.TryGetProperty(context.Selection.Field.Name.Value, out propertyValue); + } + + /// + /// Extracts the value from an IValueNode. That includes extracting the value of the variable + /// if the IValueNode is a variable and extracting the correct type from the IValueNode + /// + /// the IValueNode from which to extract the value + /// describes the schema of the argument that the IValueNode represents + /// the request context variable values needed to resolve value nodes represented as variables + public static object? ExtractValueFromIValueNode( + IValueNode value, + IInputField argumentSchema, + IVariableValueCollection variables) + { + // extract value from the variable if the IValueNode is a variable + if (value.Kind == SyntaxKind.Variable) + { + string variableName = ((VariableNode)value).Name.Value; + IValueNode? variableValue = variables.GetVariable(variableName); + + if (variableValue is null) + { + return null; + } + + return ExtractValueFromIValueNode(variableValue, argumentSchema, variables); + } + + if (value is NullValueNode) + { + return null; + } + + return argumentSchema.Type.TypeName().Value switch + { + SupportedHotChocolateTypes.BYTE_TYPE => ((IntValueNode)value).ToByte(), + SupportedHotChocolateTypes.SHORT_TYPE => ((IntValueNode)value).ToInt16(), + SupportedHotChocolateTypes.INT_TYPE => ((IntValueNode)value).ToInt32(), + SupportedHotChocolateTypes.LONG_TYPE => ((IntValueNode)value).ToInt64(), + SupportedHotChocolateTypes.SINGLE_TYPE => ((FloatValueNode)value).ToSingle(), + SupportedHotChocolateTypes.FLOAT_TYPE => ((FloatValueNode)value).ToDouble(), + SupportedHotChocolateTypes.DECIMAL_TYPE => ((FloatValueNode)value).ToDecimal(), + SupportedHotChocolateTypes.UUID_TYPE => Guid.TryParse(value.Value!.ToString(), out Guid guidValue) ? guidValue : value.Value, + _ => value.Value + }; + } + + /// + /// First: Creates parameters using the GraphQL schema's ObjectTypeDefinition metadata + /// and metadata from the request's (query) field. + /// Then: Creates parameters from schema argument fields when they have default values. + /// Lastly: Gets the user provided argument values from the query to either: + /// 1. Overwrite the parameter value if it exists in the collectedParameters dictionary + /// or + /// 2. Adds the parameter/parameter value to the dictionary. + /// + /// + /// Dictionary of parameters + /// Key: (string) argument field name + /// Value: (object) argument value + /// + public static IDictionary GetParametersFromSchemaAndQueryFields( + IObjectField schema, + FieldNode query, + IVariableValueCollection variables) + { + IDictionary collectedParameters = new Dictionary(); + + // Fill the parameters dictionary with the default argument values + IFieldCollection schemaArguments = schema.Arguments; + + // Example 'argumentSchemas' IInputField objects of type 'HotChocolate.Types.Argument': + // These are all default arguments defined in the schema for queries. + // {first:int} + // {after:String} + // {filter:entityFilterInput} + // {orderBy:entityOrderByInput} + // The values in schemaArguments will have default values when the backing + // entity is a stored procedure with runtime config defined default parameter values. + foreach (IInputField argument in schemaArguments) + { + if (argument.DefaultValue != null) + { + collectedParameters.Add( + argument.Name.Value, + ExtractValueFromIValueNode( + value: argument.DefaultValue, + argumentSchema: argument, + variables: variables)); + } + } + + // Overwrite the default values with the passed in arguments + // Example: { myEntity(first: $first, orderBy: {entityField: ASC) { items { entityField } } } + // User supplied $first filter variable overwrites the default value of 'first'. + // User supplied 'orderBy' filter overwrites the default value of 'orderBy'. + IReadOnlyList passedArguments = query.Arguments; + + foreach (ArgumentNode argument in passedArguments) + { + string argumentName = argument.Name.Value; + IInputField argumentSchema = schemaArguments[argumentName]; + + object? nodeValue = ExtractValueFromIValueNode( + value: argument.Value, + argumentSchema: argumentSchema, + variables: variables); + + if (!collectedParameters.TryAdd(argumentName, nodeValue)) + { + collectedParameters[argumentName] = nodeValue; + } + } + + return collectedParameters; + } + + /// + /// InnerMostType is innermost type of the passed Graph QL type. + /// This strips all modifiers, such as List and Non-Null. + /// So the following GraphQL types would all have the underlyingType Book: + /// - Book + /// - [Book] + /// - Book! + /// - [Book]! + /// - [Book!]! + /// + internal static IType InnerMostType(IType type) + { + if (type.ToString() == type.InnerType().ToString()) + { + return type; + } + + return InnerMostType(type.InnerType()); + } + + public static InputObjectType InputObjectTypeFromIInputField(IInputField field) + { + return (InputObjectType)(InnerMostType(field.Type)); + } + + /// + /// Creates a dictionary of parameters and associated values from + /// the GraphQL request's MiddlewareContext from arguments provided + /// in the request. e.g. first, after, filter, orderBy, and stored procedure + /// parameters. + /// + /// GraphQL HotChocolate MiddlewareContext + /// Dictionary of parameters and their values. + private static IDictionary GetParametersFromContext( + IMiddlewareContext context) + { + return GetParametersFromSchemaAndQueryFields( + context.Selection.Field, + context.Selection.SyntaxNode, + context.Variables); + } + + /// + /// Get metadata from HotChocolate's GraphQL request MiddlewareContext. + /// The metadata key is the root field name + _PURE_RESOLVER_CTX + :: + PathDepth. + /// CosmosDB does not utilize pagination metadata. So this function will return null + /// when executing GraphQl queries against CosmosDB. + /// + private static IMetadata? GetMetadata(IPureResolverContext context) + { + if (context.Selection.ResponseName == QueryBuilder.PAGINATION_FIELD_NAME && context.Path.Parent is not null) + { + // entering this block means that: + // context.Selection.ResponseName: items + // context.Path: /entityA/items (Depth: 1) + // context.Path.Parent: /entityA (Depth: 0) + // The parent's metadata will be stored in ContextData with a depth of context.Path minus 1. -> "::0" + // The resolved metadata key is entityA_PURE_RESOLVER_CTX and is appended with "::0" + // Another case would be: + // context.Path: /books/items[0]/authors/items + // context.Path.Parent: /books/items[0]/authors + // The nuance here is that HC counts the depth when the path is expanded as + // /books/items/items[idx]/authors -> Depth: 3 (0-indexed) which maps to the + // pagination metadata for the "authors/items" subquery. + string paginationObjectParentName = GetMetadataKey(context.Path) + "::" + context.Path.Parent.Depth; + return (IMetadata?)context.ContextData[paginationObjectParentName]; + } + + // This section would be reached when processing a Cosmos query of the form: + // { planet_by_pk (id: $id, _partitionKeyValue: $partitionKeyValue) { tags } } + // where nested entities like the entity 'tags' are not nested within an "items" field + // like for SQL databases. + string metadataKey = GetMetadataKey(context.Path) + "::" + context.Path.Depth; + + if (context.ContextData.TryGetValue(key: metadataKey, out object? paginationMetadata) && paginationMetadata is not null) + { + return (IMetadata)paginationMetadata; + } + else + { + // CosmosDB database type does not utilize pagination metadata. + return PaginationMetadata.MakeEmptyPaginationMetadata(); + } + } + + /// + /// Get the pagination metadata object for the field represented by the + /// pure resolver context. + /// e.g. when Context.Path is "/books/items[0]/authors", this function gets + /// the pagination metadata for authors, which is stored in the global middleware + /// context under key: "books_PURE_RESOLVER_CTX::1", where "books" is the parent object + /// and depth of "1" implicitly represents the path "/books/items". When "/books/items" + /// is processed by the pure resolver, the available pagination metadata maps to the object + /// type that enumerated in "items" + /// + /// Pure resolver context + /// Pagination metadata + private static IMetadata GetMetadataObjectField(IPureResolverContext context) + { + // Depth Levels: / 0 / 1 / 2 / 3 + // Example Path: /books/items/items[0]/publishers + // Depth of 1 should have key in context.ContextData + // Depth of 2 will not have context.ContextData entry because non-Indexed path element is the path that is cached. + // PaginationMetadata for items will be consistent across each subitem. So we can use the same metadata for each subitem. + // An indexer path segment is a segment that looks like -> items[n] + if (context.Path.Parent is IndexerPathSegment) + { + // When context.Path is "/books/items[0]/authors" + // Parent -> "/books/items[0]" + // Parent -> "/books/items" -> Depth of this path is used to create the key to get + // paginationmetadata from context.ContextData + // The PaginationMetadata fetched has subquery metadata for "authors" from path "/books/items/authors" + string objectParentName = GetMetadataKey(context.Path) + "::" + context.Path.Parent!.Parent!.Depth; + return (IMetadata)context.ContextData[objectParentName]!; + } + else if (context.Path.Parent is not null && ((NamePathSegment)context.Path.Parent).Name != PURE_RESOLVER_CONTEXT_SUFFIX) + { + // This check handles when the current selection is a relationship field because in that case, + // there will be no context data entry. + // e.g. metadata for index 4 will not exist. only 3. + // Depth: / 0 / 1 / 2 / 3 / 4 + // Path: /books/items/items[0]/publishers/books + string objectParentName = GetMetadataKey(context.Path) + "::" + context.Path.Parent!.Depth; + return (IMetadata)context.ContextData[objectParentName]!; + } + + string metadataKey = GetMetadataKey(context.Path) + "::" + context.Path.Depth; + return (IMetadata)context.ContextData[metadataKey]!; + } + + private static string GetMetadataKey(HotChocolate.Path path) + { + HotChocolate.Path currentPath = path; + + if (currentPath.Parent is RootPathSegment or null) + { + // current: "/entity/items -> "items" + return ((NamePathSegment)currentPath).Name + PURE_RESOLVER_CONTEXT_SUFFIX; + } + + // If execution reaches this point, the state of currentPath looks something + // like the following where there exists a Parent path element: + // "/entity/items -> current.Parent: "entity" + return GetMetadataKey(path: currentPath.Parent); + } + + /// + /// Resolves the name of the root object of a selection set to + /// use as the beginning of a key used to index pagination metadata in the + /// global HC middleware context. + /// + /// Root object field of query. + /// "rootObjectName_PURE_RESOLVER_CTX" + private static string GetMetadataKey(IFieldSelection rootSelection) + { + return rootSelection.ResponseName + PURE_RESOLVER_CONTEXT_SUFFIX; + } + + /// + /// Persist new metadata with a key denoting the depth of the current path. + /// The pagination metadata persisted here correlates to the top-level object type + /// denoted in the request. + /// e.g. books_PURE_RESOLVER_CTX::0 for: + /// context.Path -> /books depth(0) + /// context.Selection -> books { items {id, title}} + /// + private static void SetNewMetadata(IPureResolverContext context, IMetadata? metadata) + { + string metadataKey = GetMetadataKey(context.Selection) + "::" + context.Path.Depth; + context.ContextData.Add(metadataKey, metadata); + } + + /// + /// Stores the pagination metadata in the global context.ContextData accessible to + /// all pure resolvers for query fields referencing nested entities. + /// + /// Pure resolver context + /// Pagination metadata + private static void SetNewMetadataChildren(IPureResolverContext context, IMetadata? metadata) + { + // When context.Path is /entity/items the metadata key is "entity" + // The context key will use the depth of "items" so that the provided + // pagination metadata (which holds the subquery metadata for "/entity/items/nestedEntity") + // can be stored for future access when the "/entity/items/nestedEntity" pure resolver executes. + // When context.Path takes the form: "/entity/items[index]/nestedEntity" HC counts the depth as + // if the path took the form: "/entity/items/items[index]/nestedEntity" -> Depth of "nestedEntity" + // is 3 because depth is 0-indexed. + string contextKey = GetMetadataKey(context.Path) + "::" + context.Path.Depth; + + // It's okay to overwrite the context when we are visiting a different item in items e.g. books/items/items[1]/publishers since + // context for books/items/items[0]/publishers processing is done and that context isn't needed anymore. + if (!context.ContextData.TryAdd(contextKey, metadata)) + { + context.ContextData[contextKey] = metadata; + } + } + } +} diff --git a/src/Core/Services/GraphQLSchemaCreator.cs b/src/Core/Services/GraphQLSchemaCreator.cs index e46ae0de93..89d58d05ec 100644 --- a/src/Core/Services/GraphQLSchemaCreator.cs +++ b/src/Core/Services/GraphQLSchemaCreator.cs @@ -17,6 +17,7 @@ using Azure.DataApiBuilder.Service.GraphQLBuilder.Mutations; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; using Azure.DataApiBuilder.Service.GraphQLBuilder.Sql; +using Azure.DataApiBuilder.Service.Services; using HotChocolate.Language; using Microsoft.Extensions.DependencyInjection; @@ -100,8 +101,8 @@ private ISchemaBuilder Parse( .AddDocument(mutationNode) // Enable the OneOf directive (https://github.com/graphql/graphql-spec/pull/825) to support the DefaultValue type .ModifyOptions(o => o.EnableOneOf = true) - // Add our custom middleware for GraphQL resolvers - .Use((services, next) => new ResolverMiddleware(next, _queryEngineFactory, _mutationEngineFactory, _runtimeConfigProvider)); + // Adds our type interceptor that will create the resolvers. + .TryAddTypeInterceptor(new ResolverTypeInterceptor(new ExecutionHelper(_queryEngineFactory, _mutationEngineFactory, _runtimeConfigProvider))); } /// diff --git a/src/Core/Services/ResolverMiddleware.cs b/src/Core/Services/ResolverMiddleware.cs deleted file mode 100644 index 505e5ef0ad..0000000000 --- a/src/Core/Services/ResolverMiddleware.cs +++ /dev/null @@ -1,388 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Globalization; -using System.Text.Json; -using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Authorization; -using Azure.DataApiBuilder.Core.Configurations; -using Azure.DataApiBuilder.Core.Models; -using Azure.DataApiBuilder.Core.Resolvers; -using Azure.DataApiBuilder.Core.Resolvers.Factories; -using Azure.DataApiBuilder.Service.GraphQLBuilder; -using Azure.DataApiBuilder.Service.GraphQLBuilder.CustomScalars; -using HotChocolate.Execution; -using HotChocolate.Language; -using HotChocolate.Resolvers; -using HotChocolate.Types.NodaTime; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; -using NodaTime.Text; -using static Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes.SupportedHotChocolateTypes; - -namespace Azure.DataApiBuilder.Core.Services -{ - /// - /// The field resolver middleware that is used by the schema executor to resolve - /// the queries and mutations - /// - public class ResolverMiddleware - { - private static readonly string _contextMetadata = "metadata"; - internal readonly FieldDelegate _next; - internal readonly IQueryEngineFactory _queryEngineFactory; - internal readonly IMutationEngineFactory _mutationEngineFactory; - internal readonly RuntimeConfigProvider _runtimeConfigProvider; - - public ResolverMiddleware(FieldDelegate next, - IQueryEngineFactory queryEngineFactory, - IMutationEngineFactory mutationEngineFactory, - RuntimeConfigProvider runtimeConfigProvider) - { - _next = next; - _queryEngineFactory = queryEngineFactory; - _mutationEngineFactory = mutationEngineFactory; - _runtimeConfigProvider = runtimeConfigProvider; - } - - /// - /// HotChocolate invokes this method when this ResolverMiddleware is utilized - /// in the request pipeline. - /// From this method, the Query and Mutation engines are executed, and the execution - /// results saved in the IMiddlewareContext's result property. - /// - /// - /// HotChocolate middleware context containing request metadata. - /// Does not explicitly return data. - public async Task InvokeAsync(IMiddlewareContext context) - { - JsonElement jsonElement; - string dataSourceName = GraphQLUtils.GetDataSourceNameFromGraphQLContext(context, _runtimeConfigProvider.GetConfig()); - DataSource ds = _runtimeConfigProvider.GetConfig().GetDataSourceFromDataSourceName(dataSourceName); - - IQueryEngine queryEngine = _queryEngineFactory.GetQueryEngine(ds.DatabaseType); - - if (context.ContextData.TryGetValue("HttpContext", out object? value)) - { - if (value is not null) - { - HttpContext httpContext = (HttpContext)value; - StringValues clientRoleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; - context.ContextData.TryAdd(key: AuthorizationResolver.CLIENT_ROLE_HEADER, value: clientRoleHeader); - } - } - - if (context.Selection.Field.Coordinate.TypeName.Value == "Mutation") - { - IDictionary parameters = GetParametersFromContext(context); - // Only Stored-Procedure has ListType as returnType for Mutation - if (context.Selection.Type.IsListType()) - { - // Both Query and Mutation execute the same SQL statement for Stored Procedure. - Tuple, IMetadata?> result = await queryEngine.ExecuteListAsync(context, parameters, dataSourceName); - context.Result = GetListOfClonedElements(result.Item1); - SetNewMetadata(context, result.Item2); - } - else - { - IMutationEngine mutationEngine = _mutationEngineFactory.GetMutationEngine(ds.DatabaseType); - Tuple result = await mutationEngine.ExecuteAsync(context, parameters, dataSourceName); - SetContextResult(context, result.Item1); - SetNewMetadata(context, result.Item2); - } - } - else if (context.Selection.Field.Coordinate.TypeName.Value == "Query") - { - IDictionary parameters = GetParametersFromContext(context); - - if (context.Selection.Type.IsListType()) - { - Tuple, IMetadata?> result = await queryEngine.ExecuteListAsync(context, parameters, dataSourceName); - context.Result = GetListOfClonedElements(result.Item1); - SetNewMetadata(context, result.Item2); - } - else - { - Tuple result = await queryEngine.ExecuteAsync(context, parameters, dataSourceName); - SetContextResult(context, result.Item1); - SetNewMetadata(context, result.Item2); - } - } - else if (context.Selection.Field.Type.IsLeafType()) - { - // This means this field is a scalar, so we don't need to do - // anything for it. - if (TryGetPropertyFromParent(context, out jsonElement)) - { - context.Result = RepresentsNullValue(jsonElement) ? null : PreParseLeaf(context, jsonElement.ToString()); - } - } - else if (IsInnerObject(context)) - { - // This means it's a field that has another custom type as its - // type, so there is a full JSON object inside this key. For - // example such a JSON object could have been created by a - // One-To-Many join. - if (TryGetPropertyFromParent(context, out jsonElement)) - { - IMetadata metadata = GetMetadata(context); - using JsonDocument? innerObject = queryEngine.ResolveInnerObject(jsonElement, context.Selection.Field, ref metadata); - if (innerObject is not null) - { - context.Result = innerObject.RootElement.Clone(); - } - else - { - context.Result = null; - } - - SetNewMetadata(context, metadata); - } - } - else if (context.Selection.Type.IsListType()) - { - // This means the field is a list and HotChocolate requires - // that to be returned as a List of JsonDocuments. For example - // such a JSON list could have been created by a One-To-Many - // join. - if (TryGetPropertyFromParent(context, out jsonElement)) - { - IMetadata metadata = GetMetadata(context); - context.Result = queryEngine.ResolveListType(jsonElement, context.Selection.Field, ref metadata); - SetNewMetadata(context, metadata); - } - } - - await _next(context); - } - - /// - /// Set the context's result and dispose properly. If result is not null - /// clone root and dispose, otherwise set to null. - /// - /// Context to store result. - /// Result to store in context. - private static void SetContextResult(IMiddlewareContext context, JsonDocument? result) - { - if (result is not null) - { - context.Result = result.RootElement.Clone(); - result.Dispose(); - } - else - { - context.Result = null; - } - } - - /// - /// Create and return a list of cloned root elements from a collection of JsonDocuments. - /// Dispose of each JsonDocument after its root element is cloned. - /// - /// List of JsonDocuments to clone and dispose. - /// List of cloned root elements. - private static IEnumerable GetListOfClonedElements(IEnumerable docList) - { - List result = new(); - foreach (JsonDocument jsonDoc in docList) - { - result.Add(jsonDoc.RootElement.Clone()); - jsonDoc.Dispose(); - } - - return result; - } - - /// - /// Preparse a string extracted from the json result representing a leaf. - /// This is helpful in cases when HotChocolate's internal resolvers cannot appropriately - /// parse the result so we preparse the result so it can be appropriately handled by HotChocolate - /// later - /// - /// - /// e.g. "1" despite being a valid byte value is parsed improperly by HotChocolate so we preparse it - /// to an actual byte value then feed the result to HotChocolate - /// - private static object PreParseLeaf(IMiddlewareContext context, string leafJson) - { - IType leafType = context.Selection.Field.Type is NonNullType - ? context.Selection.Field.Type.NullableType() : context.Selection.Field.Type; - - return leafType switch - { - ByteType => byte.Parse(leafJson), - SingleType => Single.Parse(leafJson), - DateTimeType => DateTimeOffset.Parse(leafJson, DateTimeFormatInfo.InvariantInfo, DateTimeStyles.AssumeUniversal), - ByteArrayType => Convert.FromBase64String(leafJson), - LocalTimeType => LocalTimePattern.ExtendedIso.Parse(leafJson).Value, - _ => leafJson - }; - } - - public static bool RepresentsNullValue(JsonElement element) - { - return (string.IsNullOrEmpty(element.ToString()) && element.GetRawText() == "null"); - } - - protected static bool TryGetPropertyFromParent(IMiddlewareContext context, out JsonElement jsonElement) - { - JsonDocument result = JsonDocument.Parse(JsonSerializer.Serialize(context.Parent())); - if (result is null) - { - jsonElement = default; - return false; - } - - return result.RootElement.TryGetProperty(context.Selection.Field.Name.Value, out jsonElement); - } - - protected static bool IsInnerObject(IMiddlewareContext context) - { - return context.Selection.Field.Type.IsObjectType() && context.Parent() is not null; - } - - /// - /// Extracts the value from an IValueNode. That includes extracting the value of the variable - /// if the IValueNode is a variable and extracting the correct type from the IValueNode - /// - /// the IValueNode from which to extract the value - /// describes the schema of the argument that the IValueNode represents - /// the request context variable values needed to resolve value nodes represented as variables - public static object? ExtractValueFromIValueNode(IValueNode value, IInputField argumentSchema, IVariableValueCollection variables) - { - // extract value from the variable if the IValueNode is a variable - if (value.Kind == SyntaxKind.Variable) - { - string variableName = ((VariableNode)value).Name.Value; - IValueNode? variableValue = variables.GetVariable(variableName); - - if (variableValue is null) - { - return null; - } - - return ExtractValueFromIValueNode(variableValue, argumentSchema, variables); - } - - if (value is NullValueNode) - { - return null; - } - - return argumentSchema.Type.TypeName().Value switch - { - BYTE_TYPE => ((IntValueNode)value).ToByte(), - SHORT_TYPE => ((IntValueNode)value).ToInt16(), - INT_TYPE => ((IntValueNode)value).ToInt32(), - LONG_TYPE => ((IntValueNode)value).ToInt64(), - SINGLE_TYPE => ((FloatValueNode)value).ToSingle(), - FLOAT_TYPE => ((FloatValueNode)value).ToDouble(), - DECIMAL_TYPE => ((FloatValueNode)value).ToDecimal(), - // If we reach here, we can be sure that the value will not be null. - UUID_TYPE => Guid.TryParse(value.Value!.ToString(), out Guid guidValue) ? guidValue : value.Value, - _ => value.Value - }; - } - - /// - /// Extract parameters from the schema and the actual instance (query) of the field - /// Extracts default parameter values from the schema or null if no default - /// Overrides default values with actual values of parameters provided - /// Key: (string) argument field name - /// Value: (object) argument value - /// - public static IDictionary GetParametersFromSchemaAndQueryFields(IObjectField schema, FieldNode query, IVariableValueCollection variables) - { - IDictionary parameters = new Dictionary(); - - // Fill the parameters dictionary with the default argument values - IFieldCollection argumentSchemas = schema.Arguments; - foreach (IInputField argument in argumentSchemas) - { - if (argument.DefaultValue != null) - { - parameters.Add( - argument.Name.Value, - ExtractValueFromIValueNode( - value: argument.DefaultValue, - argumentSchema: argument, - variables: variables)); - } - } - - // Overwrite the default values with the passed in arguments - IReadOnlyList passedArguments = query.Arguments; - foreach (ArgumentNode argument in passedArguments) - { - string argumentName = argument.Name.Value; - IInputField argumentSchema = argumentSchemas[argumentName]; - - if (parameters.ContainsKey(argumentName)) - { - parameters[argumentName] = - ExtractValueFromIValueNode( - value: argument.Value, - argumentSchema: argumentSchema, - variables: variables); - } - else - { - parameters.Add( - argumentName, - ExtractValueFromIValueNode( - value: argument.Value, - argumentSchema: argumentSchema, - variables: variables)); - } - } - - return parameters; - } - - /// - /// InnerMostType is innermost type of the passed Graph QL type. - /// This strips all modifiers, such as List and Non-Null. - /// So the following GraphQL types would all have the underlyingType Book: - /// - Book - /// - [Book] - /// - Book! - /// - [Book]! - /// - [Book!]! - /// - internal static IType InnerMostType(IType type) - { - if (type.ToString() == type.InnerType().ToString()) - { - return type; - } - - return InnerMostType(type.InnerType()); - } - - public static InputObjectType InputObjectTypeFromIInputField(IInputField field) - { - return (InputObjectType)(InnerMostType(field.Type)); - } - - protected static IDictionary GetParametersFromContext(IMiddlewareContext context) - { - return GetParametersFromSchemaAndQueryFields(context.Selection.Field, context.Selection.SyntaxNode, context.Variables); - } - - /// - /// Get metadata from context - /// - private static IMetadata GetMetadata(IMiddlewareContext context) - { - return (IMetadata)context.ScopedContextData[_contextMetadata]!; - } - - /// - /// Set new metadata and reset the depth that the metadata has persisted - /// - private static void SetNewMetadata(IMiddlewareContext context, IMetadata? metadata) - { - context.ScopedContextData = context.ScopedContextData.SetItem(_contextMetadata, metadata); - } - } -} diff --git a/src/Core/Services/ResolverTypeInterceptor.cs b/src/Core/Services/ResolverTypeInterceptor.cs new file mode 100644 index 0000000000..9adc3069dd --- /dev/null +++ b/src/Core/Services/ResolverTypeInterceptor.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Service.Services; +using HotChocolate.Configuration; +using HotChocolate.Resolvers; +using HotChocolate.Types.Descriptors.Definitions; + +internal sealed class ResolverTypeInterceptor : TypeInterceptor +{ + private readonly FieldMiddlewareDefinition _queryMiddleware; + private readonly FieldMiddlewareDefinition _mutationMiddleware; + private readonly PureFieldDelegate _leafFieldResolver; + private readonly PureFieldDelegate _objectFieldResolver; + private readonly PureFieldDelegate _listFieldResolver; + + public ResolverTypeInterceptor(ExecutionHelper executionHelper) + { + _queryMiddleware = + new FieldMiddlewareDefinition( + next => async context => + { + await executionHelper.ExecuteQueryAsync(context).ConfigureAwait(false); + await next(context).ConfigureAwait(false); + }); + + _mutationMiddleware = + new FieldMiddlewareDefinition( + next => async context => + { + await executionHelper.ExecuteMutateAsync(context).ConfigureAwait(false); + await next(context).ConfigureAwait(false); + }); + + _leafFieldResolver = ctx => ExecutionHelper.ExecuteLeafField(ctx); + _objectFieldResolver = ctx => executionHelper.ExecuteObjectField(ctx); + _listFieldResolver = ctx => executionHelper.ExecuteListField(ctx); + } + + public override void OnBeforeCompleteType( + ITypeCompletionContext completionContext, + DefinitionBase? definition, + IDictionary contextData) + { + // We are only interested in object types here as only object types can have resolvers. + if (definition is not ObjectTypeDefinition objectTypeDef) + { + return; + } + + if (completionContext.IsQueryType ?? false) + { + foreach (ObjectFieldDefinition field in objectTypeDef.Fields) + { + field.MiddlewareDefinitions.Add(_queryMiddleware); + } + } + else if (completionContext.IsMutationType ?? false) + { + foreach (ObjectFieldDefinition field in objectTypeDef.Fields) + { + field.MiddlewareDefinitions.Add(_mutationMiddleware); + } + } + else if (completionContext.IsSubscriptionType ?? false) + { + throw new NotSupportedException(); + } + else + { + foreach (ObjectFieldDefinition field in objectTypeDef.Fields) + { + // In order to inspect the type we need to resolve the type reference on the definition. + // If it's null or cannot be resolved something is wrong, but we skip over this and let + // the type validation deal with schema errors. + if (field.Type is not null && + completionContext.TryGetType(field.Type, out IType? type)) + { + // Do not override a PureResolver when one is already set. + if (field.PureResolver is not null) + { + continue; + } + + if (type.IsLeafType()) + { + field.PureResolver = _leafFieldResolver; + } + else if (type.IsObjectType()) + { + field.PureResolver = _objectFieldResolver; + } + else if (type.IsListType()) + { + field.PureResolver = _listFieldResolver; + } + } + } + } + } +} diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 5232f16008..e116641e98 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -202,7 +202,7 @@ RequestValidator requestValidator return await DispatchMutation(context, sqlMetadataProvider.GetDatabaseType()); default: throw new NotSupportedException("This operation is not yet supported."); - }; + } } /// @@ -486,7 +486,7 @@ public async Task AuthorizationCheckForRequirementAsync(object? resource, IAutho { if (requirement is not RoleContextPermissionsRequirement && resource is null) { - throw new ArgumentNullException(paramName: "resource", message: $"Resource can't be null for the requirement: {requirement.GetType}"); + throw new ArgumentNullException(paramName: "resource", message: $"Resource can't be null for the requirement: {requirement.GetType()}"); } AuthorizationResult authorizationResult = await _authorizationService.AuthorizeAsync( diff --git a/src/Service.GraphQLBuilder/GraphQLUtils.cs b/src/Service.GraphQLBuilder/GraphQLUtils.cs index b98e13349f..e02872b6e6 100644 --- a/src/Service.GraphQLBuilder/GraphQLUtils.cs +++ b/src/Service.GraphQLBuilder/GraphQLUtils.cs @@ -278,7 +278,8 @@ error is ArgumentException || /// Generates the datasource name from the GraphQL context. /// /// Middleware context. - public static string GetDataSourceNameFromGraphQLContext(IMiddlewareContext context, RuntimeConfig runtimeConfig) + /// Datasource name used to execute request. + public static string GetDataSourceNameFromGraphQLContext(IPureResolverContext context, RuntimeConfig runtimeConfig) { string rootNode = context.Selection.Field.Coordinate.TypeName.Value; string dataSourceName; @@ -317,7 +318,7 @@ public static string GetDataSourceNameFromGraphQLContext(IMiddlewareContext cont /// /// Get entity name from context object. /// - public static string GetEntityNameFromContext(IMiddlewareContext context) + public static string GetEntityNameFromContext(IPureResolverContext context) { string entityName = context.Selection.Field.Type.TypeName(); @@ -355,7 +356,7 @@ public static string GetEntityNameFromContext(IMiddlewareContext context) return entityName; } - private static string GenerateDataSourceNameKeyFromPath(IMiddlewareContext context) + private static string GenerateDataSourceNameKeyFromPath(IPureResolverContext context) { return $"{context.Path.ToList()[0]}"; } diff --git a/src/Service.Tests/Resolvers/ArrayPoolWriterTests.cs b/src/Service.Tests/Resolvers/ArrayPoolWriterTests.cs new file mode 100644 index 0000000000..9f22097d01 --- /dev/null +++ b/src/Service.Tests/Resolvers/ArrayPoolWriterTests.cs @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.Resolvers +{ + [TestClass] + public class ArrayPoolWriterTests + { + [TestMethod] + public void Constructor_ShouldInitializeProperly() + { + // Arrange & Act + using ArrayPoolWriter writer = new(); + + // Assert + Assert.AreEqual(0, writer.GetWrittenSpan().Length); + } + + [TestMethod] + public void GetWrittenMemory_ShouldReturnReadOnlyMemory() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act + ReadOnlyMemory memory = writer.GetWrittenMemory(); + + // Assert + Assert.AreEqual(0, memory.Length); + } + + [TestMethod] + public void GetWrittenSpan_ShouldReturnReadOnlySpan() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act + ReadOnlySpan span = writer.GetWrittenSpan(); + + // Assert + Assert.AreEqual(0, span.Length); + } + + [TestMethod] + public void Advance_ShouldAdvanceCorrectly() + { + // Arrange + using ArrayPoolWriter writer = new(); + writer.GetSpan(10); + + // Act + writer.Advance(5); + + // Assert + Assert.AreEqual(5, writer.GetWrittenSpan().Length); + } + + [TestMethod] + public void GetMemory_ShouldReturnMemoryWithCorrectSizeHint() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act + Memory memory = writer.GetMemory(10); + + // Assert + Assert.IsTrue(memory.Length >= 10); + } + + [TestMethod] + public void GetSpan_ShouldReturnSpanWithCorrectSizeHint() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act + Span span = writer.GetSpan(10); + + // Assert + Assert.IsTrue(span.Length >= 10); + } + + [TestMethod] + public void Dispose_ShouldDisposeCorrectly() + { + // Arrange + ArrayPoolWriter writer = new(); + + // Act + writer.Dispose(); + + // Assert + Assert.ThrowsException(() => writer.GetMemory()); + Assert.ThrowsException(() => writer.GetSpan()); + Assert.ThrowsException(() => writer.Advance(0)); + } + + [TestMethod] + public void Advance_ShouldThrowWhenDisposed() + { + // Arrange + ArrayPoolWriter writer = new(); + writer.Dispose(); + + // Act & Assert + Assert.ThrowsException(() => writer.Advance(0)); + } + + [TestMethod] + public void Advance_ShouldThrowWhenNegativeCount() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act & Assert + Assert.ThrowsException(() => writer.Advance(-1)); + } + + [TestMethod] + public void Advance_ShouldThrowWhenCountGreaterThanCapacity() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act & Assert + Assert.ThrowsException( + () => writer.Advance(1024)); + } + + [TestMethod] + public void GetMemory_ShouldThrowWhenDisposed() + { + // Arrange + ArrayPoolWriter writer = new(); + writer.Dispose(); + + // Act & Assert + Assert.ThrowsException(() => writer.GetMemory()); + } + + [TestMethod] + public void GetMemory_ShouldThrowWhenNegativeSizeHint() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act & Assert + Assert.ThrowsException(() => writer.GetMemory(-1)); + } + + [TestMethod] + public void GetSpan_ShouldThrowWhenDisposed() + { + // Arrange + ArrayPoolWriter writer = new(); + writer.Dispose(); + + // Act & Assert + Assert.ThrowsException(() => writer.GetSpan()); + } + + [TestMethod] + public void GetSpan_ShouldThrowWhenNegativeSizeHint() + { + // Arrange + using ArrayPoolWriter writer = new(); + + // Act & Assert + Assert.ThrowsException(() => writer.GetSpan(-1)); + } + + [TestMethod] + public void WriteBytesToSpan_ShouldWriteCorrectly() + { + // Arrange + using ArrayPoolWriter writer = new(); + byte[] testData = { 1, 2, 3, 4 }; + + // Act + Span span = writer.GetSpan(4); + testData.CopyTo(span); + writer.Advance(4); + + // Assert + Assert.AreEqual(4, writer.GetWrittenSpan().Length); + ReadOnlySpan writtenSpan = writer.GetWrittenSpan(); + Assert.IsTrue(testData.SequenceEqual(writtenSpan.ToArray())); + } + + [TestMethod] + public void WriteBytesToMemory_ShouldWriteCorrectly() + { + // Arrange + using ArrayPoolWriter writer = new(); + byte[] testData = { 1, 2, 3, 4 }; + + // Act + Memory memory = writer.GetMemory(4); + testData.CopyTo(memory); + writer.Advance(4); + + // Assert + Assert.AreEqual(4, writer.GetWrittenSpan().Length); + ReadOnlyMemory writtenMemory = writer.GetWrittenMemory(); + Assert.IsTrue(testData.SequenceEqual(writtenMemory.ToArray())); + } + + [TestMethod] + public void WriteBytesExceedingInitialBufferSize_ShouldExpandAndWriteCorrectly() + { + // Arrange + using ArrayPoolWriter writer = new(); + byte[] testData = new byte[1024]; + + for (int i = 0; i < testData.Length; i++) + { + testData[i] = (byte)(i % 256); + } + + // Act + for (int i = 0; i < testData.Length; i += 128) + { + Span span = writer.GetSpan(128); + testData.AsSpan(i, 128).CopyTo(span); + writer.Advance(128); + } + + // Assert + Assert.AreEqual(1024, writer.GetWrittenSpan().Length); + ReadOnlySpan writtenSpan = writer.GetWrittenSpan(); + Assert.AreEqual(true, testData.SequenceEqual(writtenSpan.ToArray())); + } + } +} diff --git a/src/Service.Tests/Unittests/MultiSourceQueryExecutionUnitTests.cs b/src/Service.Tests/Unittests/MultiSourceQueryExecutionUnitTests.cs index ec2cb5d193..59a5182eaf 100644 --- a/src/Service.Tests/Unittests/MultiSourceQueryExecutionUnitTests.cs +++ b/src/Service.Tests/Unittests/MultiSourceQueryExecutionUnitTests.cs @@ -12,16 +12,15 @@ using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Resolvers; using Azure.DataApiBuilder.Core.Resolvers.Factories; -using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; +using Azure.DataApiBuilder.Service.Services; using Azure.DataApiBuilder.Service.Tests.GraphQLBuilder.Helpers; using Azure.Identity; using HotChocolate; using HotChocolate.Execution; using HotChocolate.Execution.Processing; using HotChocolate.Resolvers; -using HotChocolate.Types; using Microsoft.AspNetCore.Http; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Logging; @@ -134,7 +133,7 @@ public async Task TestMultiSourceQuery() .AddDirectiveType() .AddType() .AddType() - .Use((services, next) => new ResolverMiddleware(next, queryEngineFactory.Object, mutationEngineFactory.Object, provider)); + .TryAddTypeInterceptor(new ResolverTypeInterceptor(new ExecutionHelper(queryEngineFactory.Object, mutationEngineFactory.Object, provider))); ISchema schema = schemaBuilder.Create(); IExecutionResult result = await schema.MakeExecutable().ExecuteAsync(_query); diff --git a/src/Service/Azure.DataApiBuilder.Service.csproj b/src/Service/Azure.DataApiBuilder.Service.csproj index bcbb48a6e6..d4bf2601db 100644 --- a/src/Service/Azure.DataApiBuilder.Service.csproj +++ b/src/Service/Azure.DataApiBuilder.Service.csproj @@ -12,6 +12,10 @@ true + + + + diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index b0cdf0ac78..5a83ce3e7b 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -244,7 +244,9 @@ private void AddGraphQLService(IServiceCollection services) } return error; - }); + }) + .UseRequest() + .UseDefaultPipeline(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline.