Skip to content

Commit

Permalink
Only allow users to stop streaming response to >their< last message
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Oct 11, 2024
1 parent 01d1c8c commit 8a94b19
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 79 deletions.
14 changes: 10 additions & 4 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,15 @@ def get_chat_user(self) -> ChatUser:
# (`serverapp` is a property on all `JupyterHandler` subclasses)
assert self.serverapp
extensions = self.serverapp.extension_manager.extensions
collaborative = (
collaborative_legacy = (
"jupyter_collaboration" in extensions
and extensions["jupyter_collaboration"].enabled
)
collaborative_v3 = (
"jupyter_server_ydoc" in extensions
and extensions["jupyter_server_ydoc"].enabled
)
collaborative = collaborative_legacy or collaborative_v3

if collaborative:
names = self.current_user.name.split(" ", maxsplit=2)
Expand All @@ -180,7 +185,7 @@ def get_chat_user(self) -> ChatUser:
login = getpass.getuser()
initials = login[0].capitalize()
return ChatUser(
username=login,
username=self.current_user.username,
initials=initials,
name=login,
display_name=login,
Expand Down Expand Up @@ -310,7 +315,8 @@ async def on_message(self, message):
if isinstance(request, StopRequest):
for history_message in self.chat_history[::-1]:
if (
history_message.type == "agent-stream"
history_message.id == request.target
and history_message.type == "agent-stream"
and not history_message.complete
):
try:
Expand All @@ -319,7 +325,7 @@ async def on_message(self, message):
# do nothing if the message was already interrupted
# or stream got completed (thread-safe way!)
pass
self.broadcast_message(StopMessage(target=history_message.id))
self.broadcast_message(StopMessage(target=request.target))
return

chat_request = request
Expand Down
4 changes: 4 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class StopRequest(BaseModel):
"""Request sent by human asking to stop streaming/generating the response"""

type: Literal["stop"]
target: str
"""
Message ID of the agent chat message to stop streaming.
"""


class ClearRequest(BaseModel):
Expand Down
12 changes: 9 additions & 3 deletions packages/jupyter-ai/src/components/chat-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ type ChatInputProps = {
* `'Jupyternaut'`, but can differ for custom providers.
*/
personaName: string;
streaming: boolean;
/**
* List of streaming messages that are owned by the current user.
*/
streaming: AiService.AgentStreamMessage[];
};

/**
Expand Down Expand Up @@ -273,8 +276,11 @@ export function ChatInput(props: ChatInputProps): JSX.Element {

const sendButtonProps: SendButtonProps = {
onSend,
onStop: () => {
props.chatHandler.sendMessage({ type: 'stop' });
onStop: (message: AiService.AgentStreamMessage) => {
props.chatHandler.sendMessage({
type: 'stop',
target: message.id
});
},
sendWithShiftEnter: props.sendWithShiftEnter,
streaming: props.streaming,
Expand Down
13 changes: 7 additions & 6 deletions packages/jupyter-ai/src/components/chat-input/send-button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ const FIX_TOOLTIP = '/fix requires an active code cell with an error';

export type SendButtonProps = {
onSend: (selection?: AiService.Selection) => unknown;
onStop: () => void;
onStop: (message: AiService.AgentStreamMessage) => void;
sendWithShiftEnter: boolean;
streaming: boolean;
streaming: AiService.AgentStreamMessage[];
currSlashCommand: string | null;
inputExists: boolean;
activeCellHasError: boolean;
Expand All @@ -37,7 +37,8 @@ export function SendButton(props: SendButtonProps): JSX.Element {
setMenuOpen(false);
}, []);

let disabled = props.streaming ? false : !props.inputExists;
const isStreaming = props.streaming.length !== 0;
let disabled = isStreaming ? false : !props.inputExists;
if (props.currSlashCommand === '/fix' && !props.activeCellHasError) {
disabled = true;
}
Expand All @@ -60,7 +61,7 @@ export function SendButton(props: SendButtonProps): JSX.Element {
const tooltip =
props.currSlashCommand === '/fix' && !props.activeCellHasError
? FIX_TOOLTIP
: props.streaming
: isStreaming
? 'Stop streaming'
: !props.inputExists
? 'Message must not be empty'
Expand Down Expand Up @@ -101,7 +102,7 @@ export function SendButton(props: SendButtonProps): JSX.Element {
return (
<Box sx={{ display: 'flex', flexWrap: 'nowrap' }}>
<TooltippedButton
onClick={() => (props.streaming ? props.onStop() : props.onSend())}
onClick={() => (isStreaming ? props.onStop(props.streaming[props.streaming.length - 1]) : props.onSend())}
disabled={disabled}
tooltip={tooltip}
buttonProps={{
Expand All @@ -114,7 +115,7 @@ export function SendButton(props: SendButtonProps): JSX.Element {
borderRadius: '2px 0px 0px 2px'
}}
>
{props.streaming ? <StopIcon /> : <SendIcon />}
{isStreaming ? <StopIcon /> : <SendIcon />}
</TooltippedButton>
<TooltippedButton
onClick={e => {
Expand Down
138 changes: 74 additions & 64 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import AddIcon from '@mui/icons-material/Add';
import type { Awareness } from 'y-protocols/awareness';
import type { IThemeManager } from '@jupyterlab/apputils';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import type { User } from '@jupyterlab/services';
import { ISignal } from '@lumino/signaling';

import { JlThemeProvider } from './jl-theme-provider';
Expand All @@ -28,6 +29,7 @@ import {
ActiveCellContextProvider,
ActiveCellManager
} from '../contexts/active-cell-context';
import { UserContextProvider, useUserContext } from '../contexts/user-context';
import { ScrollContainer } from './scroll-container';
import { TooltippedIconButton } from './mui-extras/tooltipped-icon-button';
import { TelemetryContextProvider } from '../contexts/telemetry-context';
Expand Down Expand Up @@ -76,6 +78,7 @@ function ChatBody({
getPersonaName(messages)
);
const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true);
const user = useUserContext();

/**
* Effect: fetch config on initial render
Expand Down Expand Up @@ -143,6 +146,13 @@ function ChatBody({
);
}

const myHumanMessageIds = new Set(messages.filter(m => m.type === 'human' && m.client.username === user?.identity.username).map(m => m.id));

const myStreamedMessages = messages.filter(m =>
m.type === 'agent-stream' && !m.complete
&& myHumanMessageIds.has(m.reply_to)
) as AiService.AgentStreamMessage[];

return (
<>
<ScrollContainer sx={{ flexGrow: 1 }}>
Expand All @@ -157,10 +167,7 @@ function ChatBody({
<ChatInput
chatHandler={chatHandler}
focusInputSignal={focusInputSignal}
streaming={
messages.filter(m => m.type === 'agent-stream' && !m.complete)
.length !== 0
}
streaming={myStreamedMessages}
sx={{
paddingLeft: 4,
paddingRight: 4,
Expand Down Expand Up @@ -188,6 +195,7 @@ export type ChatProps = {
focusInputSignal: ISignal<unknown, void>;
messageFooter: IJaiMessageFooter | null;
telemetryHandler: IJaiTelemetryHandler | null;
userManager: User.IManager
};

enum ChatView {
Expand All @@ -212,69 +220,71 @@ export function Chat(props: ChatProps): JSX.Element {
activeCellManager={props.activeCellManager}
>
<TelemetryContextProvider telemetryHandler={props.telemetryHandler}>
<Box
// root box should not include padding as it offsets the vertical
// scrollbar to the left
sx={{
width: '100%',
height: '100%',
boxSizing: 'border-box',
background: 'var(--jp-layout-color0)',
display: 'flex',
flexDirection: 'column'
}}
>
{/* top bar */}
<Box sx={{ display: 'flex', justifyContent: 'space-between' }}>
{view !== ChatView.Chat ? (
<IconButton onClick={() => setView(ChatView.Chat)}>
<ArrowBackIcon />
</IconButton>
) : (
<Box />
)}
{view === ChatView.Chat ? (
<Box sx={{ display: 'flex' }}>
{!showWelcomeMessage && (
<TooltippedIconButton
onClick={() =>
props.chatHandler.sendMessage({ type: 'clear' })
}
tooltip="New chat"
>
<AddIcon />
</TooltippedIconButton>
)}
<IconButton onClick={() => openSettingsView()}>
<SettingsIcon />
<UserContextProvider userManager={props.userManager} >
<Box
// root box should not include padding as it offsets the vertical
// scrollbar to the left
sx={{
width: '100%',
height: '100%',
boxSizing: 'border-box',
background: 'var(--jp-layout-color0)',
display: 'flex',
flexDirection: 'column'
}}
>
{/* top bar */}
<Box sx={{ display: 'flex', justifyContent: 'space-between' }}>
{view !== ChatView.Chat ? (
<IconButton onClick={() => setView(ChatView.Chat)}>
<ArrowBackIcon />
</IconButton>
</Box>
) : (
<Box />
) : (
<Box />
)}
{view === ChatView.Chat ? (
<Box sx={{ display: 'flex' }}>
{!showWelcomeMessage && (
<TooltippedIconButton
onClick={() =>
props.chatHandler.sendMessage({ type: 'clear' })
}
tooltip="New chat"
>
<AddIcon />
</TooltippedIconButton>
)}
<IconButton onClick={() => openSettingsView()}>
<SettingsIcon />
</IconButton>
</Box>
) : (
<Box />
)}
</Box>
{/* body */}
{view === ChatView.Chat && (
<ChatBody
chatHandler={props.chatHandler}
openSettingsView={openSettingsView}
showWelcomeMessage={showWelcomeMessage}
setShowWelcomeMessage={setShowWelcomeMessage}
rmRegistry={props.rmRegistry}
focusInputSignal={props.focusInputSignal}
messageFooter={props.messageFooter}
/>
)}
{view === ChatView.Settings && (
<ChatSettings
rmRegistry={props.rmRegistry}
completionProvider={props.completionProvider}
openInlineCompleterSettings={
props.openInlineCompleterSettings
}
/>
)}
</Box>
{/* body */}
{view === ChatView.Chat && (
<ChatBody
chatHandler={props.chatHandler}
openSettingsView={openSettingsView}
showWelcomeMessage={showWelcomeMessage}
setShowWelcomeMessage={setShowWelcomeMessage}
rmRegistry={props.rmRegistry}
focusInputSignal={props.focusInputSignal}
messageFooter={props.messageFooter}
/>
)}
{view === ChatView.Settings && (
<ChatSettings
rmRegistry={props.rmRegistry}
completionProvider={props.completionProvider}
openInlineCompleterSettings={
props.openInlineCompleterSettings
}
/>
)}
</Box>
</UserContextProvider>
</TelemetryContextProvider>
</ActiveCellContextProvider>
</CollaboratorsContextProvider>
Expand Down
39 changes: 39 additions & 0 deletions packages/jupyter-ai/src/contexts/user-context.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import React, { useContext, useEffect, useState } from 'react';
import type { User } from '@jupyterlab/services';
import { PartialJSONObject } from '@lumino/coreutils';

const UserContext = React.createContext<
User.IUser | null
>(null);


export function useUserContext(): User.IUser | null {
return useContext(UserContext);
}

type UserContextProviderProps = {
userManager: User.IManager;
children: React.ReactNode;
};

export function UserContextProvider({
userManager,
children
}: UserContextProviderProps): JSX.Element {
const [user, setUser] = useState<User.IUser | null>(null);

useEffect(() => {
userManager.ready.then(() => {
setUser({identity: userManager.identity!, permissions: userManager.permissions as PartialJSONObject});
});
userManager.userChanged.connect((sender, newUser) => {
setUser(newUser);
});
}, []);

return (
<UserContext.Provider value={user}>
{children}
</UserContext.Provider>
);
}
1 change: 1 addition & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ export namespace AiService {

export type StopRequest = {
type: 'stop';
target: string;
};

export type Collaborator = {
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ const plugin: JupyterFrontEndPlugin<IJaiCore> = {
activeCellManager,
focusInputSignal,
messageFooter,
telemetryHandler
telemetryHandler,
app.serviceManager.user
);
} catch (e) {
chatWidget = buildErrorWidget(themeManager);
Expand Down
Loading

0 comments on commit 8a94b19

Please sign in to comment.