From 8df7e947eae91c1715303f8199a86845043f50fa Mon Sep 17 00:00:00 2001 From: Chris Dobson Date: Sat, 28 Dec 2024 14:22:16 +0000 Subject: [PATCH] feat: Parallel state (#4) --- demo/src/mockResponse.test.ts | 2 +- demo/src/parallel.test.ts | 155 ++++++++++++++++++++++++++++++++++ lib/src/testStates.ts | 49 +++++++++-- lib/src/types.ts | 33 ++++++-- 4 files changed, 227 insertions(+), 12 deletions(-) create mode 100644 demo/src/parallel.test.ts diff --git a/demo/src/mockResponse.test.ts b/demo/src/mockResponse.test.ts index 72d770a..234d6a1 100644 --- a/demo/src/mockResponse.test.ts +++ b/demo/src/mockResponse.test.ts @@ -5,7 +5,7 @@ describe("mock response tests", () => { afterAll(mockResponseTearDown) const httpTask = { - Type: "Task", + Type: "Task" as const, Resource: "arn:aws:states:::http:invoke", Parameters: { ApiEndpoint: "https://67346234723.execute-api.eu-west-1.amazonaws.com/prod/test", diff --git a/demo/src/parallel.test.ts b/demo/src/parallel.test.ts new file mode 100644 index 0000000..a4e29a7 --- /dev/null +++ b/demo/src/parallel.test.ts @@ -0,0 +1,155 @@ +import { testSingleState, testFunction } from "@chrisdobby/step-by-step" + +describe("parallel tests", () => { + it("should test a parallel state", async () => { + const result = await testSingleState({ + stateDefinition: { + Type: "Parallel", + End: true, + Branches: [ + { + StartAt: "State1", + States: { + State1: { + QueryLanguage: "JSONata", + Type: "Pass", + Output: { branch: 1 }, + End: true, + }, + }, + }, + { + StartAt: "State2", + States: { + State2: { + QueryLanguage: "JSONata", + Type: "Pass", + Output: { branch: 2 }, + End: true, + }, + }, + }, + ], + }, + }) + + expect(result.status).toBe("SUCCEEDED") + expect(result.output).toEqual([{ branch: 1 }, { branch: 2 }]) + expect(result.stack).toHaveLength(2) + }) + + it("should test nested parallel states", async () => { + const result = await testSingleState({ + stateDefinition: { + Type: "Parallel", + End: true, + Branches: [ + { + StartAt: "State1", + States: { + State1: { + QueryLanguage: "JSONata", + Type: "Pass", + Output: { branch: 1 }, + End: true, + }, + }, + }, + { + StartAt: "State2", + States: { + State2: { + Type: "Parallel", + End: true, + Branches: [ + { + StartAt: "State2-1", + States: { + "State2-1": { + QueryLanguage: "JSONata", + Type: "Pass", + Output: { branch: 2, subBranch: 1 }, + End: true, + }, + }, + }, + { + StartAt: "State2-2", + States: { + "State2-2": { + QueryLanguage: "JSONata", + Type: "Pass", + Output: { branch: 2, subBranch: 2 }, + End: true, + }, + }, + }, + ], + }, + }, + }, + ], + }, + }) + + expect(result.status).toBe("SUCCEEDED") + expect(result.output).toEqual([ + { branch: 1 }, + [ + { branch: 2, subBranch: 1 }, + { branch: 2, subBranch: 2 }, + ], + ]) + expect(result.stack).toHaveLength(4) + }) + + it("should test a parallel state as part of a function", async () => { + const result = await testFunction({ + functionDefinition: { + QueryLanguage: "JSONata", + StartAt: "ParallelState", + States: { + ParallelState: { + Type: "Parallel", + Next: "CombineResults", + Branches: [ + { + StartAt: "State1", + States: { + State1: { + Type: "Pass", + Output: { branch: 1 }, + End: true, + }, + }, + }, + { + StartAt: "State2", + States: { + State2: { + Type: "Pass", + Output: { branch: 2 }, + End: true, + }, + }, + }, + ], + }, + CombineResults: { + Type: "Pass", + End: true, + QueryLanguage: "JSONPath", + Parameters: { + "output1.$": "$[0]", + "output2.$": "$[1]", + }, + }, + }, + }, + }) + + expect(result.status).toBe("SUCCEEDED") + expect(result.stack).toHaveLength(4) + expect(result.output).toEqual({ output1: { branch: 1 }, output2: { branch: 2 } }) + }) +}) diff --git a/lib/src/testStates.ts b/lib/src/testStates.ts index 0d14400..bc05a25 100644 --- a/lib/src/testStates.ts +++ b/lib/src/testStates.ts @@ -1,5 +1,6 @@ import { SFNClient, TestStateCommand, TestStateCommandOutput, SFNServiceException } from "@aws-sdk/client-sfn" import { + ParallelState, TestFunctionInput, TestFunctionOutput, TestSingleStateInput, @@ -48,13 +49,45 @@ const testState = async ( } } +const testParallelState = async ({ + stateDefinition, + input, +}: { + stateDefinition: ParallelState + input: TestSingleStateInput["input"] +}): Promise => { + const branchOutputs = await Promise.all( + (stateDefinition as ParallelState).Branches.map(branch => + testFunction({ functionDefinition: { ...branch, QueryLanguage: stateDefinition.QueryLanguage }, input }) + ) + ) + + return { + status: "SUCCEEDED", + nextState: stateDefinition.Next, + output: branchOutputs.map(({ output }) => output) as Record[], + stack: branchOutputs.map(({ stack }) => stack).flat(), + } +} + export const testSingleState = async ({ state = "step-by-step-single-state", stateDefinition, input, mockedResult, -}: TestSingleStateInput): Promise => - mockedResult || testState(transformState(responseMocks.transformState(state, stateDefinition)), input) +}: TestSingleStateInput): Promise => { + if (mockedResult) { + return mockedResult + } + + switch (stateDefinition.Type) { + case "Parallel": + return testParallelState({ stateDefinition: stateDefinition as ParallelState, input }) + + default: + return testState(transformState(responseMocks.transformState(state, stateDefinition)), input) + } +} const execute = async ({ functionDefinition, @@ -67,14 +100,20 @@ const execute = async ({ stack?: TestFunctionOutput["stack"] endState?: string }): Promise => { - const stateDefinition = functionDefinition.States[state] - const result = await testSingleState({ + const stateDefinition = functionDefinition.QueryLanguage + ? { + ...functionDefinition.States[state], + QueryLanguage: functionDefinition.States[state].QueryLanguage || functionDefinition.QueryLanguage, + } + : functionDefinition.States[state] + + const { stack: singleStateStack, ...result } = await testSingleState({ state, stateDefinition, input, mockedResult: stateMocks.mockedResult(state, functionDefinition.States[state].Next), }) - const updatedStack = [...stack, { ...result, stateName: state }] + const updatedStack = [...stack, ...(singleStateStack || []), { ...result, stateName: state }] return stateDefinition.End || state === endState || result.status === "FAILED" ? { ...result, stack: updatedStack } diff --git a/lib/src/types.ts b/lib/src/types.ts index f62742b..53989c4 100644 --- a/lib/src/types.ts +++ b/lib/src/types.ts @@ -1,9 +1,28 @@ import { TestExecutionStatus } from "@aws-sdk/client-sfn" +export type StateType = "Task" | "Pass" | "Wait" | "Choice" | "Succeed" | "Fail" | "Parallel" | "Map" +type QueryLanguage = "JSONata" | "JSONPath" +type StateInputOutput = Record | Record[] +export type State = { + QueryLanguage?: QueryLanguage + End?: boolean + Next?: string +} & Record + +export type ParallelState = State & { + Type: "Parallel" + Branches: TestFunctionInput["functionDefinition"][] + Next?: string + End?: boolean + QueryLanguage?: QueryLanguage +} + +type StateDefinition = State & { Type: Omit } + export type TestSingleStateInput = { state?: string - stateDefinition: Record - input?: Record + stateDefinition: StateDefinition | ParallelState + input?: StateInputOutput mockedResult?: TestSingleStateOutput | null } @@ -14,15 +33,17 @@ export type TestSingleStateOutput = { } status?: TestExecutionStatus nextState?: string - output?: Record + output?: StateInputOutput + stack?: (TestSingleStateOutput & { stateName: string })[] } export type TestFunctionInput = { functionDefinition: { + QueryLanguage?: QueryLanguage StartAt: string - States: Record> + States: Record> } - input?: Record + input?: StateInputOutput } type OutputError = { @@ -32,7 +53,7 @@ type OutputError = { export type TestFunctionOutput = { error?: OutputError status?: TestExecutionStatus - output?: Record + output?: StateInputOutput stack: (TestSingleStateOutput & { stateName: string })[] }