From 3b36a33ee2103e5a3ad4fa34a5347a6d6b43b809 Mon Sep 17 00:00:00 2001 From: Giorgos Vasileiou Date: Tue, 3 Dec 2024 21:15:10 +0200 Subject: [PATCH] feat(community): Add property extraction for Nodes and Relationships in LLMGraphTransformer (#7256) Co-authored-by: Jacob Lee --- .../graph_transformers/llm.int.test.ts | 91 +++++++++-- .../experimental/graph_transformers/llm.ts | 149 ++++++++++++++---- 2 files changed, 199 insertions(+), 41 deletions(-) diff --git a/libs/langchain-community/src/experimental/graph_transformers/llm.int.test.ts b/libs/langchain-community/src/experimental/graph_transformers/llm.int.test.ts index c85fad11e763..2423c8cef210 100644 --- a/libs/langchain-community/src/experimental/graph_transformers/llm.int.test.ts +++ b/libs/langchain-community/src/experimental/graph_transformers/llm.int.test.ts @@ -10,7 +10,7 @@ import { test.skip("convertToGraphDocuments", async () => { const model = new ChatOpenAI({ temperature: 0, - modelName: "gpt-4-turbo-preview", + modelName: "gpt-4o-mini", }); const llmGraphTransformer = new LLMGraphTransformer({ @@ -22,14 +22,12 @@ test.skip("convertToGraphDocuments", async () => { const result = await llmGraphTransformer.convertToGraphDocuments([ new Document({ pageContent: "Elon Musk is suing OpenAI" }), ]); - - // console.log(result); }); test("convertToGraphDocuments with allowed", async () => { const model = new ChatOpenAI({ temperature: 0, - modelName: "gpt-4-turbo-preview", + modelName: "gpt-4o-mini", }); const llmGraphTransformer = new LLMGraphTransformer({ @@ -42,8 +40,6 @@ test("convertToGraphDocuments with allowed", async () => { new Document({ pageContent: "Elon Musk is suing OpenAI" }), ]); - // console.log(JSON.stringify(result)); - expect(result).toEqual([ new GraphDocument({ nodes: [ @@ -68,7 +64,7 @@ test("convertToGraphDocuments with allowed", async () => { test("convertToGraphDocuments with allowed lowercased", async () => { const model = new ChatOpenAI({ temperature: 0, - modelName: "gpt-4-turbo-preview", + modelName: "gpt-4o-mini", }); const llmGraphTransformer = new LLMGraphTransformer({ @@ -81,8 +77,6 @@ test("convertToGraphDocuments with allowed lowercased", async () => { new Document({ pageContent: "Elon Musk is suing OpenAI" }), ]); - // console.log(JSON.stringify(result)); - expect(result).toEqual([ new GraphDocument({ nodes: [ @@ -103,3 +97,82 @@ test("convertToGraphDocuments with allowed lowercased", async () => { }), ]); }); + +test("convertToGraphDocuments with node properties", async () => { + const model = new ChatOpenAI({ + temperature: 0, + modelName: "gpt-4o-mini", + }); + + const llmGraphTransformer = new LLMGraphTransformer({ + llm: model, + allowedNodes: ["Person"], + allowedRelationships: ["KNOWS"], + nodeProperties: ["age", "country"], + }); + + const result = await llmGraphTransformer.convertToGraphDocuments([ + new Document({ pageContent: "John is 30 years old and lives in Spain" }), + ]); + + expect(result).toEqual([ + new GraphDocument({ + nodes: [ + new Node({ + id: "John", + type: "Person", + properties: { + age: "30", + country: "Spain", + }, + }), + ], + relationships: [], + source: new Document({ + pageContent: "John is 30 years old and lives in Spain", + metadata: {}, + }), + }), + ]); +}); + +test("convertToGraphDocuments with relationship properties", async () => { + const model = new ChatOpenAI({ + temperature: 0, + modelName: "gpt-4o-mini", + }); + + const llmGraphTransformer = new LLMGraphTransformer({ + llm: model, + allowedNodes: ["Person"], + allowedRelationships: ["KNOWS"], + relationshipProperties: ["since"], + }); + + const result = await llmGraphTransformer.convertToGraphDocuments([ + new Document({ pageContent: "John has known Mary since 2020" }), + ]); + + expect(result).toEqual([ + new GraphDocument({ + nodes: [ + new Node({ id: "John", type: "Person" }), + new Node({ id: "Mary", type: "Person" }), + ], + relationships: [ + new Relationship({ + source: new Node({ id: "John", type: "Person" }), + target: new Node({ id: "Mary", type: "Person" }), + type: "KNOWS", + properties: { + since: "2020", + }, + }), + ], + source: new Document({ + pageContent: "John has known Mary since 2020", + metadata: {}, + }), + }), + ]); +}); diff --git a/libs/langchain-community/src/experimental/graph_transformers/llm.ts b/libs/langchain-community/src/experimental/graph_transformers/llm.ts index 41167e09ad6f..53155ede9866 100644 --- a/libs/langchain-community/src/experimental/graph_transformers/llm.ts +++ b/libs/langchain-community/src/experimental/graph_transformers/llm.ts @@ -47,6 +47,11 @@ interface OptionalEnumFieldProps { fieldKwargs?: object; } +interface SchemaProperty { + key: string; + value: string; +} + function toTitleCase(str: string): string { return str .split(" ") @@ -86,50 +91,112 @@ function createOptionalEnumType({ return schema; } -function createSchema(allowedNodes: string[], allowedRelationships: string[]) { +function createNodeSchema(allowedNodes: string[], nodeProperties: string[]) { + const nodeSchema = z.object({ + id: z.string(), + type: createOptionalEnumType({ + enumValues: allowedNodes, + description: "The type or label of the node.", + }), + }); + + return nodeProperties.length > 0 + ? nodeSchema.extend({ + properties: z + .array( + z.object({ + key: createOptionalEnumType({ + enumValues: nodeProperties, + description: "Property key.", + }), + value: z.string().describe("Extracted value."), + }) + ) + .describe(`List of node properties`), + }) + : nodeSchema; +} + +function createRelationshipSchema( + allowedNodes: string[], + allowedRelationships: string[], + relationshipProperties: string[] +) { + const relationshipSchema = z.object({ + sourceNodeId: z.string(), + sourceNodeType: createOptionalEnumType({ + enumValues: allowedNodes, + description: "The source node of the relationship.", + }), + relationshipType: createOptionalEnumType({ + enumValues: allowedRelationships, + description: "The type of the relationship.", + isRel: true, + }), + targetNodeId: z.string(), + targetNodeType: createOptionalEnumType({ + enumValues: allowedNodes, + description: "The target node of the relationship.", + }), + }); + + return relationshipProperties.length > 0 + ? relationshipSchema.extend({ + properties: z + .array( + z.object({ + key: createOptionalEnumType({ + enumValues: relationshipProperties, + description: "Property key.", + }), + value: z.string().describe("Extracted value."), + }) + ) + .describe(`List of relationship properties`), + }) + : relationshipSchema; +} + +function createSchema( + allowedNodes: string[], + allowedRelationships: string[], + nodeProperties: string[], + relationshipProperties: string[] +) { + const nodeSchema = createNodeSchema(allowedNodes, nodeProperties); + const relationshipSchema = createRelationshipSchema( + allowedNodes, + allowedRelationships, + relationshipProperties + ); + const dynamicGraphSchema = z.object({ - nodes: z - .array( - z.object({ - id: z.string(), - type: createOptionalEnumType({ - enumValues: allowedNodes, - description: "The type or label of the node.", - }), - }) - ) - .describe("List of nodes"), + nodes: z.array(nodeSchema).describe("List of nodes"), relationships: z - .array( - z.object({ - sourceNodeId: z.string(), - sourceNodeType: createOptionalEnumType({ - enumValues: allowedNodes, - description: "The source node of the relationship.", - }), - relationshipType: createOptionalEnumType({ - enumValues: allowedRelationships, - description: "The type of the relationship.", - isRel: true, - }), - targetNodeId: z.string(), - targetNodeType: createOptionalEnumType({ - enumValues: allowedNodes, - description: "The target node of the relationship.", - }), - }) - ) + .array(relationshipSchema) .describe("List of relationships."), }); return dynamicGraphSchema; } +function convertPropertiesToRecord( + properties: SchemaProperty[] +): Record { + return properties.reduce((accumulator: Record, prop) => { + accumulator[prop.key] = prop.value; + return accumulator; + }, {}); +} + // eslint-disable-next-line @typescript-eslint/no-explicit-any function mapToBaseNode(node: any): Node { return new Node({ id: node.id, type: node.type ? toTitleCase(node.type) : "", + properties: node.properties + ? convertPropertiesToRecord(node.properties) + : {}, }); } @@ -149,6 +216,9 @@ function mapToBaseRelationship(relationship: any): Relationship { : "", }), type: relationship.relationshipType.replace(" ", "_").toUpperCase(), + properties: relationship.properties + ? convertPropertiesToRecord(relationship.properties) + : {}, }); } @@ -158,6 +228,8 @@ export interface LLMGraphTransformerProps { allowedRelationships?: string[]; prompt?: ChatPromptTemplate; strictMode?: boolean; + nodeProperties?: string[]; + relationshipProperties?: string[]; } export class LLMGraphTransformer { @@ -170,12 +242,18 @@ export class LLMGraphTransformer { strictMode: boolean; + nodeProperties: string[]; + + relationshipProperties: string[]; + constructor({ llm, allowedNodes = [], allowedRelationships = [], prompt = DEFAULT_PROMPT, strictMode = true, + nodeProperties = [], + relationshipProperties = [], }: LLMGraphTransformerProps) { if (typeof llm.withStructuredOutput !== "function") { throw new Error( @@ -186,9 +264,16 @@ export class LLMGraphTransformer { this.allowedNodes = allowedNodes; this.allowedRelationships = allowedRelationships; this.strictMode = strictMode; + this.nodeProperties = nodeProperties; + this.relationshipProperties = relationshipProperties; // Define chain - const schema = createSchema(allowedNodes, allowedRelationships); + const schema = createSchema( + allowedNodes, + allowedRelationships, + nodeProperties, + relationshipProperties + ); const structuredLLM = llm.withStructuredOutput(zodToJsonSchema(schema)); this.chain = prompt.pipe(structuredLLM); }