-
Notifications
You must be signed in to change notification settings - Fork 8
/
demo.py
48 lines (39 loc) · 1.77 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (C) 2024 Andrew Wason
# SPDX-License-Identifier: MIT
import asyncio
import pathlib
import sys
import typing as t
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import BaseTool
from langchain_groq import ChatGroq
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from langchain_mcp import MCPToolkit
async def run(tools: list[BaseTool], prompt: str) -> str:
model = ChatGroq(model="llama-3.1-8b-instant", stop_sequences=None) # requires GROQ_API_KEY
tools_map = {tool.name: tool for tool in tools}
tools_model = model.bind_tools(tools)
messages: list[BaseMessage] = [HumanMessage(prompt)]
ai_message = t.cast(AIMessage, await tools_model.ainvoke(messages))
messages.append(ai_message)
for tool_call in ai_message.tool_calls:
selected_tool = tools_map[tool_call["name"].lower()]
tool_msg = await selected_tool.ainvoke(tool_call)
messages.append(tool_msg)
return await (tools_model | StrOutputParser()).ainvoke(messages)
async def main(prompt: str) -> None:
server_params = StdioServerParameters(
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", str(pathlib.Path(__file__).parent.parent)],
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
toolkit = MCPToolkit(session=session)
await toolkit.initialize()
response = await run(toolkit.get_tools(), prompt)
print(response)
if __name__ == "__main__":
prompt = sys.argv[1] if len(sys.argv) > 1 else "Read and summarize the file ./LICENSE"
asyncio.run(main(prompt))