Skip to content

Commit

Permalink
feat: #2 #8 add stop and retry button
Browse files Browse the repository at this point in the history
  • Loading branch information
Yidadaa committed Mar 26, 2023
1 parent a5ec152 commit 86507fa
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 11 deletions.
60 changes: 52 additions & 8 deletions app/components/home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Locale from "../locales";

import dynamic from "next/dynamic";
import { REPO_URL } from "../constant";
import { ControllerPool } from "../requests";

export function Loading(props: { noLogo?: boolean }) {
return (
Expand Down Expand Up @@ -146,28 +147,67 @@ function useSubmitHandler() {
export function Chat(props: { showSideBar?: () => void }) {
type RenderMessage = Message & { preview?: boolean };

const session = useChatStore((state) => state.currentSession());
const [session, sessionIndex] = useChatStore((state) => [
state.currentSession(),
state.currentSessionIndex,
]);
const [userInput, setUserInput] = useState("");
const [isLoading, setIsLoading] = useState(false);
const { submitKey, shouldSubmit } = useSubmitHandler();

const onUserInput = useChatStore((state) => state.onUserInput);

// submit user input
const onUserSubmit = () => {
if (userInput.length <= 0) return;
setIsLoading(true);
onUserInput(userInput).then(() => setIsLoading(false));
setUserInput("");
};

// stop response
const onUserStop = (messageIndex: number) => {
console.log(ControllerPool, sessionIndex, messageIndex);
ControllerPool.stop(sessionIndex, messageIndex);
};

// check if should send message
const onInputKeyDown = (e: KeyboardEvent) => {
if (shouldSubmit(e)) {
onUserSubmit();
e.preventDefault();
}
};
const onRightClick = (e: any, message: Message) => {
// auto fill user input
if (message.role === "user") {
setUserInput(message.content);
}

// copy to clipboard
if (selectOrCopy(e.currentTarget, message.content)) {
e.preventDefault();
}
};

const onResend = (botIndex: number) => {
// find last user input message and resend
for (let i = botIndex; i >= 0; i -= 1) {
if (messages[i].role === "user") {
setIsLoading(true);
onUserInput(messages[i].content).then(() => setIsLoading(false));
return;
}
}
};

// for auto-scroll
const latestMessageRef = useRef<HTMLDivElement>(null);

// wont scroll while hovering messages
const [hoveringMessage, setHoveringMessage] = useState(false);

// preview messages
const messages = (session.messages as RenderMessage[])
.concat(
isLoading
Expand All @@ -194,6 +234,7 @@ export function Chat(props: { showSideBar?: () => void }) {
: []
);

// auto scroll
useLayoutEffect(() => {
setTimeout(() => {
const dom = latestMessageRef.current;
Expand Down Expand Up @@ -283,13 +324,20 @@ export function Chat(props: { showSideBar?: () => void }) {
<div className={styles["chat-message-item"]}>
{!isUser && (
<div className={styles["chat-message-top-actions"]}>
{message.streaming && (
{message.streaming ? (
<div
className={styles["chat-message-top-action"]}
onClick={() => showToast(Locale.WIP)}
onClick={() => onUserStop(i)}
>
{Locale.Chat.Actions.Stop}
</div>
) : (
<div
className={styles["chat-message-top-action"]}
onClick={() => onResend(i)}
>
{Locale.Chat.Actions.Retry}
</div>
)}

<div
Expand All @@ -306,11 +354,7 @@ export function Chat(props: { showSideBar?: () => void }) {
) : (
<div
className="markdown-body"
onContextMenu={(e) => {
if (selectOrCopy(e.currentTarget, message.content)) {
e.preventDefault();
}
}}
onContextMenu={(e) => onRightClick(e, message)}
>
<Markdown content={message.content} />
</div>
Expand Down
1 change: 1 addition & 0 deletions app/locales/cn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const cn = {
Export: "导出聊天记录",
Copy: "复制",
Stop: "停止",
Retry: "重试",
},
Typing: "正在输入…",
Input: (submitKey: string) => `输入消息,${submitKey} 发送`,
Expand Down
1 change: 1 addition & 0 deletions app/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const en: LocaleType = {
Export: "Export All Messages as Markdown",
Copy: "Copy",
Stop: "Stop",
Retry: "Retry",
},
Typing: "Typing…",
Input: (submitKey: string) =>
Expand Down
36 changes: 34 additions & 2 deletions app/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export async function requestChatStream(
modelConfig?: ModelConfig;
onMessage: (message: string, done: boolean) => void;
onError: (error: Error) => void;
onController?: (controller: AbortController) => void;
}
) {
const req = makeRequestParam(messages, {
Expand Down Expand Up @@ -96,12 +97,12 @@ export async function requestChatStream(
controller.abort();
};

console.log(res);

if (res.ok) {
const reader = res.body?.getReader();
const decoder = new TextDecoder();

options?.onController?.(controller);

while (true) {
// handle time out, will stop if no response in 10 secs
const resTimeoutId = setTimeout(() => finish(), TIME_OUT_MS);
Expand Down Expand Up @@ -146,3 +147,34 @@ export async function requestWithPrompt(messages: Message[], prompt: string) {

return res.choices.at(0)?.message?.content ?? "";
}

// To store message streaming controller
export const ControllerPool = {
controllers: {} as Record<string, AbortController>,

addController(
sessionIndex: number,
messageIndex: number,
controller: AbortController
) {
const key = this.key(sessionIndex, messageIndex);
this.controllers[key] = controller;
return key;
},

stop(sessionIndex: number, messageIndex: number) {
const key = this.key(sessionIndex, messageIndex);
const controller = this.controllers[key];
console.log(controller);
controller?.abort();
},

remove(sessionIndex: number, messageIndex: number) {
const key = this.key(sessionIndex, messageIndex);
delete this.controllers[key];
},

key(sessionIndex: number, messageIndex: number) {
return `${sessionIndex},${messageIndex}`;
},
};
20 changes: 19 additions & 1 deletion app/store/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { create } from "zustand";
import { persist } from "zustand/middleware";

import { type ChatCompletionResponseMessage } from "openai";
import { requestChatStream, requestWithPrompt } from "../requests";
import {
ControllerPool,
requestChatStream,
requestWithPrompt,
} from "../requests";
import { trimTopic } from "../utils";

import Locale from "../locales";
Expand Down Expand Up @@ -296,20 +300,25 @@ export const useChatStore = create<ChatStore>()(
// get recent messages
const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage);
const sessionIndex = get().currentSessionIndex;
const messageIndex = get().currentSession().messages.length + 1;

// save user's and bot's message
get().updateCurrentSession((session) => {
session.messages.push(userMessage);
session.messages.push(botMessage);
});

// make request
console.log("[User Input] ", sendMessages);
requestChatStream(sendMessages, {
onMessage(content, done) {
// stream response
if (done) {
botMessage.streaming = false;
botMessage.content = content;
get().onNewMessage(botMessage);
ControllerPool.remove(sessionIndex, messageIndex);
} else {
botMessage.content = content;
set(() => ({}));
Expand All @@ -319,6 +328,15 @@ export const useChatStore = create<ChatStore>()(
botMessage.content += "\n\n" + Locale.Store.Error;
botMessage.streaming = false;
set(() => ({}));
ControllerPool.remove(sessionIndex, messageIndex);
},
onController(controller) {
// collect controller for stop/retry
ControllerPool.addController(
sessionIndex,
messageIndex,
controller
);
},
filterBot: !get().config.sendBotMessages,
modelConfig: get().config.modelConfig,
Expand Down

0 comments on commit 86507fa

Please sign in to comment.