From 503cf596770c6c5f2883f6573b48e4bd4798588c Mon Sep 17 00:00:00 2001 From: Timothy Makkison Date: Wed, 12 Jun 2024 20:05:37 +0100 Subject: [PATCH] feat: optimize `CachedRequestBuilder` --- Refit.Tests/CachedRequestBuilder.cs | 148 ++++++++++++++++++++ Refit/CachedRequestBuilderImplementation.cs | 108 ++++++++++---- 2 files changed, 232 insertions(+), 24 deletions(-) create mode 100644 Refit.Tests/CachedRequestBuilder.cs diff --git a/Refit.Tests/CachedRequestBuilder.cs b/Refit.Tests/CachedRequestBuilder.cs new file mode 100644 index 000000000..e9b88c8dd --- /dev/null +++ b/Refit.Tests/CachedRequestBuilder.cs @@ -0,0 +1,148 @@ +using System.Net; +using System.Net.Http; +using System.Reflection; + +using RichardSzalay.MockHttp; + +using Xunit; + +namespace Refit.Tests; + +public interface IGeneralRequests +{ + [Post("/foo")] + Task Empty(); + + [Post("/foo")] + Task SingleParameter(string id); + + [Post("/foo")] + Task MultiParameter(string id, string name); + + [Post("/foo")] + Task SingleGenericMultiParameter(string id, string name, TValue generic); +} + +public interface IDuplicateNames +{ + [Post("/foo")] + Task SingleParameter(string id); + + [Post("/foo")] + Task SingleParameter(int id); +} + +public class CachedRequestBuilderTests +{ + [Fact] + public async Task CacheHasCorrectNumberOfElementsTest() + { + var mockHttp = new MockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + var fixture = RestService.For("http://bar", settings); + + // get internal dictionary to check count + var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder"); + var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation; + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .Respond(HttpStatusCode.OK); + await fixture.Empty(); + Assert.Single(requestBuilder.MethodDictionary); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter("id"); + Assert.Equal(2, requestBuilder.MethodDictionary.Count); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .WithQueryString("name", "name") + .Respond(HttpStatusCode.OK); + await fixture.MultiParameter("id", "name"); + Assert.Equal(3, requestBuilder.MethodDictionary.Count); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .WithQueryString("name", "name") + .WithQueryString("generic", "generic") + .Respond(HttpStatusCode.OK); + await fixture.SingleGenericMultiParameter("id", "name", "generic"); + Assert.Equal(4, requestBuilder.MethodDictionary.Count); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task NoDuplicateEntriesTest() + { + var mockHttp = new MockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + var fixture = RestService.For("http://bar", settings); + + // get internal dictionary to check count + var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder"); + var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation; + + // send the same request repeatedly to ensure that multiple dictionary entries are not created + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter("id"); + Assert.Single(requestBuilder.MethodDictionary); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter("id"); + Assert.Single(requestBuilder.MethodDictionary); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter("id"); + Assert.Single(requestBuilder.MethodDictionary); + + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task SameNameDuplicateEntriesTest() + { + var mockHttp = new MockHttpMessageHandler(); + var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp }; + + var fixture = RestService.For("http://bar", settings); + + // get internal dictionary to check count + var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder"); + var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation; + + // send the two different requests with the same name + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "id") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter("id"); + Assert.Single(requestBuilder.MethodDictionary); + + mockHttp + .Expect(HttpMethod.Post, "http://bar/foo") + .WithQueryString("id", "10") + .Respond(HttpStatusCode.OK); + await fixture.SingleParameter(10); + Assert.Equal(2, requestBuilder.MethodDictionary.Count); + + mockHttp.VerifyNoOutstandingExpectation(); + } +} diff --git a/Refit/CachedRequestBuilderImplementation.cs b/Refit/CachedRequestBuilderImplementation.cs index 951b08d9e..1a0908330 100644 --- a/Refit/CachedRequestBuilderImplementation.cs +++ b/Refit/CachedRequestBuilderImplementation.cs @@ -20,10 +20,10 @@ public CachedRequestBuilderImplementation(IRequestBuilder innerBuilder) } readonly IRequestBuilder innerBuilder; - readonly ConcurrentDictionary< - string, + internal readonly ConcurrentDictionary< + MethodTableKey, Func - > methodDictionary = new(); + > MethodDictionary = new(); public Func BuildRestResultFuncForMethod( string methodName, @@ -31,13 +31,22 @@ readonly ConcurrentDictionary< Type[]? genericArgumentTypes = null ) { - var cacheKey = GetCacheKey( + var cacheKey = new MethodTableKey( methodName, parameterTypes ?? Array.Empty(), genericArgumentTypes ?? Array.Empty() ); - var func = methodDictionary.GetOrAdd( - cacheKey, + + if (MethodDictionary.TryGetValue(cacheKey, out var methodFunc)) + { + return methodFunc; + } + + // use GetOrAdd with cloned array method table key. This prevents the array from being modified, breaking the dictionary. + var func = MethodDictionary.GetOrAdd( + new MethodTableKey(methodName, + parameterTypes?.ToArray() ?? Array.Empty(), + genericArgumentTypes?.ToArray() ?? Array.Empty()), _ => innerBuilder.BuildRestResultFuncForMethod( methodName, @@ -48,37 +57,88 @@ readonly ConcurrentDictionary< return func; } + } - static string GetCacheKey( - string methodName, - Type[] parameterTypes, - Type[] genericArgumentTypes - ) + /// + /// Represents a method composed of its name, generic arguments and parameters. + /// + internal readonly struct MethodTableKey : IEquatable + { + /// + /// Constructs an instance of . + /// + /// Represents the methods name. + /// Array containing the methods parameters. + /// Array containing the methods generic arguments. + public MethodTableKey (string methodName, Type[] parameters, Type[] genericArguments) { - var genericDefinition = GetGenericString(genericArgumentTypes); - var argumentString = GetArgumentString(parameterTypes); - - return $"{methodName}{genericDefinition}({argumentString})"; + MethodName = methodName; + Parameters = parameters; + GenericArguments = genericArguments; } - static string GetArgumentString(Type[] parameterTypes) + /// + /// The methods name. + /// + string MethodName { get; } + + /// + /// Array containing the methods parameters. + /// + Type[] Parameters { get; } + + /// + /// Array containing the methods generic arguments. + /// + Type[] GenericArguments { get; } + + public override int GetHashCode() { - if (parameterTypes == null || parameterTypes.Length == 0) + unchecked { - return ""; - } + var hashCode = MethodName.GetHashCode(); + + foreach (var argument in Parameters) + { + hashCode = (hashCode * 397) ^ argument.GetHashCode(); + } - return string.Join(", ", parameterTypes.Select(t => t.FullName)); + foreach (var genericArgument in GenericArguments) + { + hashCode = (hashCode * 397) ^ genericArgument.GetHashCode(); + } + return hashCode; + } } - static string GetGenericString(Type[] genericArgumentTypes) + public bool Equals(MethodTableKey other) { - if (genericArgumentTypes == null || genericArgumentTypes.Length == 0) + if (Parameters.Length != other.Parameters.Length + || GenericArguments.Length != other.GenericArguments.Length + || MethodName != other.MethodName) { - return ""; + return false; } - return "<" + string.Join(", ", genericArgumentTypes.Select(t => t.FullName)) + ">"; + for (var i = 0; i < Parameters.Length; i++) + { + if (Parameters[i] != other.Parameters[i]) + { + return false; + } + } + + for (var i = 0; i < GenericArguments.Length; i++) + { + if (GenericArguments[i] != other.GenericArguments[i]) + { + return false; + } + } + + return true; } + + public override bool Equals(object? obj) => obj is MethodTableKey other && Equals(other); } }