Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Commit

Permalink
Edge plotly tools (#8)
Browse files Browse the repository at this point in the history
* Add edge and plotly tools

* Bump minimal tilores SDK version

* Prepare release

* Bump minimal tilores SDK version

* Add missing dependency
  • Loading branch information
stefan-berkner-tilotech authored Oct 28, 2024
1 parent 2006168 commit ae9bbed
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
26 changes: 26 additions & 0 deletions examples/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os

# LangChain
import langchain
from langchain.tools import BaseTool
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_openai import ChatOpenAI
Expand All @@ -24,6 +25,10 @@
import chainlit as cl
from chainlit.sync import run_sync

# Plotly
import plotly.graph_objects as go
from plotly.io import from_json


class HumanInputChainlit(BaseTool):
"""Tool that adds the capability to ask user for input."""
Expand Down Expand Up @@ -72,7 +77,9 @@ def start():
tools = [
HumanInputChainlit(),
tilores_tools.search_tool(),
tilores_tools.edge_tool(),
pdf_tool,
plotly_tool,
]
# Use MemorySaver to use the full conversation
memory = MemorySaver()
Expand All @@ -97,6 +104,12 @@ async def main(message: cl.Message):
ui_message = cl.Message(content="")
await ui_message.send()
async for event in runnable.astream_events(state, version="v1", config={'configurable': {'thread_id': 'thread-1'}}):
if event["event"] == "on_tool_end":
if event["data"].get('output') and event["data"].get('output').artifact:
fig = from_json(event["data"].get("output").artifact)
chart = cl.Plotly(name="chart", figure=fig, display="inline")
ui_message.elements.append(chart)

if event["event"] == "on_chat_model_stream":
c = event["data"]["chunk"].content
if c and len(c) > 0 and isinstance(c[0], dict) and c[0]["type"] == "text":
Expand Down Expand Up @@ -128,4 +141,17 @@ def load_pdf_from_url(url: str):
name = "load_pdf",
func=load_pdf_from_url,
description="useful for when you need to download and process a PDF file from a given URL"
)

def render_plotly_graph(figureCode: str):
local_vars = {}
exec(figureCode, {"go": go}, local_vars)
fig = local_vars.get("fig")
return "generated a chart from the provided figure", fig.to_json()

plotly_tool = Tool(
name = "plotly_tool",
func=render_plotly_graph,
description="useful for when you need to render a graph using plotly; the figureCode must only import plotly.graph_objects as go and must provide a local variable named fig as a result",
response_format='content_and_artifact'
)
3 changes: 2 additions & 1 deletion examples/chat/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
langchain-tilores[all]>=0.1.0
langchain-tilores[all]>=0.3.0
langgraph==0.2.22
langchain_openai
langchain_aws
Expand All @@ -8,3 +8,4 @@ chainlit
unstructured
pdfminer.six
unstructured[pdf]
plotly==5.24.1
14 changes: 12 additions & 2 deletions langchain_tilores/tilores_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tilores.helpers import PydanticFactory
from functools import cached_property
from langchain.tools import StructuredTool
from pydantic import create_model

class TiloresTools:
"""
Expand All @@ -21,8 +22,8 @@ def references(self):

def all(self):
return [
# self.record_fields_tool,
self.search_tool
self.search_tool,
self.edge_tool
]

def search_tool(self):
Expand All @@ -33,6 +34,15 @@ def search_tool(self):
'return_direct': True,
'func': self.tilores_api.search
})

def edge_tool(self):
return StructuredTool.from_function(**{
'name': 'tilores_entity_edges',
'description': 'useful for when you need to provide details about why certain records of an entity are matching; a single edge contains two record IDs and a rule ID representing the reason why those two records belong to the same entity',
'args_schema': create_model("EntityEdgesArgs", entityID=(str, ...)),
'return_direct': True,
'func': self.tilores_api.entity_edges
})

def static_value(val):
def wrapper():
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
[project]
name = "langchain-tilores"
version = "0.2.2"
version = "0.3.0"
authors = [
{ name="Lukas Rieder", email="lukas@parlant.co" },
{ name="Stefan Berkner", email="stefan.berkner@tilores.io" },
]
description = "This package contains tools to work with Tilores entity resolution database within Langchain."
readme = "README.md"
Expand All @@ -15,7 +16,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"tilores-sdk>=0.1.0",
"tilores-sdk>=0.3.0",
]
[project.optional-dependencies]
all = ["langchain-tilores[langchain]"]
Expand Down

0 comments on commit ae9bbed

Please sign in to comment.