Skip to content

Commit

Permalink
fix(ai): update useAIGeneration to add graphqlErrors before return (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
thaddmt authored Oct 11, 2024
1 parent 7413872 commit 5af986f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changeset/dirty-cups-chew.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/ui-react-ai": patch
---

fix(ai): update useAIGeneration to manage its own date state
33 changes: 33 additions & 0 deletions packages/react-ai/src/hooks/__tests__/createAIHooks.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,38 @@ describe('createAIHooks', () => {
const [awaitedState] = hookResult.current;
expect(awaitedState.data).toStrictEqual(expectedResult);
});

it('returns a result with graphqlErrors', async () => {
const client = new mockClient();
const expectedResult = {
recipe: 'This is a recipe for chocolate cake that tastes bad',
};
const generateReturn = {
data: expectedResult,
errors: ['this is just one error'],
};
generateRecipeMock.mockResolvedValueOnce(generateReturn);
const { useAIGeneration } = createAIHooks(client);

const { result: hookResult, waitForNextUpdate } = renderHook(() =>
useAIGeneration('generateRecipe')
);

const [_result, generate] = hookResult.current;
act(() => {
generate({
description: 'I want a recipe for a gluten-free chocolate cake.',
});
});

const [loadingState] = hookResult.current;
expect(loadingState.isLoading).toBeTruthy();

await waitForNextUpdate();

const [awaitedState] = hookResult.current;
expect(awaitedState.data).toStrictEqual(expectedResult);
expect(awaitedState.graphqlErrors).toHaveLength(1);
});
});
});
44 changes: 26 additions & 18 deletions packages/react-ai/src/hooks/useAIGeneration.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export interface UseAIGenerationHookWrapper<
useAIGeneration: <U extends Key>(
routeName: U
) => [
Awaited<DataState<Schema[U]['returnType']>>,
Awaited<GenerateState<Schema[U]['returnType']>>,
(input: Schema[U]['args']) => void,
];
}
Expand All @@ -20,7 +20,7 @@ export type UseAIGenerationHook<
> = (
routeName: Key
) => [
Awaited<DataState<Schema[Key]['returnType']>>,
Awaited<GenerateState<Schema[Key]['returnType']>>,
(input: Schema[Key]['args']) => void,
];

Expand All @@ -37,6 +37,15 @@ interface GraphQLFormattedError {
};
}

type SingularReturnValue<T> = {
data: T | null;
errors?: GraphQLFormattedError[];
};

type GenerateState<T> = DataState<T> & {
graphqlErrors?: GraphQLFormattedError[];
};

export function createUseAIGeneration<
Client extends Record<'generations' | 'conversations', Record<string, any>>,
Schema extends getSchema<Client>,
Expand All @@ -45,31 +54,30 @@ export function createUseAIGeneration<
Key extends keyof AIGenerationClient<Schema>['generations'],
>(
routeName: Key
) => {
): [
state: GenerateState<Schema[Key]['returnType']>,
handleAction: (input: Schema[Key]['args']) => void,
] => {
const handleGenerate = (
client.generations as AIGenerationClient<Schema>['generations']
)[routeName];

const updateAIGenerationStateAction = async (
_prev: Schema[Key]['returnType'],
input: Schema[Key]['args']
): Promise<
Schema[Key]['returnType'] & { graphqlErrors?: GraphQLFormattedError[] }
> => {
const result = await handleGenerate(input);

// handleGenerate returns a Promised wrapper around Schema[Key]['returnType'] which includes data, errors, and clientExtensions
// The type of data is Schema[Key]['returnType'] which useDataState also wraps in a data return
// TODO: follow up with how to type handleGenerate to properly return the promise wrapper shape
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
const data = (result as any).data as Schema[Key]['returnType'];
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access
const graphqlErrors = (result as any).errors;
// eslint-disable-next-line @typescript-eslint/no-unsafe-return, @typescript-eslint/no-unsafe-assignment
return { ...data, ...(graphqlErrors ? { graphqlErrors } : {}) };
): Promise<Schema[Key]['returnType']> => {
return await handleGenerate(input);
};

return useDataState(updateAIGenerationStateAction, {});
const [result, handler] = useDataState(
updateAIGenerationStateAction,
undefined
);

const { data, errors } =
(result?.data as SingularReturnValue<Schema[Key]['returnType']>) ?? {};

return [{ ...result, data, graphqlErrors: errors }, handler];
};

return useAIGeneration;
Expand Down

0 comments on commit 5af986f

Please sign in to comment.