diff --git a/.changeset/dirty-cups-chew.md b/.changeset/dirty-cups-chew.md new file mode 100644 index 00000000000..d3216a309f8 --- /dev/null +++ b/.changeset/dirty-cups-chew.md @@ -0,0 +1,5 @@ +--- +"@aws-amplify/ui-react-ai": patch +--- + +fix(ai): update useAIGeneration to manage its own date state diff --git a/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx b/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx index 346e19cb650..a2aa63fd3b0 100644 --- a/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx +++ b/packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx @@ -171,5 +171,38 @@ describe('createAIHooks', () => { const [awaitedState] = hookResult.current; expect(awaitedState.data).toStrictEqual(expectedResult); }); + + it('returns a result with graphqlErrors', async () => { + const client = new mockClient(); + const expectedResult = { + recipe: 'This is a recipe for chocolate cake that tastes bad', + }; + const generateReturn = { + data: expectedResult, + errors: ['this is just one error'], + }; + generateRecipeMock.mockResolvedValueOnce(generateReturn); + const { useAIGeneration } = createAIHooks(client); + + const { result: hookResult, waitForNextUpdate } = renderHook(() => + useAIGeneration('generateRecipe') + ); + + const [_result, generate] = hookResult.current; + act(() => { + generate({ + description: 'I want a recipe for a gluten-free chocolate cake.', + }); + }); + + const [loadingState] = hookResult.current; + expect(loadingState.isLoading).toBeTruthy(); + + await waitForNextUpdate(); + + const [awaitedState] = hookResult.current; + expect(awaitedState.data).toStrictEqual(expectedResult); + expect(awaitedState.graphqlErrors).toHaveLength(1); + }); }); }); diff --git a/packages/react-ai/src/hooks/useAIGeneration.tsx b/packages/react-ai/src/hooks/useAIGeneration.tsx index cfa9addd075..c9aa4c139ef 100644 --- a/packages/react-ai/src/hooks/useAIGeneration.tsx +++ b/packages/react-ai/src/hooks/useAIGeneration.tsx @@ -9,7 +9,7 @@ export interface UseAIGenerationHookWrapper< useAIGeneration: ( routeName: U ) => [ - Awaited>, + Awaited>, (input: Schema[U]['args']) => void, ]; } @@ -20,7 +20,7 @@ export type UseAIGenerationHook< > = ( routeName: Key ) => [ - Awaited>, + Awaited>, (input: Schema[Key]['args']) => void, ]; @@ -37,6 +37,15 @@ interface GraphQLFormattedError { }; } +type SingularReturnValue = { + data: T | null; + errors?: GraphQLFormattedError[]; +}; + +type GenerateState = DataState & { + graphqlErrors?: GraphQLFormattedError[]; +}; + export function createUseAIGeneration< Client extends Record<'generations' | 'conversations', Record>, Schema extends getSchema, @@ -45,7 +54,10 @@ export function createUseAIGeneration< Key extends keyof AIGenerationClient['generations'], >( routeName: Key - ) => { + ): [ + state: GenerateState, + handleAction: (input: Schema[Key]['args']) => void, + ] => { const handleGenerate = ( client.generations as AIGenerationClient['generations'] )[routeName]; @@ -53,23 +65,19 @@ export function createUseAIGeneration< const updateAIGenerationStateAction = async ( _prev: Schema[Key]['returnType'], input: Schema[Key]['args'] - ): Promise< - Schema[Key]['returnType'] & { graphqlErrors?: GraphQLFormattedError[] } - > => { - const result = await handleGenerate(input); - - // handleGenerate returns a Promised wrapper around Schema[Key]['returnType'] which includes data, errors, and clientExtensions - // The type of data is Schema[Key]['returnType'] which useDataState also wraps in a data return - // TODO: follow up with how to type handleGenerate to properly return the promise wrapper shape - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - const data = (result as any).data as Schema[Key]['returnType']; - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access - const graphqlErrors = (result as any).errors; - // eslint-disable-next-line @typescript-eslint/no-unsafe-return, @typescript-eslint/no-unsafe-assignment - return { ...data, ...(graphqlErrors ? { graphqlErrors } : {}) }; + ): Promise => { + return await handleGenerate(input); }; - return useDataState(updateAIGenerationStateAction, {}); + const [result, handler] = useDataState( + updateAIGenerationStateAction, + undefined + ); + + const { data, errors } = + (result?.data as SingularReturnValue) ?? {}; + + return [{ ...result, data, graphqlErrors: errors }, handler]; }; return useAIGeneration;