Skip to content

Commit

Permalink
feat: Parallel state (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisDobby authored Dec 28, 2024
1 parent 46d38bf commit 8df7e94
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 12 deletions.
2 changes: 1 addition & 1 deletion demo/src/mockResponse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe("mock response tests", () => {
afterAll(mockResponseTearDown)

const httpTask = {
Type: "Task",
Type: "Task" as const,
Resource: "arn:aws:states:::http:invoke",
Parameters: {
ApiEndpoint: "https://67346234723.execute-api.eu-west-1.amazonaws.com/prod/test",
Expand Down
155 changes: 155 additions & 0 deletions demo/src/parallel.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import { testSingleState, testFunction } from "@chrisdobby/step-by-step"

describe("parallel tests", () => {
it("should test a parallel state", async () => {
const result = await testSingleState({
stateDefinition: {
Type: "Parallel",
End: true,
Branches: [
{
StartAt: "State1",
States: {
State1: {
QueryLanguage: "JSONata",
Type: "Pass",
Output: { branch: 1 },
End: true,
},
},
},
{
StartAt: "State2",
States: {
State2: {
QueryLanguage: "JSONata",
Type: "Pass",
Output: { branch: 2 },
End: true,
},
},
},
],
},
})

expect(result.status).toBe("SUCCEEDED")
expect(result.output).toEqual([{ branch: 1 }, { branch: 2 }])
expect(result.stack).toHaveLength(2)
})

it("should test nested parallel states", async () => {
const result = await testSingleState({
stateDefinition: {
Type: "Parallel",
End: true,
Branches: [
{
StartAt: "State1",
States: {
State1: {
QueryLanguage: "JSONata",
Type: "Pass",
Output: { branch: 1 },
End: true,
},
},
},
{
StartAt: "State2",
States: {
State2: {
Type: "Parallel",
End: true,
Branches: [
{
StartAt: "State2-1",
States: {
"State2-1": {
QueryLanguage: "JSONata",
Type: "Pass",
Output: { branch: 2, subBranch: 1 },
End: true,
},
},
},
{
StartAt: "State2-2",
States: {
"State2-2": {
QueryLanguage: "JSONata",
Type: "Pass",
Output: { branch: 2, subBranch: 2 },
End: true,
},
},
},
],
},
},
},
],
},
})

expect(result.status).toBe("SUCCEEDED")
expect(result.output).toEqual([
{ branch: 1 },
[
{ branch: 2, subBranch: 1 },
{ branch: 2, subBranch: 2 },
],
])
expect(result.stack).toHaveLength(4)
})

it("should test a parallel state as part of a function", async () => {
const result = await testFunction({
functionDefinition: {
QueryLanguage: "JSONata",
StartAt: "ParallelState",
States: {
ParallelState: {
Type: "Parallel",
Next: "CombineResults",
Branches: [
{
StartAt: "State1",
States: {
State1: {
Type: "Pass",
Output: { branch: 1 },
End: true,
},
},
},
{
StartAt: "State2",
States: {
State2: {
Type: "Pass",
Output: { branch: 2 },
End: true,
},
},
},
],
},
CombineResults: {
Type: "Pass",
End: true,
QueryLanguage: "JSONPath",
Parameters: {
"output1.$": "$[0]",
"output2.$": "$[1]",
},
},
},
},
})

expect(result.status).toBe("SUCCEEDED")
expect(result.stack).toHaveLength(4)
expect(result.output).toEqual({ output1: { branch: 1 }, output2: { branch: 2 } })
})
})
49 changes: 44 additions & 5 deletions lib/src/testStates.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { SFNClient, TestStateCommand, TestStateCommandOutput, SFNServiceException } from "@aws-sdk/client-sfn"
import {
ParallelState,
TestFunctionInput,
TestFunctionOutput,
TestSingleStateInput,
Expand Down Expand Up @@ -48,13 +49,45 @@ const testState = async (
}
}

const testParallelState = async ({
stateDefinition,
input,
}: {
stateDefinition: ParallelState
input: TestSingleStateInput["input"]
}): Promise<TestSingleStateOutput> => {
const branchOutputs = await Promise.all(
(stateDefinition as ParallelState).Branches.map(branch =>
testFunction({ functionDefinition: { ...branch, QueryLanguage: stateDefinition.QueryLanguage }, input })
)
)

return {
status: "SUCCEEDED",
nextState: stateDefinition.Next,
output: branchOutputs.map(({ output }) => output) as Record<string, unknown>[],
stack: branchOutputs.map(({ stack }) => stack).flat(),
}
}

export const testSingleState = async ({
state = "step-by-step-single-state",
stateDefinition,
input,
mockedResult,
}: TestSingleStateInput): Promise<TestSingleStateOutput> =>
mockedResult || testState(transformState(responseMocks.transformState(state, stateDefinition)), input)
}: TestSingleStateInput): Promise<TestSingleStateOutput> => {
if (mockedResult) {
return mockedResult
}

switch (stateDefinition.Type) {
case "Parallel":
return testParallelState({ stateDefinition: stateDefinition as ParallelState, input })

default:
return testState(transformState(responseMocks.transformState(state, stateDefinition)), input)
}
}

const execute = async ({
functionDefinition,
Expand All @@ -67,14 +100,20 @@ const execute = async ({
stack?: TestFunctionOutput["stack"]
endState?: string
}): Promise<TestFunctionOutput> => {
const stateDefinition = functionDefinition.States[state]
const result = await testSingleState({
const stateDefinition = functionDefinition.QueryLanguage
? {
...functionDefinition.States[state],
QueryLanguage: functionDefinition.States[state].QueryLanguage || functionDefinition.QueryLanguage,
}
: functionDefinition.States[state]

const { stack: singleStateStack, ...result } = await testSingleState({
state,
stateDefinition,
input,
mockedResult: stateMocks.mockedResult(state, functionDefinition.States[state].Next),
})
const updatedStack = [...stack, { ...result, stateName: state }]
const updatedStack = [...stack, ...(singleStateStack || []), { ...result, stateName: state }]

return stateDefinition.End || state === endState || result.status === "FAILED"
? { ...result, stack: updatedStack }
Expand Down
33 changes: 27 additions & 6 deletions lib/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
import { TestExecutionStatus } from "@aws-sdk/client-sfn"

export type StateType = "Task" | "Pass" | "Wait" | "Choice" | "Succeed" | "Fail" | "Parallel" | "Map"
type QueryLanguage = "JSONata" | "JSONPath"
type StateInputOutput = Record<string, unknown> | Record<string, unknown>[]
export type State = {
QueryLanguage?: QueryLanguage
End?: boolean
Next?: string
} & Record<string, unknown>

export type ParallelState = State & {
Type: "Parallel"
Branches: TestFunctionInput["functionDefinition"][]
Next?: string
End?: boolean
QueryLanguage?: QueryLanguage
}

type StateDefinition = State & { Type: Omit<StateType, "Parallel"> }

export type TestSingleStateInput = {
state?: string
stateDefinition: Record<string, unknown>
input?: Record<string, unknown>
stateDefinition: StateDefinition | ParallelState
input?: StateInputOutput
mockedResult?: TestSingleStateOutput | null
}

Expand All @@ -14,15 +33,17 @@ export type TestSingleStateOutput = {
}
status?: TestExecutionStatus
nextState?: string
output?: Record<string, unknown>
output?: StateInputOutput
stack?: (TestSingleStateOutput & { stateName: string })[]
}

export type TestFunctionInput = {
functionDefinition: {
QueryLanguage?: QueryLanguage
StartAt: string
States: Record<string, { Type: string; End?: boolean; Next?: string } & Record<string, unknown>>
States: Record<string, { Type: StateType; End?: boolean; Next?: string } & Record<string, unknown>>
}
input?: Record<string, unknown>
input?: StateInputOutput
}

type OutputError = {
Expand All @@ -32,7 +53,7 @@ type OutputError = {
export type TestFunctionOutput = {
error?: OutputError
status?: TestExecutionStatus
output?: Record<string, unknown>
output?: StateInputOutput
stack: (TestSingleStateOutput & { stateName: string })[]
}

Expand Down

0 comments on commit 8df7e94

Please sign in to comment.