-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
community[major]: DeepInfra llm and chat #5672
Changes from 25 commits
d48eb29
f361789
09b9fee
c778987
b50fa6b
945f7b9
87e5977
d50c600
76242e8
29707e7
e2f2f50
e495ed1
69bb8ff
ce10418
90ee4fd
fafea9b
08929fe
936bdd4
69ef80a
49ac1a2
699ec9e
4e2355f
98f34c5
a791836
1ce74c4
765e75e
f5059dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import { DeepInfraLLM } from "@langchain/community/llms/deepinfra"; | ||
|
||
const apiKey = process.env.DEEPINFRA_API_TOKEN; | ||
const model = "meta-llama/Meta-Llama-3-70B-Instruct"; | ||
|
||
export const run = async () => { | ||
const llm = new DeepInfraLLM({ | ||
temperature: 0.7, | ||
maxTokens: 20, | ||
model, | ||
apiKey, | ||
maxRetries: 5, | ||
}); | ||
const res = await llm.invoke( | ||
"Question: What is the next step in the process of making a good game?\nAnswer:" | ||
); | ||
console.log({ res }); | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import { ChatDeepInfra } from "@langchain/community/chat_models/deepinfra"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey there! I've reviewed the code changes, and it looks like the addition of accessing the environment variable |
||
import { HumanMessage } from "@langchain/core/messages"; | ||
|
||
const apiKey = process.env.DEEPINFRA_API_TOKEN; | ||
|
||
const model = "meta-llama/Meta-Llama-3-70B-Instruct"; | ||
|
||
const chat = new ChatDeepInfra({ | ||
model, | ||
apiKey, | ||
}); | ||
|
||
const messages = [new HumanMessage("Hello")]; | ||
|
||
chat.invoke(messages).then((response: any) => { | ||
console.log(response); | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import { DeepInfraLLM } from "@langchain/community/llms/deepinfra"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey there! I've reviewed the code and flagged a change related to accessing environment variables for the maintainers to review. Please take a look at the comment for more details. Let me know if you need further assistance! |
||
|
||
const apiKey = process.env.DEEPINFRA_API_TOKEN; | ||
const model = "meta-llama/Meta-Llama-3-70B-Instruct"; | ||
|
||
const llm = new DeepInfraLLM({ | ||
maxTokens: 20, | ||
model, | ||
apiKey, | ||
}); | ||
const res = await llm.invoke( | ||
"What is the next step in the process of making a good game?" | ||
); | ||
console.log({ res }); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey team, I've reviewed the code and noticed that the new changes introduce a net-new HTTP request using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey team, I've reviewed the code changes and flagged a specific section for your attention. The added code appears to explicitly access and require an environment variable using the |
||
BaseChatModel, | ||
type BaseChatModelParams, | ||
} from "@langchain/core/language_models/chat_models"; | ||
import { AIMessage, type BaseMessage } from "@langchain/core/messages"; | ||
import { type ChatResult } from "@langchain/core/outputs"; | ||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
|
||
export const DEFAULT_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"; | ||
|
||
export type DeepInfraMessageRole = "system" | "assistant" | "user"; | ||
|
||
export const API_BASE_URL = | ||
"https://api.deepinfra.com/v1/openai/chat/completions"; | ||
|
||
export const ENV_VARIABLE_API_KEY = "DEEPINFRA_API_TOKEN"; | ||
|
||
interface DeepInfraMessage { | ||
role: DeepInfraMessageRole; | ||
content: string; | ||
} | ||
|
||
interface ChatCompletionRequest { | ||
model: string; | ||
messages?: DeepInfraMessage[]; | ||
stream?: boolean; | ||
max_tokens?: number | null; | ||
temperature?: number | null; | ||
} | ||
|
||
interface BaseResponse { | ||
code?: string; | ||
message?: string; | ||
} | ||
|
||
interface ChoiceMessage { | ||
role: string; | ||
content: string; | ||
} | ||
|
||
interface ResponseChoice { | ||
index: number; | ||
finish_reason: "stop" | "length" | "null" | null; | ||
delta: ChoiceMessage; | ||
message: ChoiceMessage; | ||
} | ||
|
||
interface ChatCompletionResponse extends BaseResponse { | ||
choices: ResponseChoice[]; | ||
usage: { | ||
completion_tokens: number; | ||
prompt_tokens: number; | ||
total_tokens: number; | ||
}; | ||
output: { | ||
text: string; | ||
finish_reason: "stop" | "length" | "null" | null; | ||
}; | ||
} | ||
|
||
export interface ChatDeepInfraParams { | ||
model: string; | ||
apiKey?: string; | ||
temperature?: number; | ||
maxTokens?: number; | ||
} | ||
|
||
function messageToRole(message: BaseMessage): DeepInfraMessageRole { | ||
const type = message._getType(); | ||
switch (type) { | ||
case "ai": | ||
return "assistant"; | ||
case "human": | ||
return "user"; | ||
case "system": | ||
return "system"; | ||
default: | ||
throw new Error(`Unknown message type: ${type}`); | ||
} | ||
} | ||
|
||
export class ChatDeepInfra | ||
extends BaseChatModel | ||
implements ChatDeepInfraParams | ||
{ | ||
static lc_name() { | ||
return "ChatDeepInfra"; | ||
} | ||
|
||
get callKeys() { | ||
return ["stop", "signal", "options"]; | ||
} | ||
|
||
apiKey?: string; | ||
|
||
model: string; | ||
|
||
apiUrl: string; | ||
|
||
maxTokens?: number; | ||
|
||
temperature?: number; | ||
|
||
constructor(fields: Partial<ChatDeepInfraParams> & BaseChatModelParams = {}) { | ||
super(fields); | ||
|
||
this.apiKey = | ||
fields?.apiKey ?? getEnvironmentVariable(ENV_VARIABLE_API_KEY); | ||
if (!this.apiKey) { | ||
throw new Error( | ||
"API key is required, set `DEEPINFRA_API_TOKEN` environment variable or pass it as a parameter" | ||
); | ||
} | ||
|
||
this.apiUrl = API_BASE_URL; | ||
this.model = fields.model ?? DEFAULT_MODEL; | ||
this.temperature = fields.temperature ?? 0; | ||
this.maxTokens = fields.maxTokens; | ||
} | ||
|
||
invocationParams(): Omit<ChatCompletionRequest, "messages"> { | ||
return { | ||
model: this.model, | ||
stream: false, | ||
temperature: this.temperature, | ||
max_tokens: this.maxTokens, | ||
}; | ||
} | ||
|
||
identifyingParams(): Omit<ChatCompletionRequest, "messages"> { | ||
return this.invocationParams(); | ||
} | ||
|
||
async _generate( | ||
messages: BaseMessage[], | ||
options?: this["ParsedCallOptions"] | ||
): Promise<ChatResult> { | ||
const parameters = this.invocationParams(); | ||
|
||
const messagesMapped: DeepInfraMessage[] = messages.map((message) => ({ | ||
role: messageToRole(message), | ||
content: message.content as string, | ||
})); | ||
|
||
const data = await this.completionWithRetry( | ||
{ ...parameters, messages: messagesMapped }, | ||
false, | ||
options?.signal | ||
).then<ChatCompletionResponse>((data) => { | ||
|
||
if (data?.code) { | ||
throw new Error(data?.message); | ||
} | ||
const { finish_reason, message } = data.choices[0]; | ||
const text = message.content; | ||
return { | ||
...data, | ||
output: { text, finish_reason }, | ||
}; | ||
}); | ||
|
||
const { | ||
prompt_tokens = 0, | ||
completion_tokens = 0, | ||
total_tokens = 0, | ||
} = data.usage ?? {}; | ||
|
||
const { text } = data.output; | ||
|
||
return { | ||
generations: [{ text, message: new AIMessage(text) }], | ||
llmOutput: { | ||
tokenUsage: { | ||
promptTokens: prompt_tokens, | ||
completionTokens: completion_tokens, | ||
totalTokens: total_tokens, | ||
}, | ||
}, | ||
}; | ||
} | ||
|
||
async completionWithRetry( | ||
request: ChatCompletionRequest, | ||
stream: boolean, | ||
signal?: AbortSignal | ||
) { | ||
|
||
const body = { | ||
temperature: this.temperature, | ||
max_tokens: this.maxTokens, | ||
...request, | ||
model: this.model, | ||
}; | ||
|
||
const makeCompletionRequest = async () => { | ||
const response = await fetch(this.apiUrl, { | ||
method: "POST", | ||
headers: { | ||
Authorization: `Bearer ${this.apiKey}`, | ||
"Content-Type": "application/json", | ||
}, | ||
body: JSON.stringify(body), | ||
signal, | ||
}); | ||
|
||
if (!stream) { | ||
return response.json(); | ||
} | ||
}; | ||
|
||
return this.caller.call(makeCompletionRequest); | ||
} | ||
|
||
_llmType(): string { | ||
return "DeepInfra"; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import { test } from "@jest/globals"; | ||
import { HumanMessage } from "@langchain/core/messages"; | ||
import { ChatDeepInfra } from "../deepinfra.js"; | ||
|
||
describe("ChatDeepInfra", () => { | ||
test("call", async () => { | ||
const deepInfraChat = new ChatDeepInfra({ maxTokens: 20 }); | ||
const message = new HumanMessage("1 + 1 = "); | ||
const res = await deepInfraChat.invoke([message]); | ||
console.log({ res }); | ||
}); | ||
|
||
test("generate", async () => { | ||
const deepInfraChat = new ChatDeepInfra({ maxTokens: 20 }); | ||
const message = new HumanMessage("1 + 1 = "); | ||
const res = await deepInfraChat.generate([[message]]); | ||
console.log(JSON.stringify(res, null, 2)); | ||
}); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey team, just a heads up that I've flagged a change in the PR for review. The added code accesses an environment variable using
process.env
, so it's important to ensure proper handling and security of environment variables.