Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit history viewer #3322

Merged
merged 4 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backend/oasst_backend/api/v1/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,22 @@ def edit_tx(session: deps.Session):
edit_tx()


@router.get("/{message_id}/history", response_model=list[protocol.MessageRevision])
def get_revision_history(
*,
message_id: UUID,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get all revisions of this message sorted from oldest to most recent
"""
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
revisions = pr.fetch_message_revision_history(message_id)
return utils.prepare_message_revision_list(revisions)


@router.post("/{message_id}/emoji", status_code=HTTP_202_ACCEPTED)
def post_message_emoji(
*,
Expand Down
17 changes: 16 additions & 1 deletion backend/oasst_backend/api/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from uuid import UUID

from oasst_backend.models import Message
from oasst_backend.models import Message, MessageRevision
from oasst_shared.schemas import protocol


Expand Down Expand Up @@ -66,6 +66,21 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
return protocol.MessageTree(id=tree_id, messages=tree_messages)


def prepare_message_revision(revision: MessageRevision) -> protocol.MessageRevision:
return protocol.MessageRevision(
id=revision.id,
text=revision.payload.payload.text,
message_id=revision.message_id,
user_id=revision.user_id,
created_date=revision.created_date,
user_is_author=revision._user_is_author,
)


def prepare_message_revision_list(revisions: list[MessageRevision]) -> list[protocol.MessageRevision]:
return [prepare_message_revision(revision) for revision in revisions]


split_uuid_pattern = re.compile(
r"^([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})\$(.*)$"
)
3 changes: 3 additions & 0 deletions backend/oasst_backend/models/message_revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from pydantic import PrivateAttr
from sqlmodel import Field, SQLModel
from uuid_extensions import uuid7

Expand All @@ -23,3 +24,5 @@ class MessageRevision(SQLModel, table=True):
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
)

_user_is_author: Optional[bool] = PrivateAttr(default=None)
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,16 @@ def fetch_message_text_labels(self, message_id: UUID, user_id: Optional[UUID] =
query = query.filter(TextLabels.user_id == user_id)
return query.all()

def fetch_message_revision_history(self, message_id: UUID) -> list[MessageRevision]:
# the revisions are sorted by time using the uuid7 id
revisions: list[MessageRevision] = sorted(
self.db.query(MessageRevision).filter(MessageRevision.message_id == message_id).all(),
key=lambda revision: revision.id.int >> 80,
someone13574 marked this conversation as resolved.
Show resolved Hide resolved
)
for revision in revisions:
revision._user_is_author = self.user_id == revision.user_id
return revisions

@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
Expand Down
9 changes: 9 additions & 0 deletions oasst-shared/oasst_shared/schemas/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ class Message(ConversationMessage):
user: Optional[FrontEndUser]


class MessageRevision(BaseModel):
id: UUID
text: str
message_id: UUID
user_id: Optional[UUID]
created_date: Optional[datetime]
user_is_author: Optional[bool]


class MessagePage(PageResult):
items: list[Message]

Expand Down
19 changes: 19 additions & 0 deletions website/src/components/Messages/MessageCreateDate.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { Text, useColorModeValue } from "@chakra-ui/react";
import { useCurrentLocale } from "src/hooks/locale/useCurrentLocale";

export const MessageCreateDate = ({ date }: { date: string }) => {
const locale = useCurrentLocale();
const createdDateColor = useColorModeValue("blackAlpha.600", "gray.400");

return (
<Text as="span" fontSize="small" color={createdDateColor} fontWeight="medium" me={{ base: 3, md: 6 }}>
{new Intl.DateTimeFormat(locale, {
hour: "2-digit",
minute: "2-digit",
year: "numeric",
month: "2-digit",
day: "2-digit",
}).format(new Date(date))}
</Text>
);
};
85 changes: 85 additions & 0 deletions website/src/components/Messages/MessageHistoryTable.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { Badge, Flex, Stack, Tooltip } from "@chakra-ui/react";
import { boolean } from "boolean";
import { User } from "lucide-react";
import { useRouter } from "next/router";
import { useTranslation } from "next-i18next";
import { ROUTES } from "src/lib/routes";
import { Message, MessageRevision } from "src/types/Conversation";

import { BaseMessageEntry } from "./BaseMessageEntry";
import { MessageCreateDate } from "./MessageCreateDate";
import { BaseMessageEmojiButton } from "./MessageEmojiButton";
import { MessageInlineEmojiRow } from "./MessageInlineEmojiRow";

export interface MessageHistoryTableProps {
message: Message;
revisions: MessageRevision[];
}

export function MessageHistoryTable({ message, revisions }: MessageHistoryTableProps) {
const { t } = useTranslation(["message"]);
const router = useRouter();

return (
<Stack spacing={4}>
{(revisions.length === 0
? ([
{
text: message.text,
created_date: message.created_date,
user_id: message.user_id,
user_is_author: message.user_is_author,
},
] as Omit<MessageRevision, "id" | "message_id">[])
: (revisions.map((revision) => ({
text: revision.text,
created_date: revision.created_date,
user_id: revision.user_id,
user_is_author: revision.user_is_author,
})) as Omit<MessageRevision, "id" | "message_id">[])
).map(({ text, created_date, user_id, user_is_author }, index, array) => (
<BaseMessageEntry
key={`version-${index}`}
content={text}
avatarProps={{
name: `${boolean(message.is_assistant) ? "Assistant" : "User"}`,
src: `${boolean(message.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`,
}}
highlight={index === array.length - 1}
>
<Flex justifyContent={"space-between"} marginTop={2} alignItems={"center"}>
<MessageCreateDate date={created_date} />
<MessageInlineEmojiRow>
<BaseMessageEmojiButton
emoji={User}
label="Manage User"
onClick={() => router.push(ROUTES.ADMIN_USER_DETAIL(user_id))}
/>
</MessageInlineEmojiRow>
</Flex>
<Flex
position={"absolute"}
gap="2"
top="-2.5"
style={{
insetInlineEnd: "1.25rem",
}}
>
{index === 0 && (
<Tooltip label={"This is the original version of this message"} placement="top">
<Badge colorScheme={"blue"}>Original</Badge>
</Tooltip>
)}
{user_is_author && (
<Tooltip label={t("message_author_explain")} placement="top">
<Badge size="sm" colorScheme="green" textTransform="capitalize">
{t("message_author")}
</Badge>
</Tooltip>
)}
</Flex>
</BaseMessageEntry>
))}
</Stack>
);
}
23 changes: 2 additions & 21 deletions website/src/components/Messages/MessageTableEntry.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import {
MenuList,
Portal,
SimpleGrid,
Text,
Tooltip,
useColorModeValue,
useDisclosure,
Expand Down Expand Up @@ -39,7 +38,6 @@ import { LabelMessagePopup } from "src/components/Messages/LabelPopup";
import { MessageEmojiButton } from "src/components/Messages/MessageEmojiButton";
import { ReportPopup } from "src/components/Messages/ReportPopup";
import { useHasAnyRole } from "src/hooks/auth/useHasAnyRole";
import { useCurrentLocale } from "src/hooks/locale/useCurrentLocale";
import { useDeleteMessage } from "src/hooks/message/useDeleteMessage";
import { post, put } from "src/lib/api";
import { ROUTES } from "src/lib/routes";
Expand All @@ -52,6 +50,7 @@ import { useUndeleteMessage } from "../../hooks/message/useUndeleteMessage";
import { BaseMessageEntry } from "./BaseMessageEntry";
import { MessageInlineEmojiRow } from "./MessageInlineEmojiRow";
import { MessageSyntheticBadge } from "./MessageSyntheticBadge";
import { MessageCreateDate } from "./MessageCreateDate";

interface MessageTableEntryProps {
message: Message;
Expand Down Expand Up @@ -113,7 +112,7 @@ export const MessageTableEntry = forwardRef<HTMLDivElement, MessageTableEntryPro
>
<Flex justifyContent="space-between" mt="2" alignItems="center">
{showCreatedDate ? (
<MessageCreateDate date={message.created_date}></MessageCreateDate>
<MessageCreateDate date={message.created_date} />
) : (
// empty span is required to make emoji displayed at the end of row
<span></span>
Expand Down Expand Up @@ -180,24 +179,6 @@ export const MessageTableEntry = forwardRef<HTMLDivElement, MessageTableEntryPro
}
);

const me = { base: 3, md: 6 };

const MessageCreateDate = ({ date }: { date: string }) => {
const locale = useCurrentLocale();
const createdDateColor = useColorModeValue("blackAlpha.600", "gray.400");
return (
<Text as="span" fontSize="small" color={createdDateColor} fontWeight="medium" me={me}>
{new Intl.DateTimeFormat(locale, {
hour: "2-digit",
minute: "2-digit",
year: "numeric",
month: "2-digit",
day: "2-digit",
}).format(new Date(date))}
</Text>
);
};

const EmojiMenuItem = ({
emoji,
checked,
Expand Down
9 changes: 8 additions & 1 deletion website/src/lib/oasst_api_client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { EmojiOp, FetchMessagesCursorResponse, Message } from "src/types/Conversation";
import type { EmojiOp, FetchMessagesCursorResponse, Message, MessageRevision } from "src/types/Conversation";
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import { Stats } from "src/types/Stat";
import type { AvailableTasks } from "src/types/Task";
Expand Down Expand Up @@ -208,6 +208,13 @@ export class OasstApiClient {
}>(`/api/v1/messages/${message_id}/tree/state`);
}

/**
* Returns a list of revisions assoicated with `message_id`.
*/
async fetch_message_revision_history(message_id: string): Promise<MessageRevision[]> {
return this.get<MessageRevision[]>(`/api/v1/messages/${message_id}/history`);
}

/**
* Delete a message by its id
*/
Expand Down
21 changes: 19 additions & 2 deletions website/src/pages/admin/messages/[id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ export { getServerSideProps } from "src/lib/defaultServerSideProps";
import { AdminArea } from "src/components/AdminArea";
import { JsonCard } from "src/components/JsonCard";
import { AdminLayout } from "src/components/Layout";
import { MessageHistoryTable } from "src/components/Messages/MessageHistoryTable";
import { MessageTree } from "src/components/Messages/MessageTree";
import { get } from "src/lib/api";
import { Message, MessageWithChildren } from "src/types/Conversation";
import { Message, MessageRevision, MessageWithChildren } from "src/types/Conversation";
import useSWRImmutable from "swr/immutable";

const MessageDetail = () => {
Expand All @@ -42,16 +43,24 @@ const MessageDetail = () => {
}>(`/api/admin/messages/${messageId}/tree`, get, {
keepPreviousData: true,
});
const {
data: revisions,
isLoading: revisionsLoading,
error: revisionError,
} = useSWRImmutable<MessageRevision[]>(`/api/admin/messages/${messageId}/history`, get, { keepPreviousData: true });

return (
<>
<Head>
<title>Open Assistant</title>
</Head>
<AdminArea>
{isLoading && !data && <CircularProgress isIndeterminate></CircularProgress>}
{(isLoading && !data) ||
(revisionsLoading && !revisions && <CircularProgress isIndeterminate></CircularProgress>)}
{error && "Unable to load message tree"}
{revisionError && "Unable to load message revision history"}
{data &&
revisions &&
(data.tree === null ? (
"Unable to build tree"
) : (
Expand All @@ -64,6 +73,14 @@ const MessageDetail = () => {
<JsonCard>{data.message}</JsonCard>
</CardBody>
</Card>
<Card>
<CardHeader fontWeight="bold" fontSize="xl" pb="0">
Message History
</CardHeader>
<CardBody>
<MessageHistoryTable message={data?.message} revisions={revisions} />
</CardBody>
</Card>
<Card>
<CardHeader fontWeight="bold" fontSize="xl" pb="0">
Tree {data.tree.id}
Expand Down
13 changes: 13 additions & 0 deletions website/src/pages/api/admin/messages/[id]/history.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { withAnyRole } from "src/lib/auth";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";

export default withAnyRole(["moderator", "admin"], async (req, res, token) => {
const { id } = req.query;

const user = await getBackendUserCore(token.sub);
const client = createApiClientFromUser(user);

const revision_history = await client.fetch_message_revision_history(id as string);
res.status(200).json(revision_history);
});
9 changes: 9 additions & 0 deletions website/src/types/Conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ export interface Message extends MessageEmojis {
user: BackendUser | null;
}

export interface MessageRevision {
id: string;
text: string;
message_id: string;
user_id: string;
created_date: string;
user_is_author: boolean;
}

export interface Conversation {
messages: Message[];
}
Expand Down