Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: stream smart queries #4044

Merged
merged 13 commits into from
Jan 7, 2025
85 changes: 85 additions & 0 deletions packages/shared/src/components/post/smartPrompts/CustomPrompt.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import React, { useCallback, useMemo } from 'react';
import type { ReactElement } from 'react';
import { Button, ButtonSize, ButtonVariant } from '../../buttons/Button';
import type { Post } from '../../../graphql/posts';
import { useSmartPrompt } from '../../../hooks/prompt/useSmartPrompt';
import { usePromptsQuery } from '../../../hooks/prompt/usePromptsQuery';
import { PromptDisplay } from '../../../graphql/prompt';
import { SearchProgressBar } from '../../search';
import { isNullOrUndefined } from '../../../lib/func';
import Alert, { AlertType } from '../../widgets/Alert';
import { labels } from '../../../lib';
import { RenderMarkdown } from '../../RenderMarkdown';

type CustomPromptProps = {
post: Post;
};

export const CustomPrompt = ({ post }: CustomPromptProps): ReactElement => {
const { data: prompts } = usePromptsQuery();
const prompt = useMemo(
() => prompts?.find((p) => p.id === PromptDisplay.CustomPrompt),
[prompts],
);
const { executePrompt, data, isPending } = useSmartPrompt({ post, prompt });
const onSubmitCustomPrompt = useCallback(
(e) => {
e.preventDefault();

executePrompt(e.target[0].value);
},
[executePrompt],
);

if (!data) {
return (
<form
className="rounded-14 bg-surface-float"
onSubmit={onSubmitCustomPrompt}
>
<textarea
className="min-h-[9.5rem] w-full bg-transparent p-3"
placeholder="Write your custom instruction to tailor the post to your needs."
/>
<div className="flex border-t border-t-border-subtlest-tertiary px-4 py-2">
<Button
className="ml-auto"
variant={ButtonVariant.Primary}
size={ButtonSize.Small}
>
Run prompt
</Button>
</div>
</form>
);
}

return (
<div>
{!!data?.chunks?.[0]?.steps && (
<div className="mb-4">
<SearchProgressBar
max={data?.chunks?.[0]?.steps}
progress={data?.chunks?.[0]?.progress}
/>
{!!data?.chunks?.[0]?.status && (
<div className="mt-2 text-text-tertiary typo-callout">
{data?.chunks?.[0]?.status}
</div>
)}
</div>
)}
{!isNullOrUndefined(data?.chunks?.[0]?.error?.code) && (
<Alert
className="mb-4"
type={AlertType.Error}
title={data?.chunks?.[0]?.error?.message || labels.error.generic}
/>
)}
<RenderMarkdown
isLoading={isPending}
content={data?.chunks?.[0]?.response || ''}
/>
</div>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { usePromptsQuery } from '../../../hooks/prompt/usePromptsQuery';
import { ElementPlaceholder } from '../../ElementPlaceholder';
import type { PromptFlags } from '../../../graphql/prompt';
import { PromptDisplay } from '../../../graphql/prompt';
import { usePromptButtons } from '../../../hooks/feed/usePromptButtons';
import { usePromptButtons } from '../../../hooks/prompt/usePromptButtons';
import { useViewSize, ViewSize } from '../../../hooks';
import { SimpleTooltip } from '../../tooltips';
import { promptColorMap, PromptIconMap } from './common';
Expand Down
30 changes: 5 additions & 25 deletions packages/shared/src/components/post/smartPrompts/SmartPrompt.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useRef, useState } from 'react';
import React, { useRef, useState } from 'react';
import type { ReactElement } from 'react';
import type { Post } from '../../../graphql/posts';
import PostSummary from '../../cards/common/PostSummary';
Expand All @@ -9,7 +9,8 @@ import { PromptDisplay } from '../../../graphql/prompt';
import { PostUpgradeToPlus } from '../../plus/PostUpgradeToPlus';
import { TargetId } from '../../../lib/log';
import ShowMoreContent from '../../cards/common/ShowMoreContent';
import { Button, ButtonSize, ButtonVariant } from '../../buttons/Button';
import { SmartPromptResponse } from './SmartPromptResponse';
import { CustomPrompt } from './CustomPrompt';

export const SmartPrompt = ({ post }: { post: Post }): ReactElement => {
const { isPlus, showPlusSubscription } = usePlusSubscription();
Expand All @@ -20,10 +21,6 @@ export const SmartPrompt = ({ post }: { post: Post }): ReactElement => {
const elementRef = useRef<HTMLDivElement>(null);
const width = elementRef?.current?.getBoundingClientRect()?.width || 0;

const onSubmitCustomPrompt = useCallback((e) => {
e.preventDefault();
}, []);

const onSetActivePrompt = (prompt: string) => {
setActivePrompt(prompt);
if (!isPlus && prompt !== PromptDisplay.TLDR) {
Expand Down Expand Up @@ -69,28 +66,11 @@ export const SmartPrompt = ({ post }: { post: Post }): ReactElement => {
</Tab>

<Tab label={PromptDisplay.SmartPrompt}>
Smart prompt - {activePrompt}
<SmartPromptResponse post={post} activePrompt={activePrompt} />
</Tab>

<Tab label={PromptDisplay.CustomPrompt}>
<form
className="rounded-14 bg-surface-float"
onSubmit={onSubmitCustomPrompt}
>
<textarea
className="min-h-[9.5rem] w-full bg-transparent p-3"
placeholder="Write your custom instruction to tailor the post to your needs."
/>
<div className="flex border-t border-t-border-subtlest-tertiary px-4 py-2">
<Button
className="ml-auto"
variant={ButtonVariant.Primary}
size={ButtonSize.Small}
>
Run prompt
</Button>
</div>
</form>
<CustomPrompt post={post} />
</Tab>

<Tab label={PromptDisplay.UpgradeToPlus}>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import React, { useEffect, useMemo } from 'react';
import type { ReactElement } from 'react';
import type { Post } from '../../../graphql/posts';
import { RenderMarkdown } from '../../RenderMarkdown';
import { SearchProgressBar } from '../../search';
import Alert, { AlertType } from '../../widgets/Alert';
import { isNullOrUndefined } from '../../../lib/func';
import { labels } from '../../../lib';
import { usePromptsQuery } from '../../../hooks/prompt/usePromptsQuery';
import { useSmartPrompt } from '../../../hooks/prompt/useSmartPrompt';

type SmartPromptResponseProps = {
post: Post;
activePrompt: string;
};

export const SmartPromptResponse = ({
post,
activePrompt,
}: SmartPromptResponseProps): ReactElement => {
const { data: prompts } = usePromptsQuery();
const prompt = useMemo(
() => prompts?.find((p) => p.id === activePrompt),
[activePrompt, prompts],
);

const { executePrompt, data, isPending } = useSmartPrompt({ post, prompt });

useEffect(() => {
if (!prompt.prompt || data) {
return;
}

executePrompt(prompt.prompt + new Date().getTime());
omBratteng marked this conversation as resolved.
Show resolved Hide resolved
}, [prompt, executePrompt, data]);

return (
<div>
{!!data?.chunks?.[0]?.steps && (
<div className="mb-4">
<SearchProgressBar
max={data?.chunks?.[0]?.steps}
progress={data?.chunks?.[0]?.progress}
/>
{!!data?.chunks?.[0]?.status && (
<div className="mt-2 text-text-tertiary typo-callout">
{data?.chunks?.[0]?.status}
</div>
)}
</div>
)}
{!isNullOrUndefined(data?.chunks?.[0]?.error?.code) && (
<Alert
className="mb-4"
type={AlertType.Error}
title={data?.chunks?.[0]?.error?.message || labels.error.generic}
/>
)}
<RenderMarkdown
isLoading={isPending}
content={data?.chunks?.[0]?.response || ''}
/>
</div>
);
};
2 changes: 2 additions & 0 deletions packages/shared/src/graphql/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export type PromptFlags = {
export type Prompt = {
id: string;
label: string;
prompt: string;
description?: string;
createdAt: Date;
updatedAt: Date;
Expand All @@ -28,6 +29,7 @@ export const PROMPTS_QUERY = gql`
id
label
description
prompt
flags {
icon
color
Expand Down
26 changes: 25 additions & 1 deletion packages/shared/src/graphql/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ export const getSearchUrl = (params: SearchUrlParams): string => {

export const searchQueryUrl = `${apiUrl}/search/query`;

export const sendPrompt = async (
params: URLSearchParams,
): Promise<EventSource> => {
return new EventSource(`${searchQueryUrl}?${params}`);
};

export const sendSearchQuery = async (
query: string,
token: string,
Expand All @@ -374,7 +380,25 @@ export const sendSearchQuery = async (
token,
});

return new EventSource(`${searchQueryUrl}?${params}`);
return sendPrompt(params);
};

export const sendSmartPromptQuery = async ({
query,
token,
post,
}: {
query: string;
token: string;
post: Post;
}): Promise<EventSource> => {
const params = new URLSearchParams({
prompt: query,
token,
postId: post.id,
});

return sendPrompt(params);
};

export type SearchSuggestion = {
Expand Down
3 changes: 1 addition & 2 deletions packages/shared/src/hooks/chat/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { QueryKey } from '@tanstack/react-query';
import type { MouseEvent } from 'react';
import type { Search, SearchChunkSource } from '../../graphql/search';

export interface UseChatProps {
Expand Down Expand Up @@ -32,7 +31,7 @@ export interface UseChat {
queryKey: QueryKey;
data: Search;
isLoading: boolean;
handleSubmit(event: MouseEvent, value: string): void;
handleSubmit(prompt: string): Promise<void>;
}

export interface CreatePayload {
Expand Down
3 changes: 2 additions & 1 deletion packages/shared/src/hooks/chat/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { useQueryClient, useQuery } from '@tanstack/react-query';
import { useAuthContext } from '../../contexts/AuthContext';
import type { Search } from '../../graphql/search';
import { getSearchSession } from '../../graphql/search';
import { generateQueryKey, RequestKey } from '../../lib/query';
import { generateQueryKey, RequestKey, StaleTime } from '../../lib/query';
import type { UseChatSessionProps, UseChatSession } from './types';

export const useChatSession = ({
Expand All @@ -27,6 +27,7 @@ export const useChatSession = ({
return getSearchSession(id);
},
enabled: !!id,
staleTime: StaleTime.OneHour,
});

return {
Expand Down
15 changes: 5 additions & 10 deletions packages/shared/src/hooks/chat/useChatStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ export const useChatStream = (): UseChatStream => {
const [sessionId, setSessionId] = useState<string>(null);

const executePrompt = useCallback(
async (value: string) => {
if (!value) {
async (prompt: string) => {
if (!prompt) {
return;
}

Expand Down Expand Up @@ -87,7 +87,7 @@ export const useChatStream = (): UseChatStream => {
...payload,
createdAt: new Date(),
status: data.status,
prompt: value,
prompt,
}),
);

Expand Down Expand Up @@ -164,7 +164,7 @@ export const useChatStream = (): UseChatStream => {
logErrorEvent(code);
};

const source = await sendSearchQuery(value, accessToken?.token);
const source = await sendSearchQuery(prompt, accessToken?.token);
source.addEventListener('message', onMessage);
source.addEventListener('error', onError);
sourceRef.current = source;
Expand All @@ -182,11 +182,6 @@ export const useChatStream = (): UseChatStream => {

return {
id: sessionId,
handleSubmit: useCallback(
(_, value: string) => {
executePrompt(value);
},
[executePrompt],
),
handleSubmit: executePrompt,
};
};
Loading
Loading