Skip to content

Commit

Permalink
Merge pull request #195 from crewAIInc/bugfix/make-tooling-optional
Browse files Browse the repository at this point in the history
Fix for Gui
  • Loading branch information
bhancockio authored Jan 24, 2025
2 parents f98b768 + 89ca736 commit f3805ad
Show file tree
Hide file tree
Showing 4 changed files with 717 additions and 698 deletions.
18 changes: 12 additions & 6 deletions crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import Any, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr

# Type checking import
if TYPE_CHECKING:
from firecrawl import FirecrawlApp


try:
from firecrawl import FirecrawlApp

FIRECRAWL_AVAILABLE = True
except ImportError:
FirecrawlApp = Any
FIRECRAWL_AVAILABLE = False


class FirecrawlSearchToolSchema(BaseModel):
Expand Down Expand Up @@ -51,9 +56,10 @@ def __init__(self, api_key: Optional[str] = None, **kwargs):

def _initialize_firecrawl(self) -> None:
try:
from firecrawl import FirecrawlApp # type: ignore

self.firecrawl = FirecrawlApp(api_key=self.api_key)
if FIRECRAWL_AVAILABLE:
self._firecrawl = FirecrawlApp(api_key=self.api_key)
else:
raise ImportError
except ImportError:
import click

Expand Down
109 changes: 54 additions & 55 deletions crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Any, Type
from typing import TYPE_CHECKING, Any, Type

from crewai.tools import BaseTool
from patronus import Client
from pydantic import BaseModel, Field

if TYPE_CHECKING:
from patronus import Client, EvaluationResult

try:
from patronus import Client
import patronus

PYPATRONUS_AVAILABLE = True
except ImportError:
PYPATRONUS_AVAILABLE = False
Client = Any


class FixedLocalEvaluatorToolSchema(BaseModel):
Expand All @@ -31,53 +32,59 @@ class FixedLocalEvaluatorToolSchema(BaseModel):

class PatronusLocalEvaluatorTool(BaseTool):
name: str = "Patronus Local Evaluator Tool"
evaluator: str = "The registered local evaluator"
evaluated_model_gold_answer: str = "The agent's gold answer"
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
client: Any = None
description: str = (
"This tool is used to evaluate the model input and output using custom function evaluators."
)
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
client: "Client" = None
evaluator: str
evaluated_model_gold_answer: str

class Config:
arbitrary_types_allowed = True

def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
**kwargs: Any,
):
def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
patronus_client: "Client" = None,
evaluator: str = "",
evaluated_model_gold_answer: str = "",
**kwargs: Any,
):
super().__init__(**kwargs)
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
if evaluator:
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
self._generate_description()
print(
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
else:
self.evaluator = evaluator
self.evaluated_model_gold_answer = evaluated_model_gold_answer
self._initialize_patronus(patronus_client)

def _initialize_patronus(self, patronus_client: "Client") -> None:
try:
if PYPATRONUS_AVAILABLE:
self.client = patronus_client
self._generate_description()
print(
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
else:
raise ImportError
except ImportError:
import click

if click.confirm(
"You are missing the 'patronus' package. Would you like to install it?"
):
import subprocess

subprocess.run(["uv", "add", "patronus"], check=True)
try:
subprocess.run(["uv", "add", "patronus"], check=True)
self.client = patronus_client
self._generate_description()
print(
f"Updating evaluator and gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
)
except subprocess.CalledProcessError:
raise ImportError("Failed to install 'patronus' package")
else:
raise ImportError(
"You are missing the patronus package. Would you like to install it?"
"`patronus` package not found, please run `uv add patronus`"
)

def _run(
Expand All @@ -92,30 +99,22 @@ def _run(
evaluated_model_gold_answer = self.evaluated_model_gold_answer
evaluator = self.evaluator

result = self.client.evaluate(
result: "EvaluationResult" = self.client.evaluate(
evaluator=evaluator,
evaluated_model_input=(
evaluated_model_input
if isinstance(evaluated_model_input, str)
else evaluated_model_input.get("description")
),
evaluated_model_output=(
evaluated_model_output
if isinstance(evaluated_model_output, str)
else evaluated_model_output.get("description")
),
evaluated_model_retrieved_context=(
evaluated_model_retrieved_context
if isinstance(evaluated_model_retrieved_context, str)
else evaluated_model_retrieved_context.get("description")
),
evaluated_model_gold_answer=(
evaluated_model_gold_answer
if isinstance(evaluated_model_gold_answer, str)
else evaluated_model_gold_answer.get("description")
),
tags={}, # Optional metadata, supports arbitrary kv pairs
tags={}, # Optional metadata, supports arbitrary kv pairs
evaluated_model_input=evaluated_model_input,
evaluated_model_output=evaluated_model_output,
evaluated_model_retrieved_context=evaluated_model_retrieved_context,
evaluated_model_gold_answer=evaluated_model_gold_answer,
tags={}, # Optional metadata, supports arbitrary key-value pairs
)
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
return output


try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
PatronusLocalEvaluatorTool.model_rebuild()
PatronusLocalEvaluatorTool._model_rebuilt = True
except Exception:
pass
109 changes: 88 additions & 21 deletions crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

import snowflake.connector
from crewai.tools.base_tool import BaseTool
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError

if TYPE_CHECKING:
# Import types for type checking only
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError

try:
import snowflake.connector
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

SNOWFLAKE_AVAILABLE = True
except ImportError:
SNOWFLAKE_AVAILABLE = False

# Configure logging
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -83,24 +92,70 @@ class SnowflakeSearchTool(BaseTool):
default=True, description="Enable query result caching"
)

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
)

_connection_pool: Optional[List["SnowflakeConnection"]] = None
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False

def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._connection_pool: List[SnowflakeConnection] = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
self._initialize_snowflake()

def _initialize_snowflake(self) -> None:
try:
if SNOWFLAKE_AVAILABLE:
self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
else:
raise ImportError
except ImportError:
import click

if click.confirm(
"You are missing the 'snowflake-connector-python' package. Would you like to install it?"
):
import subprocess

try:
subprocess.run(
[
"uv",
"add",
"cryptography",
"snowflake-connector-python",
"snowflake-sqlalchemy",
],
check=True,
)

self._connection_pool = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
except subprocess.CalledProcessError:
raise ImportError("Failed to install Snowflake dependencies")
else:
raise ImportError(
"Snowflake dependencies not found. Please install them by running "
"`uv add cryptography snowflake-connector-python snowflake-sqlalchemy`"
)

async def _get_connection(self) -> SnowflakeConnection:
async def _get_connection(self) -> "SnowflakeConnection":
"""Get a connection from the pool or create a new one."""
async with self._pool_lock:
if not self._connection_pool:
conn = self._create_connection()
conn = await asyncio.get_event_loop().run_in_executor(
self._thread_pool, self._create_connection
)
self._connection_pool.append(conn)
return self._connection_pool.pop()

def _create_connection(self) -> SnowflakeConnection:
def _create_connection(self) -> "SnowflakeConnection":
"""Create a new Snowflake connection."""
conn_params = {
"account": self.config.account,
Expand All @@ -114,7 +169,7 @@ def _create_connection(self) -> SnowflakeConnection:

if self.config.password:
conn_params["password"] = self.config.password.get_secret_value()
elif self.config.private_key_path:
elif self.config.private_key_path and serialization:
with open(self.config.private_key_path, "rb") as key_file:
p_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend()
Expand All @@ -131,6 +186,7 @@ async def _execute_query(
self, query: str, timeout: int = 300
) -> List[Dict[str, Any]]:
"""Execute a query with retries and return results."""

if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
Expand Down Expand Up @@ -174,6 +230,7 @@ async def _run(
**kwargs: Any,
) -> Any:
"""Execute the search query."""

try:
# Override database/schema if provided
if database:
Expand All @@ -190,12 +247,22 @@ async def _run(
def __del__(self):
"""Cleanup connections on deletion."""
try:
for conn in getattr(self, "_connection_pool", []):
try:
conn.close()
except:
pass
if hasattr(self, "_thread_pool"):
if self._connection_pool:
for conn in self._connection_pool:
try:
conn.close()
except Exception:
pass
if self._thread_pool:
self._thread_pool.shutdown()
except:
except Exception:
pass


try:
# Only rebuild if the class hasn't been initialized yet
if not hasattr(SnowflakeSearchTool, "_model_rebuilt"):
SnowflakeSearchTool.model_rebuild()
SnowflakeSearchTool._model_rebuilt = True
except Exception:
pass
Loading

0 comments on commit f3805ad

Please sign in to comment.