Skip to content

Commit

Permalink
add stuff chain and qa chain (#12)
Browse files Browse the repository at this point in the history
* cr

* cr
  • Loading branch information
hwchase17 authored Feb 15, 2023
1 parent cbeddea commit 7f23be1
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 6 deletions.
10 changes: 8 additions & 2 deletions langchain/chains/base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { LLMChain, SerializedLLMChain } from "./index";
import { LLMChain, StuffDocumentsChain } from "./index";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ChainValues = Record<string, any>;

type SerializedBaseChain = SerializedLLMChain;
const chainClasses = [LLMChain, StuffDocumentsChain];

export type SerializedBaseChain = ReturnType<
InstanceType<(typeof chainClasses)[number]>["serialize"]
>;

export abstract class BaseChain {
abstract _call(values: ChainValues): Promise<ChainValues>;
Expand All @@ -24,6 +28,8 @@ export abstract class BaseChain {
switch (data._type) {
case "llm_chain":
return LLMChain.deserialize(data);
case "stuff_documents_chain":
return StuffDocumentsChain.deserialize(data);
default:
throw new Error(
`Invalid prompt type in config: ${
Expand Down
76 changes: 76 additions & 0 deletions langchain/chains/combine_docs_chain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { BaseChain, ChainValues, LLMChain, SerializedLLMChain } from "./index";

import { Document } from "../document";

import { resolveConfigFromFile } from "../util";

export interface StuffDocumentsChainInput {
llmChain: LLMChain;
inputKey: string;
outputKey: string;
documentVariableName: string;
}

export type SerializedStuffDocumentsChain = {
_type: "stuff_documents_chain";
llm_chain?: SerializedLLMChain;
llm_chain_path?: string;
};

export class StuffDocumentsChain extends BaseChain implements StuffDocumentsChainInput {
llmChain: LLMChain;

inputKey = "input_documents";

outputKey = "output_text";

documentVariableName = "context";

constructor(fields: {
llmChain: LLMChain;
inputKey?: string;
outputKey?: string;
documentVariableName?: string;
}) {
super();
this.llmChain = fields.llmChain;
this.documentVariableName = fields.documentVariableName ?? this.documentVariableName;
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
}

async _call(values: ChainValues): Promise<ChainValues> {
if (!(this.inputKey in values)) {
throw new Error(`Document key ${ this.inputKey } not found.`);
}
const docs: Document[] = values[this.inputKey];
const texts = docs.map(({ pageContent }) => pageContent);
const text = texts.join("\n\n");
delete values[this.inputKey];
values[this.documentVariableName] = text;
const result = await this.llmChain.call(values);
return result;
}

_chainType() {
return "stuff_documents_chain" as const;
}

static async deserialize(data: SerializedStuffDocumentsChain) {
const SerializedLLMChain = resolveConfigFromFile<"llm_chain", SerializedLLMChain>(
"llm_chain",
data
);

return new StuffDocumentsChain({
llmChain: await LLMChain.deserialize(SerializedLLMChain),
});
}

serialize(): SerializedStuffDocumentsChain {
return {
_type: this._chainType(),
llm_chain: this.llmChain.serialize(),
};
}
}
2 changes: 2 additions & 0 deletions langchain/chains/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export { BaseChain, ChainValues } from "./base";
export { SerializedLLMChain, LLMChain } from "./llm_chain";
export { SerializedStuffDocumentsChain, StuffDocumentsChain } from "./combine_docs_chain";
export { loadChain } from "./load";
export { loadQAChain } from "./question_answering/load";
11 changes: 11 additions & 0 deletions langchain/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { BaseLLM } from "../../llms";
import { LLMChain } from "../llm_chain";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { prompt } from "./stuff_prompts";


export const loadQAChain = (llm: BaseLLM) => {
const llmChain = new LLMChain({ prompt, llm });
const chain = new StuffDocumentsChain({llmChain});
return chain;
};
8 changes: 8 additions & 0 deletions langchain/chains/question_answering/stuff_prompts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/* eslint-disable */
import { PromptTemplate } from "../../prompt";

export const prompt = new PromptTemplate({
template: "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:",
inputVariables: ["context", "question"],
});

12 changes: 12 additions & 0 deletions langchain/chains/question_answering/tests/load.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../../llms/openai";
import { loadQAChain } from "../load";
import { Document } from "../../../document";

test("Test loadQAChain", async () => {
const model = new OpenAI({});
const chain = loadQAChain(model);
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs, question: "Whats up" });
console.log({ res });
});
27 changes: 27 additions & 0 deletions langchain/chains/tests/combine_docs_chain.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../llms/openai";
import { PromptTemplate } from "../../prompt";
import { LLMChain } from "../llm_chain";
import { loadChain } from "../load";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { Document } from "../../document";

test("Test StuffDocumentsChain", async () => {
const model = new OpenAI({});
const prompt = new PromptTemplate({
template: "Print {foo}",
inputVariables: ["foo"],
});
const llmChain = new LLMChain({ prompt, llm: model });
const chain = new StuffDocumentsChain({ llmChain, documentVariableName: "foo"});
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs });
console.log({ res });
});

test("Load chain from hub", async () => {
const chain = await loadChain("lc://chains/question_answering/stuff/chain.json");
const docs = [ new Document({pageContent: 'foo' }), new Document({pageContent: 'bar' }), new Document({pageContent: 'baz' }), ];
const res = await chain.call({ input_documents: docs, question: "what up" });
console.log({ res });
});
2 changes: 1 addition & 1 deletion langchain/text_splitter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ Bye!\n\n-H.`;
"Bye!\n\n-H.",
];
expect(output).toEqual(expectedOutput);
})
});
6 changes: 3 additions & 3 deletions langchain/text_splitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ interface TextSplitterParams {
}

abstract class TextSplitter implements TextSplitterParams {
chunkSize: number = 1000;
chunkSize = 1000;

chunkOverlap: number = 200;
chunkOverlap = 200;

constructor(fields?: Partial<TextSplitterParams>) {
this.chunkSize = fields?.chunkSize ?? this.chunkSize;
Expand Down Expand Up @@ -97,7 +97,7 @@ export interface CharacterTextSplitterParams extends TextSplitterParams {
}

export class CharacterTextSplitter extends TextSplitter implements CharacterTextSplitterParams{
separator: string = "\n\n";
separator = "\n\n";

constructor(fields?: Partial<CharacterTextSplitterParams>) {
super(fields);
Expand Down

0 comments on commit 7f23be1

Please sign in to comment.