Skip to content

Commit

Permalink
mulish
Browse files Browse the repository at this point in the history
  • Loading branch information
AyaanZaveri committed Dec 16, 2023
1 parent bb3eaed commit d02ca53
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 98 deletions.
148 changes: 60 additions & 88 deletions app/api/cog/completions/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { CallbackManager, ConsoleCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage } from "langchain/schema";
import { prompts } from "@/lib/prompts";
import { db } from "@/lib/db";
import { Prisma } from '@prisma/client';
import { Prisma } from "@prisma/client";
import { HuggingFaceInferenceEmbeddings } from "langchain/embeddings/hf";
import {
BytesOutputParser,
Expand Down Expand Up @@ -69,7 +69,7 @@ const embeddingsModel = new HuggingFaceInferenceEmbeddings({
const serializeDocs = (docs: Array<Document>) =>
docs.map((doc) => doc.pageContent).join("\n\n");

const runLLMChain = async (style: string, messages: any, id: string) => {
const getStuff = async (currentMessageContent: string, id: string) => {
const encoder = new TextEncoder();

const transformStream = new TransformStream();
Expand Down Expand Up @@ -105,37 +105,6 @@ const runLLMChain = async (style: string, messages: any, id: string) => {

console.log(additionalContext);

const model = new ChatOpenAI(
{
streaming: true,
verbose: true,
callbacks: [
{
async handleLLMNewToken(token) {
await writer.ready;
await writer.write(encoder.encode(`${token}`));
},
async handleLLMEnd() {
await writer.ready;
await writer.close();
},
},
],
temperature: 0.7,
openAIApiKey: process.env.NEXT_PUBLIC_OPENAI_API_KEY_CHAT,
// topP: 0.75,
maxTokens: 4000,
// modelName: "huggingfaceh4/zephyr-7b-beta",
modelName: "mistralai/mixtral-8x7b-instruct",
},
{
basePath: process.env.NEXT_PUBLIC_OPENAI_ENDPOINT_CHAT,
defaultHeaders: {
"HTTP-Referer": process.env.NEXTAUTH_URL,
},
}
);

const questionModel = new ChatOpenAI(
{
temperature: 0.5,
Expand All @@ -153,65 +122,18 @@ const runLLMChain = async (style: string, messages: any, id: string) => {

console.log("Created models");

const retriever = vectorStore.asRetriever();

const standaloneQuestionChain = RunnableSequence.from([
{
question: (input: ConversationalRetrievalQAChainInput) => input.question,
chat_history: (input: ConversationalRetrievalQAChainInput) =>
formatVercelMessages(input.chat_history),
},
PromptTemplate.fromTemplate(prompts[style].condense),
questionModel,
new StringOutputParser(),
]);

const answerChain = RunnableSequence.from([
{
question: new RunnablePassthrough(),
context: retriever.pipe(combineDocumentsFn),
},
PromptTemplate.fromTemplate(prompts[style].qa),
model,
new BytesOutputParser(),
]);

const chain = standaloneQuestionChain.pipe(answerChain);

// const chain = ConversationalRetrievalQAChain.fromLLM(
// streamingModel,
// vectorStore.asRetriever(),
// {
// verbose: true,
// returnSourceDocuments: true,
// qaChainOptions: {
// type: "stuff",
// prompt: PromptTemplate.fromTemplate(prompts[style].qa),
// },
// questionGeneratorChainOptions: {
// template: prompts[style].condense,
// llm: nonStreamingModel,
// },
// }
// );

const currentMessageContent = messages[messages.length - 1].content;
const previousMessages = messages.slice(0, -1);

chain.invoke({
question: currentMessageContent,
chat_history: previousMessages,
});
const similarDocs = await vectorStore.similaritySearch(
currentMessageContent,
7
);

return transformStream.readable;
return similarDocs;
};

export async function POST(req: Request) {
try {
const { id, messages, style } = await req.json();

const { stream, handlers } = LangChainStream();

console.log("Created vector store");

// teach me more
Expand All @@ -225,14 +147,64 @@ export async function POST(req: Request) {

console.log("Calling chain");

const theStream = await runLLMChain(style, messages, id);
const currentMessageContent = messages[messages.length - 1].content;

const similarDocs = await getStuff(currentMessageContent, id);

const model = new ChatOpenAI(
{
streaming: true,
verbose: true,
temperature: 0.7,
openAIApiKey: process.env.NEXT_PUBLIC_OPENAI_API_KEY_CHAT,
maxTokens: 4000,
modelName: "mistralai/mixtral-8x7b-instruct",
},
{
basePath: process.env.NEXT_PUBLIC_OPENAI_ENDPOINT_CHAT,
defaultHeaders: {
"HTTP-Referer": process.env.NEXTAUTH_URL,
},
}
);

const previousMessages = messages.slice(0, -1);

const proompt = `
{previousMessages}
Here is some context from a document, along with a question related to it.
<context>
{docs}
</context>
Question: {message}
Carefully heed the user's instructions.
Respond using Markdown.
Bold important words using **bold**.
Answer:
`;

const prompt = PromptTemplate.fromTemplate(proompt);

const outputParser = new StringOutputParser();

const chain = prompt.pipe(model).pipe(outputParser);

console.log("Called chain");

const stream = await chain.stream({
message: currentMessageContent,
previousMessages: previousMessages,
docs: combineDocumentsFn(similarDocs),
});

return new StreamingTextResponse(stream);

// return new StreamingTextResponse(stream);
// return stream as readable stream

return new Response(theStream);
} catch (error) {
// get the first 2000 characters of the error
const errorString = error!.toString().substring(0, 2000);
Expand Down
11 changes: 8 additions & 3 deletions components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ReactMarkdown from "react-markdown";
import { useChat } from "ai/react";
import ChatBox from "./ChatBox";
import { FormEvent, useState } from "react";
import { Space_Grotesk } from "next/font/google";
import { Mulish, Space_Grotesk } from "next/font/google";
import { ChatRequestOptions } from "ai";
import { cn } from "@/lib/utils";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
Expand All @@ -17,6 +17,11 @@ const space_grotesk = Space_Grotesk({
subsets: ["latin"],
});

const mulish = Mulish({
weight: ["300", "400", "500", "600", "700"],
subsets: ["latin"],
})

export default function Chat({ id }: { id: string }) {
const [isThinking, setIsThinking] = useState(false);
const [isStreaming, setIsStreaming] = useState(false);
Expand Down Expand Up @@ -98,7 +103,7 @@ export default function Chat({ id }: { id: string }) {
</Tabs>
</div>
)}
<div className="flex w-full flex-col gap-5 px-8">
<div className={`flex w-full flex-col gap-5 px-8 ${mulish.className}`}>
{messages.map((message) => (
<div
key={message.id}
Expand All @@ -111,7 +116,7 @@ export default function Chat({ id }: { id: string }) {
>
<span className="prose transition-all">
<ReactMarkdown remarkPlugins={[remarkGfm]}>
{message.content}
{message.content.replace(/\\n/g, "\n").replace(/<\/s>/g, "")}
</ReactMarkdown>
</span>
</div>
Expand Down
7 changes: 0 additions & 7 deletions lib/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ export const prompts: { [key: string]: Prompt } = {
qa: `
Carefully heed the user's instructions.
Respond using Markdown.
Here is some context from a document, along with a question related to it.
{context}
Question: {question}
=========
Answer:
`,
},
focused: {
Expand Down

0 comments on commit d02ca53

Please sign in to comment.