Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stuff chain and qa chain #12

Merged
merged 2 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions langchain/chains/base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { LLMChain, SerializedLLMChain } from "./index";
import { LLMChain, StuffDocumentsChain } from "./index";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ChainValues = Record<string, any>;

type SerializedBaseChain = SerializedLLMChain;
const chainClasses = [LLMChain, StuffDocumentsChain];

export type SerializedBaseChain = ReturnType<
InstanceType<(typeof chainClasses)[number]>["serialize"]
>;

export abstract class BaseChain {
abstract _call(values: ChainValues): Promise<ChainValues>;
Expand All @@ -24,6 +28,8 @@ export abstract class BaseChain {
switch (data._type) {
case "llm_chain":
return LLMChain.deserialize(data);
case "stuff_documents_chain":
return StuffDocumentsChain.deserialize(data);
default:
throw new Error(
`Invalid prompt type in config: ${
Expand Down
76 changes: 76 additions & 0 deletions langchain/chains/combine_docs_chain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { BaseChain, ChainValues, LLMChain, SerializedLLMChain } from "./index";

import { Document } from "../document";

import { resolveConfigFromFile } from "../util";

export interface StuffDocumentsChainInput {
llmChain: LLMChain;
inputKey: string;
outputKey: string;
documentVariableName: string;
}

export type SerializedStuffDocumentsChain = {
_type: "stuff_documents_chain";
llm_chain?: SerializedLLMChain;
llm_chain_path?: string;
};

export class StuffDocumentsChain extends BaseChain implements StuffDocumentsChainInput {
llmChain: LLMChain;

inputKey = "input_documents";

outputKey = "output_text";

documentVariableName = "context";

constructor(fields: {
llmChain: LLMChain;
inputKey?: string;
outputKey?: string;
documentVariableName?: string;
}) {
super();
this.llmChain = fields.llmChain;
this.documentVariableName = fields.documentVariableName ?? this.documentVariableName;
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
}

async _call(values: ChainValues): Promise<ChainValues> {
if (!(this.inputKey in values)) {
throw new Error(`Document key ${ this.inputKey } not found.`);
}
const docs: Document[] = values[this.inputKey];
const texts = docs.map(({ pageContent }) => pageContent);
const text = texts.join("\n\n");
delete values[this.inputKey];
values[this.documentVariableName] = text;
const result = await this.llmChain.call(values);
return result;
}

_chainType() {
return "stuff_documents_chain" as const;
}

static async deserialize(data: SerializedStuffDocumentsChain) {
const SerializedLLMChain = resolveConfigFromFile<"llm_chain", SerializedLLMChain>(
"llm_chain",
data
);

return new StuffDocumentsChain({
llmChain: await LLMChain.deserialize(SerializedLLMChain),
});
}

serialize(): SerializedStuffDocumentsChain {
return {
_type: this._chainType(),
llm_chain: this.llmChain.serialize(),
};
}
}
2 changes: 2 additions & 0 deletions langchain/chains/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export { BaseChain, ChainValues } from "./base";
export { SerializedLLMChain, LLMChain } from "./llm_chain";
export { SerializedStuffDocumentsChain, StuffDocumentsChain } from "./combine_docs_chain";
export { loadChain } from "./load";
export { loadQAChain } from "./question_answering/load";
11 changes: 11 additions & 0 deletions langchain/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { BaseLLM } from "../../llms";
import { LLMChain } from "../llm_chain";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { prompt } from "./stuff_prompts";


export const loadQAChain = (llm: BaseLLM) => {
const llmChain = new LLMChain({ prompt, llm });
const chain = new StuffDocumentsChain({llmChain});
return chain;
};
8 changes: 8 additions & 0 deletions langchain/chains/question_answering/stuff_prompts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/* eslint-disable */
import { PromptTemplate } from "../../prompt";

export const prompt = new PromptTemplate({
template: "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:",
inputVariables: ["context", "question"],
});

12 changes: 12 additions & 0 deletions langchain/chains/question_answering/tests/load.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../../llms/openai";
import { loadQAChain } from "../load";
import { Document } from "../../../document";

test("Test loadQAChain", async () => {
const model = new OpenAI({});
const chain = loadQAChain(model);
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs, question: "Whats up" });
console.log({ res });
});
27 changes: 27 additions & 0 deletions langchain/chains/tests/combine_docs_chain.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../llms/openai";
import { PromptTemplate } from "../../prompt";
import { LLMChain } from "../llm_chain";
import { loadChain } from "../load";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { Document } from "../../document";

test("Test StuffDocumentsChain", async () => {
const model = new OpenAI({});
const prompt = new PromptTemplate({
template: "Print {foo}",
inputVariables: ["foo"],
});
const llmChain = new LLMChain({ prompt, llm: model });
const chain = new StuffDocumentsChain({ llmChain, documentVariableName: "foo"});
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs });
console.log({ res });
});

test("Load chain from hub", async () => {
const chain = await loadChain("lc://chains/question_answering/stuff/chain.json");
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs, question: "what up" });
console.log({ res });
});
2 changes: 1 addition & 1 deletion langchain/text_splitter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ Bye!\n\n-H.`;
"Bye!\n\n-H.",
];
expect(output).toEqual(expectedOutput);
})
});
6 changes: 3 additions & 3 deletions langchain/text_splitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ interface TextSplitterParams {
}

abstract class TextSplitter implements TextSplitterParams {
chunkSize: number = 1000;
chunkSize = 1000;

chunkOverlap: number = 200;
chunkOverlap = 200;

constructor(fields?: Partial<TextSplitterParams>) {
this.chunkSize = fields?.chunkSize ?? this.chunkSize;
Expand Down Expand Up @@ -97,7 +97,7 @@ export interface CharacterTextSplitterParams extends TextSplitterParams {
}

export class CharacterTextSplitter extends TextSplitter implements CharacterTextSplitterParams{
separator: string = "\n\n";
separator = "\n\n";

constructor(fields?: Partial<CharacterTextSplitterParams>) {
super(fields);
Expand Down