diff --git a/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Controllers/Api/Public.Data/DataSetsControllerTests.cs b/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Controllers/Api/Public.Data/DataSetsControllerTests.cs index d02e478ee98..be4d733b359 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Controllers/Api/Public.Data/DataSetsControllerTests.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Controllers/Api/Public.Data/DataSetsControllerTests.cs @@ -31,7 +31,7 @@ namespace GovUk.Education.ExploreEducationStatistics.Admin.Tests.Controllers.Api.Public.Data; -public class DataSetsControllerTests(TestApplicationFactory testApp) : IntegrationTestFixture(testApp) +public abstract class DataSetsControllerTests(TestApplicationFactory testApp) : IntegrationTestFixture(testApp) { private const string BaseUrl = "api/public-data/data-sets"; diff --git a/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Fixture/TestApplicationFactory.cs b/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Fixture/TestApplicationFactory.cs index af92fb96cca..92417373437 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Fixture/TestApplicationFactory.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Admin.Tests/Fixture/TestApplicationFactory.cs @@ -48,7 +48,10 @@ protected override IHostBuilder CreateHostBuilder() .ConfigureServices(services => { services.AddDbContext<PublicDataDbContext>( - options => options.UseNpgsql(_postgreSqlContainer.GetConnectionString())); + options => options + .UseNpgsql( + _postgreSqlContainer.GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())); using var serviceScope = services.BuildServiceProvider() .GetRequiredService<IServiceScopeFactory>() diff --git a/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/Extensions/DbContextTransactionExtensionsTests.cs b/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/Extensions/DbContextTransactionExtensionsTests.cs new file mode 100644 index 00000000000..4ab994aba15 --- /dev/null +++ b/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/Extensions/DbContextTransactionExtensionsTests.cs @@ -0,0 +1,382 @@ +#nullable enable +using System; +using System.Linq; +using System.Threading.Tasks; +using GovUk.Education.ExploreEducationStatistics.Common.Extensions; +using GovUk.Education.ExploreEducationStatistics.Common.Model; +using GovUk.Education.ExploreEducationStatistics.Common.Tests.Fixtures; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Testcontainers.PostgreSql; +using Xunit; + +namespace GovUk.Education.ExploreEducationStatistics.Common.Tests.Extensions; + +/// <summary> +/// +/// This test suite covers convenience extension methods for removing the verbosity of +/// creating shared transactions between different DbContexts, and the behaviours around +/// how they interact and some of the issues to look out for when using. +/// +/// From the tests covered below, we establish that: +/// +/// * multiple DbContexts are able to coordinate under the same transaction boundary and +/// all roll back if a failure occurs. +/// * Transactions can be nested within each other and the parent has control of completing +/// or failing the transaction. If the child transaction fails, the parent needs to +/// acknowledge that failure in order to fail itself (e.g. rethrow an exception, return a +/// failing Either etc). +/// * only a single DbContext that supports RetryOnFailure need be the one to create an +/// ExecutionContext, and thereafter all DbContexts supporting RetryOnFailure will +/// operate successfully. +/// * If a DbContext that doesn't support RetryOnFailure is used as the one that creates +/// the ExecutionStrategy and subsequently a DbContext that *does* support RetryOnFailure +/// is used, it will throw an InvalidOperationException, showing therefore that it is +/// best to use a RetryOnFailure-supporting DbContext if possible to originally create the +/// transaction. +/// +/// </summary> +public abstract class DbContextTransactionExtensionsTests(TestApplicationFactory<TestStartup> testApp) : + IClassFixture<TestApplicationFactory<TestStartup>>, + IAsyncLifetime +{ + private readonly PostgreSqlContainer[] _postgreSqlContainers = + Enumerable.Range(0, 3).Select(_ => new PostgreSqlBuilder() + .WithImage("postgres:16.1-alpine") + .Build()) + .ToArray(); + + public async Task InitializeAsync() => + await Task.WhenAll(_postgreSqlContainers.SelectAsync(async container => + { + await container.StartAsync(); + await container.ExecScriptAsync("ALTER SYSTEM SET max_prepared_transactions = 100"); + await container.StopAsync(); + await container.StartAsync(); + await container.ExecScriptAsync("""CREATE TABLE IF NOT EXISTS "Entities" ("Id" int PRIMARY KEY)"""); + return Task.CompletedTask; + })); + + public async Task DisposeAsync() => + await Task.WhenAll(_postgreSqlContainers.Select(async container => + await container.DisposeAsync())); + + public class NoTransactionTests(TestApplicationFactory<TestStartup> testApp) + : DbContextTransactionExtensionsTests(testApp) + { + [Fact] + public async Task SucceedWithoutTransaction() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.SucceedWithoutTransaction(); + + AssertEntitiesInAllDbContexts(app); + } + } + + private WebApplicationFactory<TestStartup> BuildApp() => + testApp.ConfigureServices( + services => services + .AddTransient<TestService>() + .AddDbContext<TestDbContext1>(options => + options.UseNpgsql(_postgreSqlContainers[0].GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())) + .AddDbContext<TestDbContext2>(options => + options.UseNpgsql(_postgreSqlContainers[1].GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())) + .AddDbContext<TestDbContext3WithoutRetry>(options => + options.UseNpgsql(_postgreSqlContainers[2].GetConnectionString())) + ); + + public class RequireTransactionTests(TestApplicationFactory<TestStartup> testApp) + : DbContextTransactionExtensionsTests(testApp) + { + public class FlatTransactionTests(TestApplicationFactory<TestStartup> testApp) + : RequireTransactionTests(testApp) + { + [Fact] + public async Task Succeed() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.SucceedWithinFlatTransaction(); + + AssertEntitiesInAllDbContexts(app); + } + + [Fact] + public async Task SucceedWithEither() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.SucceedWithinFlatTransactionWithEither(); + + AssertEntitiesInAllDbContexts(app); + } + + [Fact] + public async Task Fail() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await Assert.ThrowsAsync<SimulateFailureException>(service.FailWithinFlatTransaction); + + AssertNoEntitiesInAnyDbContexts(app); + } + + [Fact] + public async Task FailWithEither() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.FailWithinFlatTransaction_WithEither(); + + AssertNoEntitiesInAnyDbContexts(app); + } + } + + public class NestedTransactionTests(TestApplicationFactory<TestStartup> testApp) + : RequireTransactionTests(testApp) + { + [Fact] + public async Task SucceedWithinNestedTransaction() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.SucceedWithinNestedTransaction(); + + AssertEntitiesInAllDbContexts(app, 1); + AssertEntitiesInAllDbContexts(app, 2); + } + + [Fact] + public async Task SucceedWithinNestedTransaction_MultipleContextsRequestTransaction() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.SucceedWithinNestedTransaction_MultipleContextsRequestTransaction(); + + AssertEntitiesInAllDbContexts(app, 1); + AssertEntitiesInAllDbContexts(app, 2); + AssertEntitiesInAllDbContexts(app, 3); + } + + [Fact] + public async Task TransactionInitiatedByNonRetryDbContext_ThrowsException() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await Assert.ThrowsAsync<InvalidOperationException>(service.TransactionInitiatedByNonRetryDbContext); + + AssertNoEntitiesInAnyDbContexts(app); + } + + [Fact] + public async Task FailWithinNestedTransaction() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await Assert.ThrowsAsync<SimulateFailureException>(service.FailWithinNestedTransaction); + + AssertNoEntitiesInAnyDbContexts(app); + } + + [Fact] + public async Task FailWithinNestedTransaction_WithEither() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.FailWithinNestedTransaction_WithEither(); + + AssertNoEntitiesInAnyDbContexts(app); + } + + [Fact] + public async Task FailAtTopLevelWithNestedTransaction() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await Assert.ThrowsAsync<SimulateFailureException>(service.FailAtTopLevelWithNestedTransaction); + + AssertNoEntitiesInAnyDbContexts(app); + } + + [Fact] + public async Task FailAtTopLevelWithNestedTransaction_WithEither() + { + var app = BuildApp(); + var service = app.Services.GetRequiredService<TestService>(); + await service.FailAtTopLevelWithNestedTransaction_WithEither(); + + AssertNoEntitiesInAnyDbContexts(app); + } + } + } + + internal class TestService( + TestDbContext1 dbContext1, + TestDbContext2 dbContext2, + TestDbContext3WithoutRetry dbContext3WithoutRetry) + { + public async Task SucceedWithoutTransaction() => await WriteToAllDbContexts(); + + public async Task SucceedWithinFlatTransaction() => + await dbContext1.RequireTransaction(() => WriteToAllDbContexts()); + + public async Task SucceedWithinFlatTransactionWithEither() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + return new Either<int, string>("success!"); + }); + + public async Task FailWithinFlatTransaction() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + throw new SimulateFailureException(); + }); + + public async Task FailWithinFlatTransaction_WithEither() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + return new Either<int, string>(1); + }); + + public async Task SucceedWithinNestedTransaction() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext1.RequireTransaction(() => + WriteToAllDbContexts(2)); + }); + + public async Task SucceedWithinNestedTransaction_MultipleContextsRequestTransaction() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext2.RequireTransaction(async () => + { + await WriteToAllDbContexts(2); + await dbContext3WithoutRetry.RequireTransaction(async () => + { + await WriteToAllDbContexts(3); + }); + }); + }); + + public async Task TransactionInitiatedByNonRetryDbContext() => + await dbContext3WithoutRetry.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext2.RequireTransaction(async () => + { + await WriteToAllDbContexts(2); + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(3); + }); + }); + }); + + public async Task FailWithinNestedTransaction() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(2); + throw new SimulateFailureException(); + }); + }); + + public async Task FailWithinNestedTransaction_WithEither() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + return await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(2); + return new Either<int, string>(1); + }); + }); + + public async Task FailAtTopLevelWithNestedTransaction() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext1.RequireTransaction(() => WriteToAllDbContexts(2)); + + throw new SimulateFailureException(); + }); + + public async Task FailAtTopLevelWithNestedTransaction_WithEither() => + await dbContext1.RequireTransaction(async () => + { + await WriteToAllDbContexts(); + await dbContext1.RequireTransaction(() => WriteToAllDbContexts(2)); + return new Either<int, string>(1); + }); + + + private async Task WriteToAllDbContexts(int id = 1) + { + await dbContext1.Entities.AddAsync(new TestEntity { Id = id }); + await dbContext1.SaveChangesAsync(); + + await dbContext2.Entities.AddAsync(new TestEntity { Id = id }); + await dbContext2.SaveChangesAsync(); + + await dbContext3WithoutRetry.Entities.AddAsync(new TestEntity { Id = id }); + await dbContext3WithoutRetry.SaveChangesAsync(); + } + } + + private static void AssertEntitiesInAllDbContexts( + WebApplicationFactory<TestStartup> app, int expectedId = 1) + { + var dbContext1 = app.Services.GetRequiredService<TestDbContext1>(); + var dbContext2 = app.Services.GetRequiredService<TestDbContext2>(); + var dbContext3 = app.Services.GetRequiredService<TestDbContext3WithoutRetry>(); + + Assert.NotNull(dbContext1.Entities.SingleOrDefaultAsync(entity => entity.Id == expectedId)); + Assert.NotNull(dbContext2.Entities.SingleOrDefaultAsync(entity => entity.Id == expectedId)); + Assert.NotNull(dbContext3.Entities.SingleOrDefaultAsync(entity => entity.Id == expectedId)); + } + + private static void AssertNoEntitiesInAnyDbContexts(WebApplicationFactory<TestStartup> app) + { + var dbContext1 = app.Services.GetRequiredService<TestDbContext1>(); + var dbContext2 = app.Services.GetRequiredService<TestDbContext2>(); + var dbContext3 = app.Services.GetRequiredService<TestDbContext3WithoutRetry>(); + + Assert.Empty(dbContext1.Entities); + Assert.Empty(dbContext2.Entities); + Assert.Empty(dbContext3.Entities); + } + + internal class TestEntity + { + public int Id { get; set; } + } + + internal class TestDbContext1(DbContextOptions<TestDbContext1> options) : DbContext(options) + { + public DbSet<TestEntity> Entities { get; init; } = null!; + } + + internal class TestDbContext2(DbContextOptions<TestDbContext2> options) : DbContext(options) + { + public DbSet<TestEntity> Entities { get; init; } = null!; + } + + internal class TestDbContext3WithoutRetry(DbContextOptions<TestDbContext3WithoutRetry> options) : DbContext(options) + { + public DbSet<TestEntity> Entities { get; init; } = null!; + } + + private class SimulateFailureException : Exception; +} diff --git a/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/GovUk.Education.ExploreEducationStatistics.Common.Tests.csproj b/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/GovUk.Education.ExploreEducationStatistics.Common.Tests.csproj index 0dfbb9e71d4..3093d690753 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/GovUk.Education.ExploreEducationStatistics.Common.Tests.csproj +++ b/src/GovUk.Education.ExploreEducationStatistics.Common.Tests/GovUk.Education.ExploreEducationStatistics.Common.Tests.csproj @@ -1,43 +1,46 @@ <Project Sdk="Microsoft.NET.Sdk"> - <PropertyGroup> - <TargetFramework>net8.0</TargetFramework> - <IsPackable>false</IsPackable> - <EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild> - </PropertyGroup> + <PropertyGroup> + <TargetFramework>net8.0</TargetFramework> + <IsPackable>false</IsPackable> + <EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild> + </PropertyGroup> - <ItemGroup> - <CompilerVisibleProperty Include="RootNamespace" /> - <CompilerVisibleProperty Include="ProjectDir" /> - </ItemGroup> + <ItemGroup> + <CompilerVisibleProperty Include="RootNamespace"/> + <CompilerVisibleProperty Include="ProjectDir"/> + </ItemGroup> - <ItemGroup> - <PackageReference Include="AspectInjector" Version="2.8.2" /> - <PackageReference Include="Bogus" Version="35.5.1" /> - <PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.4" PrivateAssets="all" /> - <PackageReference Include="CompareNETObjects" Version="4.83.0" /> - <PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.6" /> - <PackageReference Include="Microsoft.Azure.Functions.Worker" Version="1.22.0" /> - <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" /> - <PackageReference Include="Moq" Version="4.20.70" /> - <PackageReference Include="Snapshooter.Xunit" Version="0.14.1" /> - <PackageReference Include="xunit" Version="2.7.1" /> - <PackageReference Include="xunit.runner.visualstudio" Version="2.8.0"> - <PrivateAssets>all</PrivateAssets> - <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> - </PackageReference> - <PackageReference Include="System.Net.NameResolution" Version="4.3.0" /> - </ItemGroup> + <ItemGroup> + <PackageReference Include="AspectInjector" Version="2.8.2"/> + <PackageReference Include="Bogus" Version="35.5.1"/> + <PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.4" PrivateAssets="all"/> + <PackageReference Include="CompareNETObjects" Version="4.83.0"/> + <PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.6"/> + <PackageReference Include="Microsoft.Azure.Functions.Worker" Version="1.22.0"/> + <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0"/> + <PackageReference Include="Moq" Version="4.20.70"/> + <PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="8.0.2"/> + <PackageReference Include="Snapshooter.Xunit" Version="0.14.1"/> + <PackageReference Include="xunit" Version="2.7.1"/> + <PackageReference Include="xunit.runner.visualstudio" Version="2.8.0"> + <PrivateAssets>all</PrivateAssets> + <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> + </PackageReference> + <PackageReference Include="System.Net.NameResolution" Version="4.3.0"/> + <PackageReference Include="Testcontainers" Version="3.8.0"/> + <PackageReference Include="Testcontainers.PostgreSql" Version="3.8.0"/> + </ItemGroup> - <ItemGroup> - <ProjectReference Include="..\GovUk.Education.ExploreEducationStatistics.Common\GovUk.Education.ExploreEducationStatistics.Common.csproj" /> - </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\GovUk.Education.ExploreEducationStatistics.Common\GovUk.Education.ExploreEducationStatistics.Common.csproj"/> + </ItemGroup> - <ItemGroup> - <None Remove="Resources\**" /> - <Content Include="Resources\**"> - <CopyToOutputDirectory>Always</CopyToOutputDirectory> - </Content> - </ItemGroup> + <ItemGroup> + <None Remove="Resources\**"/> + <Content Include="Resources\**"> + <CopyToOutputDirectory>Always</CopyToOutputDirectory> + </Content> + </ItemGroup> </Project> diff --git a/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/DbContextTransactionExtensions.cs b/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/DbContextTransactionExtensions.cs new file mode 100644 index 00000000000..b76f22bcbaa --- /dev/null +++ b/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/DbContextTransactionExtensions.cs @@ -0,0 +1,84 @@ +using System; +using System.Threading.Tasks; +using System.Transactions; +using GovUk.Education.ExploreEducationStatistics.Common.Model; +using Microsoft.EntityFrameworkCore; + +namespace GovUk.Education.ExploreEducationStatistics.Common.Extensions; + +public static class DbContextTransactionExtensions +{ + public static async Task<TResult> RequireTransaction<TDbContext, TResult>( + this TDbContext context, + Func<Task<TResult>> transactionalUnit) + where TDbContext : DbContext + { + var strategy = context.Database.CreateExecutionStrategy(); + + return await strategy.ExecuteAsync( + async () => + { + using var transactionScope = new TransactionScope( + TransactionScopeOption.Required, + new TransactionOptions {IsolationLevel = IsolationLevel.ReadCommitted}, + TransactionScopeAsyncFlowOption.Enabled); + + var result = await transactionalUnit.Invoke(); + transactionScope.Complete(); + return result; + }); + } + + public static async Task<Either<TFailure, TResult>> RequireTransaction<TDbContext, TFailure, TResult>( + this TDbContext context, + Func<Task<Either<TFailure, TResult>>> transactionalUnit) + where TDbContext : DbContext + { + var strategy = context.Database.CreateExecutionStrategy(); + + return await strategy.ExecuteAsync( + async () => + { + using var transactionScope = new TransactionScope( + TransactionScopeOption.Required, + new TransactionOptions {IsolationLevel = IsolationLevel.ReadCommitted}, + TransactionScopeAsyncFlowOption.Enabled); + + return await transactionalUnit + .Invoke() + .OnSuccessDo(transactionScope.Complete); + }); + } + + public static async Task RequireTransaction<TDbContext>( + this TDbContext context, + Func<Task> transactionalUnit) + where TDbContext : DbContext + { + await RequireTransaction(context, async () => + { + await transactionalUnit.Invoke(); + return Unit.Instance; + }); + } + + public static Task RequireTransaction<TDbContext>( + this TDbContext context, + Action transactionalUnit) + where TDbContext : DbContext + { + return RequireTransaction(context, () => + { + transactionalUnit.Invoke(); + return Task.CompletedTask; + }); + } + + public static Task<TResult> RequireTransaction<TDbContext, TResult>( + this TDbContext context, + Func<TResult> transactionalUnit) + where TDbContext : DbContext + { + return RequireTransaction(context, () => Task.FromResult(transactionalUnit.Invoke())); + } +} diff --git a/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/ServiceCollectionExtensions.cs b/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/ServiceCollectionExtensions.cs index fa0eacd4456..0f9e34bf515 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/ServiceCollectionExtensions.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Common/Extensions/ServiceCollectionExtensions.cs @@ -76,7 +76,9 @@ private static IServiceCollection AddDevelopmentPsqlDbContext<TDbContext>( services.AddDbContext<TDbContext>(options => { options - .UseNpgsql(dataSource) + .UseNpgsql( + dataSource, + psqlOptions => psqlOptions.EnableRetryOnFailure()) .EnableSensitiveDataLogging(); optionsConfiguration?.Invoke(options); @@ -111,7 +113,10 @@ private static IServiceCollection RegisterManagedIdentityPsqlDbContext<TDbContex services.AddDbContext<TDbContext>(options => { - options.UseNpgsql(dataSource); + options + .UseNpgsql( + dataSource, + psqlOptions => psqlOptions.EnableRetryOnFailure()); optionsConfiguration?.Invoke(options); }); diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Api.Tests/Fixture/TestApplicationFactory.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Api.Tests/Fixture/TestApplicationFactory.cs index aa724eaeca6..2cf4d952c00 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Api.Tests/Fixture/TestApplicationFactory.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Api.Tests/Fixture/TestApplicationFactory.cs @@ -51,7 +51,10 @@ protected override IHostBuilder CreateHostBuilder() .ConfigureServices(services => { services.AddDbContext<PublicDataDbContext>( - options => options.UseNpgsql(_postgreSqlContainer.GetConnectionString())); + options => options + .UseNpgsql( + _postgreSqlContainer.GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())); }); } } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests/Validators/ValidationMessages.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests/Validators/ValidationMessages.cs index efeee36b90c..634d31d140c 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests/Validators/ValidationMessages.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests/Validators/ValidationMessages.cs @@ -34,6 +34,11 @@ public static class ValidationMessages Code: nameof(DataSetVersionCanNotBeDeleted), Message: $"The data set version is not in a '{DataSetVersionStatus.Draft}' status, so cannot be deleted." ); + + public static readonly LocalizableMessage DataSetMustHaveNoExistingVersions = new( + Code: nameof(DataSetMustHaveNoExistingVersions), + Message: "The data set must have no existing versions when creating the initial version." + ); public static readonly LocalizableMessage DataSetNotFound = new( Code: nameof(DataSetNotFound), diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateDataSetFunctionTests.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateDataSetFunctionTests.cs index cbead5d99b1..30e423306a8 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateDataSetFunctionTests.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateDataSetFunctionTests.cs @@ -27,8 +27,6 @@ public abstract class CreateDataSetFunctionTests(ProcessorFunctionsIntegrationTe public class CreateDataSetTests(ProcessorFunctionsIntegrationTestFixture fixture) : CreateDataSetFunctionTests(fixture) { - private const string DurableTaskClientName = "TestClient"; - [Fact] public async Task Success() { @@ -48,7 +46,7 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.AddRange(releaseFile, releaseMetaFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); + var durableTaskClientMock = new Mock<DurableTaskClient>(MockBehavior.Strict, "TestClient"); ProcessDataSetVersionContext? processInitialDataSetVersionContext = null; StartOrchestrationOptions? startOrchestrationOptions = null; @@ -125,13 +123,7 @@ await AddTestData<ContentDbContext>(context => [Fact] public async Task ReleaseFileIdIsEmpty_ReturnsValidationProblem() { - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: Guid.Empty, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: Guid.Empty); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -142,13 +134,7 @@ public async Task ReleaseFileIdIsEmpty_ReturnsValidationProblem() [Fact] public async Task ReleaseFileIdIsNotFound_ReturnsValidationProblem() { - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: Guid.NewGuid(), - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: Guid.NewGuid()); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -182,13 +168,7 @@ await AddTestData<PublicDataDbContext>(context => context.DataSetVersions.Add(dataSetVersion); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -215,13 +195,7 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.AddRange(releaseFile, releaseMetaFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -244,13 +218,7 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.Add(releaseFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -272,13 +240,7 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.Add(releaseFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - - var result = await CreateDataSet( - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + var result = await CreateDataSet(releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -290,11 +252,11 @@ await AddTestData<ContentDbContext>(context => private async Task<IActionResult> CreateDataSet( Guid releaseFileId, - DurableTaskClient durableTaskClient) + DurableTaskClient? durableTaskClient = null) { var function = GetRequiredService<CreateDataSetFunction>(); return await function.CreateInitialDataSetVersion(new DataSetCreateRequest {ReleaseFileId = releaseFileId}, - durableTaskClient, + durableTaskClient ?? new Mock<DurableTaskClient>(MockBehavior.Strict, "TestClient").Object, CancellationToken.None); } } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateNextDataSetVersionFunctionTests.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateNextDataSetVersionFunctionTests.cs index b085260b152..72ec245a79d 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateNextDataSetVersionFunctionTests.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/Functions/CreateNextDataSetVersionFunctionTests.cs @@ -27,15 +27,13 @@ public abstract class CreateNextDataSetVersionFunctionTests(ProcessorFunctionsIn public class CreateNextDataSetVersionTests(ProcessorFunctionsIntegrationTestFixture fixture) : CreateNextDataSetVersionFunctionTests(fixture) { - private const string DurableTaskClientName = "TestClient"; - [Fact] public async Task Success() { var (dataSet, liveDataSetVersion) = await AddDataSetAndLatestLiveVersion(); var (nextReleaseFile, _) = await AddDataAndMetadataFiles(dataSet.PublicationId); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); + var durableTaskClientMock = new Mock<DurableTaskClient>(MockBehavior.Strict, "TestClient"); ProcessDataSetVersionContext? processNextDataSetVersionContext = null; StartOrchestrationOptions? startOrchestrationOptions = null; @@ -124,14 +122,9 @@ public async Task Success() [Fact] public async Task ReleaseFileIdIsEmpty_ReturnsValidationProblem() { - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: Guid.NewGuid(), - releaseFileId: Guid.Empty, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: Guid.Empty); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -142,14 +135,9 @@ public async Task ReleaseFileIdIsEmpty_ReturnsValidationProblem() [Fact] public async Task DataSetIdIsEmpty_ReturnsValidationProblem() { - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: Guid.Empty, - releaseFileId: Guid.NewGuid(), - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: Guid.NewGuid()); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -176,14 +164,9 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.AddRange(releaseFile, releaseMetaFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: Guid.NewGuid(), - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -197,14 +180,9 @@ public async Task ReleaseFileIdIsNotFound_ReturnsValidationProblem() { var (dataSet, _) = await AddDataSetAndLatestLiveVersion(); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: Guid.NewGuid(), - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: Guid.NewGuid()); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -241,14 +219,9 @@ await AddTestData<PublicDataDbContext>(context => context.DataSetVersions.Add(otherDataSetVersion); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -277,14 +250,9 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.AddRange(releaseFile, releaseMetaFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -309,14 +277,9 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.Add(releaseFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -340,14 +303,9 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.Add(releaseFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -365,14 +323,9 @@ public async Task DataSetAndReleaseFileFromDifferentPublications_ReturnsValidati // Add ReleaseFiles for a different Publication. var (releaseFile, _) = await AddDataAndMetadataFiles(publicationId: Guid.NewGuid()); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -393,14 +346,9 @@ public async Task DataSetWithoutLiveVersion_ReturnsValidationProblem() var (releaseFile, _) = await AddDataAndMetadataFiles(dataSet.PublicationId); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: releaseFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: releaseFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -445,14 +393,9 @@ await AddTestData<ContentDbContext>(context => context.ReleaseFiles.AddRange(nextDataFile, nextMetaFile); }); - var durableTaskClientMock = new Mock<DurableTaskClient>(DurableTaskClientName); - var result = await CreateNextDataSetVersion( dataSetId: dataSet.Id, - releaseFileId: nextDataFile.Id, - durableTaskClientMock.Object); - - VerifyAllMocks(durableTaskClientMock); + releaseFileId: nextDataFile.Id); var validationProblem = result.AssertBadRequestWithValidationProblem(); @@ -529,7 +472,7 @@ await AddTestData<ContentDbContext>(context => private async Task<IActionResult> CreateNextDataSetVersion( Guid dataSetId, Guid releaseFileId, - DurableTaskClient durableTaskClient) + DurableTaskClient? durableTaskClient = null) { var function = GetRequiredService<CreateNextDataSetVersionFunction>(); return await function.CreateNextDataSetVersion(new NextDataSetVersionCreateRequest @@ -537,7 +480,7 @@ private async Task<IActionResult> CreateNextDataSetVersion( DataSetId = dataSetId, ReleaseFileId = releaseFileId }, - durableTaskClient, + durableTaskClient ?? new Mock<DurableTaskClient>(MockBehavior.Strict, "TestClient").Object, CancellationToken.None); } } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/ProcessorFunctionsIntegrationTest.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/ProcessorFunctionsIntegrationTest.cs index fd4ebabdbd9..5516c292daa 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/ProcessorFunctionsIntegrationTest.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Tests/ProcessorFunctionsIntegrationTest.cs @@ -149,7 +149,9 @@ public override IHostBuilder ConfigureTestHostBuilder() services.UseInMemoryDbContext<ContentDbContext>(databaseName: Guid.NewGuid().ToString()); services.AddDbContext<PublicDataDbContext>( - options => options.UseNpgsql(_postgreSqlContainer.GetConnectionString())); + options => options.UseNpgsql( + _postgreSqlContainer.GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())); using var serviceScope = services.BuildServiceProvider() .GetRequiredService<IServiceScopeFactory>() diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateDataSetFunction.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateDataSetFunction.cs index e71a68459b0..f0647cd1ab1 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateDataSetFunction.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateDataSetFunction.cs @@ -30,7 +30,7 @@ public async Task<IActionResult> CreateInitialDataSetVersion( var instanceId = Guid.NewGuid(); return await requestValidator.Validate(request, cancellationToken) - .OnSuccess(() => dataSetService.CreateInitialDataSetVersion( + .OnSuccess(() => dataSetService.CreateDataSet( request, instanceId, cancellationToken: cancellationToken diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateNextDataSetVersionFunction.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateNextDataSetVersionFunction.cs index fa0a21420f3..edb131d2ebb 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateNextDataSetVersionFunction.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Functions/CreateNextDataSetVersionFunction.cs @@ -16,7 +16,7 @@ namespace GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Funct public class CreateNextDataSetVersionFunction( ILogger<CreateNextDataSetVersionFunction> logger, - IDataSetService dataSetService, + IDataSetVersionService dataSetVersionService, IValidator<NextDataSetVersionCreateRequest> requestValidator) { [Function(nameof(CreateNextDataSetVersion))] @@ -30,8 +30,9 @@ public async Task<IActionResult> CreateNextDataSetVersion( var instanceId = Guid.NewGuid(); return await requestValidator.Validate(request, cancellationToken) - .OnSuccess(() => dataSetService.CreateNextDataSetVersion( - request, + .OnSuccess(() => dataSetVersionService.CreateNextVersion( + dataSetId: request.DataSetId, + releaseFileId: request.ReleaseFileId, instanceId, cancellationToken: cancellationToken )) diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetService.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetService.cs index ec638318b73..fa635b22a05 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetService.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetService.cs @@ -1,6 +1,6 @@ -using System.Transactions; using GovUk.Education.ExploreEducationStatistics.Common.Extensions; using GovUk.Education.ExploreEducationStatistics.Common.Model; +using GovUk.Education.ExploreEducationStatistics.Common.Services.Interfaces; using GovUk.Education.ExploreEducationStatistics.Common.Validators; using GovUk.Education.ExploreEducationStatistics.Common.Validators.ErrorDetails; using GovUk.Education.ExploreEducationStatistics.Common.ViewModels; @@ -19,202 +19,23 @@ namespace GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Servi public class DataSetService( ContentDbContext contentDbContext, - PublicDataDbContext publicDataDbContext + PublicDataDbContext publicDataDbContext, + IDataSetVersionService dataSetVersionService ) : IDataSetService { - public async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateInitialDataSetVersion( + public async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateDataSet( DataSetCreateRequest request, Guid instanceId, CancellationToken cancellationToken = default) { - return await CreateDataSetVersion( - releaseFileId: request.ReleaseFileId, - instanceId: instanceId, - dataSetSupplier: async releaseFile => await CreateDataSet(releaseFile, cancellationToken), - cancellationToken: cancellationToken); - } - - public async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateNextDataSetVersion( - NextDataSetVersionCreateRequest request, - Guid instanceId, - CancellationToken cancellationToken = default) - { - return await ValidateDataSet(request.DataSetId, cancellationToken) - .OnSuccess(dataSet => CreateDataSetVersion( - releaseFileId: request.ReleaseFileId, - instanceId: instanceId, - dataSetSupplier: _ => Task.FromResult(dataSet), - cancellationToken: cancellationToken)); - } - - private async Task<Either<ActionResult, DataSet>> ValidateDataSet( - Guid dataSetId, - CancellationToken cancellationToken = default) - { - var dataSet = await publicDataDbContext - .DataSets - .Include(dataSet => dataSet.LatestLiveVersion) - .Include(dataSet => dataSet.Versions) - .FirstOrDefaultAsync(dataSet => dataSet.Id == dataSetId, cancellationToken); - - if (dataSet is null) - { - return ValidationUtils.ValidationResult(CreateDataSetIdError( - message: ValidationMessages.FileNotFound, - dataSetId: dataSetId - )); - } - - if (dataSet.LatestLiveVersionId is null) - { - return ValidationUtils.ValidationResult(CreateDataSetIdError( - message: ValidationMessages.DataSetMustHaveLiveDataSetVersion, - dataSetId: dataSet.Id)); - } - - return dataSet; - } - - private async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateDataSetVersion( - Guid releaseFileId, - Guid instanceId, - Func<ReleaseFile, Task<DataSet>> dataSetSupplier, - CancellationToken cancellationToken = default) - { - var strategy = contentDbContext.Database.CreateExecutionStrategy(); - - return await strategy.ExecuteAsync(async () => - { - using var transactionScope = new TransactionScope( - TransactionScopeOption.Required, - new TransactionOptions {IsolationLevel = IsolationLevel.ReadCommitted}, - TransactionScopeAsyncFlowOption.Enabled); - - return await GetReleaseFile(releaseFileId, cancellationToken) - .OnSuccess(async releaseFile => await ValidateReleaseFile(releaseFile, cancellationToken) - .OnSuccess(() => dataSetSupplier.Invoke(releaseFile)) - .OnSuccessDo(async dataSet => - await ValidateReleaseFileAndDataSet(releaseFile, dataSet, cancellationToken)) - .OnSuccess(async dataSet => - await CreateDataSetVersion(dataSet, releaseFile, cancellationToken)) - .OnSuccessDo(async dataSetVersion => - await CreateDataSetVersionImport(dataSetVersion, instanceId, cancellationToken)) - .OnSuccessDo(async dataSetVersion => - await UpdateFilePublicDataSetVersionId(releaseFile, dataSetVersion, cancellationToken)) - .OnSuccessDo(transactionScope.Complete) - .OnSuccess(dataSetVersion => - (dataSetId: dataSetVersion.DataSetId, dataSetVersionId: dataSetVersion.Id))); - }); - } - - private async Task<Either<ActionResult, ReleaseFile>> GetReleaseFile( - Guid releaseFileId, - CancellationToken cancellationToken) - { - var releaseFile = await contentDbContext.ReleaseFiles - .Include(rf => rf.File) - .Include(rf => rf.ReleaseVersion) - .FirstOrDefaultAsync(rf => rf.Id == releaseFileId, cancellationToken); - - return releaseFile is null - ? ValidationUtils.ValidationResult(CreateReleaseFileIdError( - message: ValidationMessages.FileNotFound, - releaseFileId: releaseFileId - )) - : releaseFile; - } - - private async Task<Either<ActionResult, Unit>> ValidateReleaseFile( - ReleaseFile releaseFile, - CancellationToken cancellationToken) - { - // ReleaseFile must not already have a DataSetVersion - if (await publicDataDbContext.DataSetVersions.AnyAsync( - dsv => dsv.ReleaseFileId == releaseFile.Id, - cancellationToken: cancellationToken)) - { - return ValidationUtils.ValidationResult( - [ - CreateReleaseFileIdError( - message: ValidationMessages.FileHasApiDataSetVersion, - releaseFileId: releaseFile.Id) - ]); - } - - // ReleaseFile must relate to a ReleaseVersion in Draft approval status - if (releaseFile.ReleaseVersion.ApprovalStatus != ReleaseApprovalStatus.Draft) - { - return ValidationUtils.ValidationResult( - [ - CreateReleaseFileIdError( - message: ValidationMessages.FileReleaseVersionNotDraft, - releaseFileId: releaseFile.Id) - ]); - } - - List<ErrorViewModel> errors = []; - - // ReleaseFile must relate to a File of type Data - if (releaseFile.File.Type != FileType.Data) - { - errors.Add(CreateReleaseFileIdError( - message: ValidationMessages.FileTypeNotData, - releaseFileId: releaseFile.Id)); - } - - // There must be a ReleaseFile related to the same ReleaseVersion and Subject with File of type Metadata - if (!await contentDbContext.ReleaseFiles - .Where(rf => rf.ReleaseVersionId == releaseFile.ReleaseVersionId) - .Where(rf => rf.File.SubjectId == releaseFile.File.SubjectId) - .Where(rf => rf.File.Type == FileType.Metadata) - .AnyAsync(cancellationToken: cancellationToken)) - { - errors.Add(CreateReleaseFileIdError( - message: ValidationMessages.NoMetadataFile, - releaseFileId: releaseFile.Id)); - } - - return errors.Count == 0 ? Unit.Instance : ValidationUtils.ValidationResult(errors); - } - - private async Task<Either<ActionResult, Unit>> ValidateReleaseFileAndDataSet( - ReleaseFile releaseFile, - DataSet dataSet, - CancellationToken cancellationToken) - { - List<ErrorViewModel> errors = []; - - if (releaseFile.ReleaseVersion.PublicationId != dataSet.PublicationId) - { - errors.Add(CreateReleaseFileIdError( - message: ValidationMessages.NextReleaseFileMustBeForSamePublicationAsDataSet, - releaseFileId: releaseFile.Id)); - } - - var historicReleaseFileIds = dataSet - .Versions - .Select(version => version.ReleaseFileId) - .ToList(); - - var historicalReleaseIds = await GetReleaseIdsForReleaseFiles( - contentDbContext, - historicReleaseFileIds, - cancellationToken); - - var selectedReleaseFileReleaseId = (await GetReleaseIdsForReleaseFiles( - contentDbContext, - [releaseFile.Id], - cancellationToken)) - .Single(); - - if (historicalReleaseIds.Contains(selectedReleaseFileReleaseId)) - { - errors.Add(CreateReleaseFileIdError( - message: ValidationMessages.ReleaseFileMustBeFromDifferentReleaseToHistoricalVersions, - releaseFileId: releaseFile.Id)); - } - - return errors.Count == 0 ? Unit.Instance : ValidationUtils.ValidationResult(errors); + return await publicDataDbContext.RequireTransaction(async () => + await GetReleaseFile(request.ReleaseFileId, cancellationToken) + .OnSuccess(releaseFile => CreateDataSet(releaseFile, cancellationToken)) + .OnSuccess(dataSet => dataSetVersionService.CreateInitialVersion( + dataSetId: dataSet.Id, + releaseFileId: request.ReleaseFileId, + instanceId: instanceId, + cancellationToken))); } private async Task<DataSet> CreateDataSet( @@ -235,90 +56,23 @@ private async Task<DataSet> CreateDataSet( return dataSet; } - private async Task<DataSetVersion> CreateDataSetVersion( - DataSet dataSet, - ReleaseFile releaseFile, - CancellationToken cancellationToken) - { - var dataSetVersion = new DataSetVersion - { - ReleaseFileId = releaseFile.Id, - DataSetId = dataSet.Id, - Status = DataSetVersionStatus.Processing, - Notes = "", - VersionMajor = dataSet.LatestLiveVersion?.VersionMajor ?? 1, - VersionMinor = dataSet.LatestLiveVersion?.VersionMinor + 1 ?? 0 - }; - - dataSet.Versions.Add(dataSetVersion); - dataSet.LatestDraftVersion = dataSetVersion; - - publicDataDbContext.DataSets.Update(dataSet); - await publicDataDbContext.SaveChangesAsync(cancellationToken); - - return dataSetVersion; - } - - private async Task CreateDataSetVersionImport( - DataSetVersion dataSetVersion, - Guid instanceId, - CancellationToken cancellationToken) - { - var dataSetVersionImport = new DataSetVersionImport - { - DataSetVersionId = dataSetVersion.Id, InstanceId = instanceId, Stage = DataSetVersionImportStage.Pending - }; - - publicDataDbContext.DataSetVersionImports.Add(dataSetVersionImport); - await publicDataDbContext.SaveChangesAsync(cancellationToken); - } - - private async Task UpdateFilePublicDataSetVersionId( - ReleaseFile releaseFile, - DataSetVersion dataSetVersion, - CancellationToken cancellationToken) - { - releaseFile.File.PublicApiDataSetId = dataSetVersion.DataSetId; - releaseFile.File.PublicApiDataSetVersion = dataSetVersion.FullSemanticVersion(); - await contentDbContext.SaveChangesAsync(cancellationToken); - } - - private static async Task<List<Guid>> GetReleaseIdsForReleaseFiles( - ContentDbContext contentDbContext, - List<Guid> releaseFileIds, + private async Task<Either<ActionResult, ReleaseFile>> GetReleaseFile( + Guid releaseFileId, CancellationToken cancellationToken) { - return await contentDbContext - .ReleaseFiles - .Include(releaseFile => releaseFile.ReleaseVersion) - .Where(releaseFile => releaseFileIds.Contains(releaseFile.Id)) - .Select(releaseFile => releaseFile.ReleaseVersion.ReleaseId) - .ToListAsync(cancellationToken); - } - - private static ErrorViewModel CreateReleaseFileIdError( - LocalizableMessage message, - Guid releaseFileId) - { - return new ErrorViewModel - { - Code = message.Code, - Message = message.Message, - Path = nameof(DataSetCreateRequest.ReleaseFileId).ToLowerFirst(), - Detail = new InvalidErrorDetail<Guid>(releaseFileId) - }; - } + var releaseFile = await contentDbContext.ReleaseFiles + .Include(rf => rf.File) + .Include(rf => rf.ReleaseVersion) + .FirstOrDefaultAsync(rf => rf.Id == releaseFileId, cancellationToken); - private static ErrorViewModel CreateDataSetIdError( - LocalizableMessage message, - Guid dataSetId) - { - return new ErrorViewModel - { - Code = message.Code, - Message = message.Message, - Path = nameof(NextDataSetVersionCreateRequest.DataSetId).ToLowerFirst(), - Detail = new InvalidErrorDetail<Guid>(dataSetId) - }; + return releaseFile is null + ? ValidationUtils.ValidationResult(new ErrorViewModel + { + Code = ValidationMessages.FileNotFound.Code, + Message = ValidationMessages.FileNotFound.Message, + Path = nameof(DataSetCreateRequest.ReleaseFileId).ToLowerFirst(), + Detail = new InvalidErrorDetail<Guid>(releaseFileId) + }) + : releaseFile; } } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetVersionService.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetVersionService.cs index ad755598861..d4fb49ab14d 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetVersionService.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/DataSetVersionService.cs @@ -7,11 +7,11 @@ using GovUk.Education.ExploreEducationStatistics.Content.Model.Database; using GovUk.Education.ExploreEducationStatistics.Public.Data.Model; using GovUk.Education.ExploreEducationStatistics.Public.Data.Model.Database; +using GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests; using GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Services.Interfaces; using GovUk.Education.ExploreEducationStatistics.Public.Data.Services.Interfaces; using Microsoft.AspNetCore.Mvc; using Microsoft.EntityFrameworkCore; -using System.Transactions; using ValidationMessages = GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Requests.Validators.ValidationMessages; @@ -23,31 +23,50 @@ internal class DataSetVersionService( IDataSetVersionPathResolver dataSetVersionPathResolver ) : IDataSetVersionService { + public async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateInitialVersion( + Guid dataSetId, + Guid releaseFileId, + Guid instanceId, + CancellationToken cancellationToken = default) + { + return await GetDataSet(dataSetId, cancellationToken) + .OnSuccess(ValidateInitialDataSet) + .OnSuccess(dataSet => CreateDataSetVersion( + releaseFileId: releaseFileId, + instanceId: instanceId, + dataSet: dataSet, + cancellationToken: cancellationToken)); + } + + public async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateNextVersion( + Guid dataSetId, + Guid releaseFileId, + Guid instanceId, + CancellationToken cancellationToken = default) + { + return await GetDataSet(dataSetId, cancellationToken) + .OnSuccess(ValidateNextDataSet) + .OnSuccess(dataSet => CreateDataSetVersion( + releaseFileId: releaseFileId, + instanceId: instanceId, + dataSet: dataSet, + cancellationToken: cancellationToken)); + } + public async Task<Either<ActionResult, Unit>> DeleteVersion(Guid dataSetVersionId, CancellationToken cancellationToken = default) { - var strategy = contentDbContext.Database.CreateExecutionStrategy(); - - return await strategy.ExecuteAsync( - async () => - { - using var transactionScope = new TransactionScope( - TransactionScopeOption.Required, - new TransactionOptions { IsolationLevel = IsolationLevel.ReadCommitted }, - TransactionScopeAsyncFlowOption.Enabled); - - return await GetDataSetVersion(dataSetVersionId, cancellationToken) - .OnSuccessDo(CheckCanDeleteDataSetVersion) - .OnSuccessDo(async dataSetVersion => await GetReleaseFile(dataSetVersion, cancellationToken) - .OnSuccessVoid(async releaseFile => - await UpdateFilePublicApiDataSetId(releaseFile, cancellationToken)) - .OnFailureDo(_ => - throw new KeyNotFoundException( - $"The expected 'ReleaseFile', with ID '{dataSetVersion.ReleaseFileId}', was not found."))) - .OnSuccessDo(async dataSetVersion => await DeleteDataSetVersion(dataSetVersion, cancellationToken)) - .OnSuccessVoid(DeleteParquetFiles) - .OnSuccessVoid(transactionScope.Complete); - }); + return await publicDataDbContext.RequireTransaction(() => + GetDataSetVersion(dataSetVersionId, cancellationToken) + .OnSuccessDo(CheckCanDeleteDataSetVersion) + .OnSuccessDo(async dataSetVersion => await GetReleaseFile(dataSetVersion, cancellationToken) + .OnSuccessVoid(async releaseFile => + await UpdateFilePublicApiDataSetId(releaseFile, cancellationToken)) + .OnFailureDo(_ => + throw new KeyNotFoundException( + $"The expected 'ReleaseFile', with ID '{dataSetVersion.ReleaseFileId}', was not found."))) + .OnSuccessDo(async dataSetVersion => await DeleteDataSetVersion(dataSetVersion, cancellationToken)) + .OnSuccessVoid(DeleteParquetFiles)); } private async Task<Either<ActionResult, DataSetVersion>> GetDataSetVersion(Guid dataSetVersionId, @@ -123,4 +142,264 @@ private void DeleteParquetFiles(DataSetVersion dataSetVersion) Directory.Delete(directory, true); } + + private async Task<Either<ActionResult, DataSet>> GetDataSet( + Guid dataSetId, + CancellationToken cancellationToken = default) + { + var dataSet = await publicDataDbContext + .DataSets + .Include(dataSet => dataSet.LatestLiveVersion) + .Include(dataSet => dataSet.Versions) + .FirstOrDefaultAsync(dataSet => dataSet.Id == dataSetId, cancellationToken); + + return dataSet is null + ? ValidationUtils.ValidationResult(CreateDataSetIdError( + message: ValidationMessages.FileNotFound, + dataSetId: dataSetId + )) + : dataSet; + } + + private Either<ActionResult, DataSet> ValidateInitialDataSet(DataSet dataSet) + { + if (dataSet.Versions.Count > 0) + { + return ValidationUtils.ValidationResult(CreateDataSetIdError( + message: ValidationMessages.DataSetMustHaveNoExistingVersions, + dataSetId: dataSet.Id)); + } + + return dataSet; + } + + private Either<ActionResult, DataSet> ValidateNextDataSet(DataSet dataSet) + { + if (dataSet.LatestLiveVersionId is null) + { + return ValidationUtils.ValidationResult(CreateDataSetIdError( + message: ValidationMessages.DataSetMustHaveLiveDataSetVersion, + dataSetId: dataSet.Id)); + } + + return dataSet; + } + + private async Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateDataSetVersion( + Guid releaseFileId, + Guid instanceId, + DataSet dataSet, + CancellationToken cancellationToken = default) + { + return await publicDataDbContext.RequireTransaction(async () => + await GetReleaseFile(releaseFileId, cancellationToken) + .OnSuccess(async releaseFile => await ValidateReleaseFile(releaseFile, cancellationToken) + .OnSuccessDo(async () => + await ValidateReleaseFileAndDataSet(releaseFile, dataSet, cancellationToken)) + .OnSuccess(async () => + await CreateDataSetVersion(dataSet, releaseFile, cancellationToken)) + .OnSuccessDo(async dataSetVersion => + await CreateDataSetVersionImport(dataSetVersion, instanceId, cancellationToken)) + .OnSuccessDo(async dataSetVersion => + await UpdateFilePublicDataSetVersionId(releaseFile, dataSetVersion, cancellationToken)) + .OnSuccess(dataSetVersion => + (dataSetId: dataSetVersion.DataSetId, dataSetVersionId: dataSetVersion.Id)))); + } + + private async Task<Either<ActionResult, ReleaseFile>> GetReleaseFile( + Guid releaseFileId, + CancellationToken cancellationToken) + { + var releaseFile = await contentDbContext.ReleaseFiles + .Include(rf => rf.File) + .Include(rf => rf.ReleaseVersion) + .FirstOrDefaultAsync(rf => rf.Id == releaseFileId, cancellationToken); + + return releaseFile is null + ? ValidationUtils.ValidationResult(CreateReleaseFileIdError( + message: ValidationMessages.FileNotFound, + releaseFileId: releaseFileId + )) + : releaseFile; + } + + private async Task<Either<ActionResult, Unit>> ValidateReleaseFile( + ReleaseFile releaseFile, + CancellationToken cancellationToken) + { + // ReleaseFile must not already have a DataSetVersion + if (await publicDataDbContext.DataSetVersions.AnyAsync( + dsv => dsv.ReleaseFileId == releaseFile.Id, + cancellationToken: cancellationToken)) + { + return ValidationUtils.ValidationResult( + [ + CreateReleaseFileIdError( + message: ValidationMessages.FileHasApiDataSetVersion, + releaseFileId: releaseFile.Id) + ]); + } + + // ReleaseFile must relate to a ReleaseVersion in Draft approval status + if (releaseFile.ReleaseVersion.ApprovalStatus != ReleaseApprovalStatus.Draft) + { + return ValidationUtils.ValidationResult( + [ + CreateReleaseFileIdError( + message: ValidationMessages.FileReleaseVersionNotDraft, + releaseFileId: releaseFile.Id) + ]); + } + + List<ErrorViewModel> errors = []; + + // ReleaseFile must relate to a File of type Data + if (releaseFile.File.Type != FileType.Data) + { + errors.Add(CreateReleaseFileIdError( + message: ValidationMessages.FileTypeNotData, + releaseFileId: releaseFile.Id)); + } + + // There must be a ReleaseFile related to the same ReleaseVersion and Subject with File of type Metadata + if (!await contentDbContext.ReleaseFiles + .Where(rf => rf.ReleaseVersionId == releaseFile.ReleaseVersionId) + .Where(rf => rf.File.SubjectId == releaseFile.File.SubjectId) + .Where(rf => rf.File.Type == FileType.Metadata) + .AnyAsync(cancellationToken: cancellationToken)) + { + errors.Add(CreateReleaseFileIdError( + message: ValidationMessages.NoMetadataFile, + releaseFileId: releaseFile.Id)); + } + + return errors.Count == 0 ? Unit.Instance : ValidationUtils.ValidationResult(errors); + } + + private async Task<Either<ActionResult, Unit>> ValidateReleaseFileAndDataSet( + ReleaseFile releaseFile, + DataSet dataSet, + CancellationToken cancellationToken) + { + List<ErrorViewModel> errors = []; + + if (releaseFile.ReleaseVersion.PublicationId != dataSet.PublicationId) + { + errors.Add(CreateReleaseFileIdError( + message: ValidationMessages.NextReleaseFileMustBeForSamePublicationAsDataSet, + releaseFileId: releaseFile.Id)); + } + + var historicReleaseFileIds = dataSet + .Versions + .Select(version => version.ReleaseFileId) + .ToList(); + + var historicalReleaseIds = await GetReleaseIdsForReleaseFiles( + contentDbContext, + historicReleaseFileIds, + cancellationToken); + + var selectedReleaseFileReleaseId = (await GetReleaseIdsForReleaseFiles( + contentDbContext, + [releaseFile.Id], + cancellationToken)) + .Single(); + + if (historicalReleaseIds.Contains(selectedReleaseFileReleaseId)) + { + errors.Add(CreateReleaseFileIdError( + message: ValidationMessages.ReleaseFileMustBeFromDifferentReleaseToHistoricalVersions, + releaseFileId: releaseFile.Id)); + } + + return errors.Count == 0 ? Unit.Instance : ValidationUtils.ValidationResult(errors); + } + + private async Task<DataSetVersion> CreateDataSetVersion( + DataSet dataSet, + ReleaseFile releaseFile, + CancellationToken cancellationToken) + { + var dataSetVersion = new DataSetVersion + { + ReleaseFileId = releaseFile.Id, + DataSetId = dataSet.Id, + Status = DataSetVersionStatus.Processing, + Notes = "", + VersionMajor = dataSet.LatestLiveVersion?.VersionMajor ?? 1, + VersionMinor = dataSet.LatestLiveVersion?.VersionMinor + 1 ?? 0 + }; + + dataSet.Versions.Add(dataSetVersion); + dataSet.LatestDraftVersion = dataSetVersion; + + publicDataDbContext.DataSets.Update(dataSet); + await publicDataDbContext.SaveChangesAsync(cancellationToken); + + return dataSetVersion; + } + + private async Task CreateDataSetVersionImport( + DataSetVersion dataSetVersion, + Guid instanceId, + CancellationToken cancellationToken) + { + var dataSetVersionImport = new DataSetVersionImport + { + DataSetVersionId = dataSetVersion.Id, InstanceId = instanceId, Stage = DataSetVersionImportStage.Pending + }; + + publicDataDbContext.DataSetVersionImports.Add(dataSetVersionImport); + await publicDataDbContext.SaveChangesAsync(cancellationToken); + } + + private async Task UpdateFilePublicDataSetVersionId( + ReleaseFile releaseFile, + DataSetVersion dataSetVersion, + CancellationToken cancellationToken) + { + releaseFile.File.PublicApiDataSetId = dataSetVersion.DataSetId; + releaseFile.File.PublicApiDataSetVersion = dataSetVersion.FullSemanticVersion(); + await contentDbContext.SaveChangesAsync(cancellationToken); + } + + private static async Task<List<Guid>> GetReleaseIdsForReleaseFiles( + ContentDbContext contentDbContext, + List<Guid> releaseFileIds, + CancellationToken cancellationToken) + { + return await contentDbContext + .ReleaseFiles + .Include(releaseFile => releaseFile.ReleaseVersion) + .Where(releaseFile => releaseFileIds.Contains(releaseFile.Id)) + .Select(releaseFile => releaseFile.ReleaseVersion.ReleaseId) + .ToListAsync(cancellationToken); + } + + private static ErrorViewModel CreateReleaseFileIdError( + LocalizableMessage message, + Guid releaseFileId) + { + return new ErrorViewModel + { + Code = message.Code, + Message = message.Message, + Path = nameof(DataSetCreateRequest.ReleaseFileId).ToLowerFirst(), + Detail = new InvalidErrorDetail<Guid>(releaseFileId) + }; + } + + private static ErrorViewModel CreateDataSetIdError( + LocalizableMessage message, + Guid dataSetId) + { + return new ErrorViewModel + { + Code = message.Code, + Message = message.Message, + Path = nameof(NextDataSetVersionCreateRequest.DataSetId).ToLowerFirst(), + Detail = new InvalidErrorDetail<Guid>(dataSetId) + }; + } } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetService.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetService.cs index d63c756318f..c35b4500ae8 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetService.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetService.cs @@ -6,13 +6,8 @@ namespace GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Servi public interface IDataSetService { - Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateInitialDataSetVersion( + Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateDataSet( DataSetCreateRequest request, Guid instanceId, CancellationToken cancellationToken = default); - - Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateNextDataSetVersion( - NextDataSetVersionCreateRequest request, - Guid instanceId, - CancellationToken cancellationToken = default); } diff --git a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetVersionService.cs b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetVersionService.cs index 49e7c99f75d..cf874f335ac 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetVersionService.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Public.Data.Processor/Services/Interfaces/IDataSetVersionService.cs @@ -5,6 +5,18 @@ namespace GovUk.Education.ExploreEducationStatistics.Public.Data.Processor.Servi public interface IDataSetVersionService { + Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateInitialVersion( + Guid dataSetId, + Guid releaseFileId, + Guid instanceId, + CancellationToken cancellationToken = default); + + Task<Either<ActionResult, (Guid dataSetId, Guid dataSetVersionId)>> CreateNextVersion( + Guid dataSetId, + Guid releaseFileId, + Guid instanceId, + CancellationToken cancellationToken = default); + Task<Either<ActionResult, Unit>> DeleteVersion( Guid dataSetVersionId, CancellationToken cancellationToken = default); diff --git a/src/GovUk.Education.ExploreEducationStatistics.Publisher.Tests/PublisherFunctionsIntegrationTest.cs b/src/GovUk.Education.ExploreEducationStatistics.Publisher.Tests/PublisherFunctionsIntegrationTest.cs index 588b7c8d4ec..820bc914606 100644 --- a/src/GovUk.Education.ExploreEducationStatistics.Publisher.Tests/PublisherFunctionsIntegrationTest.cs +++ b/src/GovUk.Education.ExploreEducationStatistics.Publisher.Tests/PublisherFunctionsIntegrationTest.cs @@ -55,7 +55,10 @@ public override IHostBuilder ConfigureTestHostBuilder() services.UseInMemoryDbContext<ContentDbContext>(); services.AddDbContext<PublicDataDbContext>( - options => options.UseNpgsql(_postgreSqlContainer.GetConnectionString())); + options => options + .UseNpgsql( + _postgreSqlContainer.GetConnectionString(), + psqlOptions => psqlOptions.EnableRetryOnFailure())); using var serviceScope = services.BuildServiceProvider() .GetRequiredService<IServiceScopeFactory>()