Skip to content

Commit

Permalink
Merge branch 'llm_tools' of https://github.com/srdas/jupyter-ai into …
Browse files Browse the repository at this point in the history
…llm_tools
  • Loading branch information
srdas committed Sep 20, 2024
2 parents 30bb099 + 1aaa56f commit 5328aab
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
import numpy as np
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyter_core.paths import jupyter_config_dir
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import tool
from langgraph.graph import MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from .base import BaseChatHandler, SlashCommandRoutingType

from jupyter_core.paths import jupyter_config_dir
TOOLS_DIR = os.path.join(jupyter_config_dir(), "jupyter-ai", "tools")

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Expand Down Expand Up @@ -85,10 +83,7 @@ def __init__(self, *args, **kwargs):
self.parser.add_argument("query", nargs=argparse.REMAINDER)
self.tools_file_path = None


def setup_llm(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
def setup_llm(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
"""Sets up the LLM before creating the LLM Chain"""
unified_parameters = {
"verbose": True,
Expand Down Expand Up @@ -124,7 +119,6 @@ def create_llm_chain(
)
self.llm_chain = runnable


def get_tool_files(self) -> list:
"""
Gets required tool files from TOOLS_DIR
Expand Down Expand Up @@ -179,7 +173,7 @@ def get_tools(file_paths: list) -> list:
for file_path in file_paths:
with open(file_path) as file:
exec(file.read())
try: # For each tool file, collect tool list
try: # For each tool file, collect tool list
with open(file_path) as file:
content = file.read()
tree = ast.parse(content)
Expand Down Expand Up @@ -246,19 +240,13 @@ async def process_message(self, message: HumanChatMessage):
return

if args.list:
tool_files = os.listdir(
os.path.join(Path.home(), TOOLS_DIR)
)
tool_files = os.listdir(os.path.join(Path.home(), TOOLS_DIR))
self.reply(f"The available tools files are: {tool_files}")
return
elif args.tools:
self.tools_file_path = os.path.join(
Path.home(), TOOLS_DIR, args.tools
)
self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR, args.tools)
else:
self.tools_file_path = os.path.join(
Path.home(), TOOLS_DIR
)
self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR)

query = " ".join(args.query)
if not query:
Expand Down

0 comments on commit 5328aab

Please sign in to comment.