Skip to content

Commit 2a40e9d

Browse files
committed
vanna-ai#548: Added support for additional db connect options.
1 parent 8cc20fb commit 2a40e9d

File tree

1 file changed

+78
-54
lines changed

1 file changed

+78
-54
lines changed

src/vanna/base/base.py

+78-54
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->
130130
question_sql_list=question_sql_list,
131131
ddl_list=ddl_list,
132132
doc_list=doc_list,
133-
**kwargs,
133+
**kwargs
134134
)
135135
self.log(title="SQL Prompt", message=prompt)
136136
llm_response = self.submit_prompt(prompt, **kwargs)
@@ -153,7 +153,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->
153153
question_sql_list=question_sql_list,
154154
ddl_list=ddl_list,
155155
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
156-
**kwargs,
156+
**kwargs
157157
)
158158
self.log(title="Final SQL Prompt", message=prompt)
159159
llm_response = self.submit_prompt(prompt, **kwargs)
@@ -529,7 +529,7 @@ def get_sql_prompt(
529529
question_sql_list: list,
530530
ddl_list: list,
531531
doc_list: list,
532-
**kwargs,
532+
**kwargs
533533
):
534534
"""
535535
Example:
@@ -599,7 +599,7 @@ def get_followup_questions_prompt(
599599
question_sql_list: list,
600600
ddl_list: list,
601601
doc_list: list,
602-
**kwargs,
602+
**kwargs
603603
) -> list:
604604
initial_prompt = f"The user initially asked the question: '{question}': \n\n"
605605

@@ -655,7 +655,7 @@ def generate_question(self, sql: str, **kwargs) -> str:
655655
),
656656
self.user_message(sql),
657657
],
658-
**kwargs,
658+
**kwargs
659659
)
660660

661661
return response
@@ -718,6 +718,7 @@ def connect_to_snowflake(
718718
database: str,
719719
role: Union[str, None] = None,
720720
warehouse: Union[str, None] = None,
721+
**kwargs
721722
):
722723
try:
723724
snowflake = __import__("snowflake.connector")
@@ -764,7 +765,8 @@ def connect_to_snowflake(
764765
password=password,
765766
account=account,
766767
database=database,
767-
client_session_keep_alive=True
768+
client_session_keep_alive=True,
769+
**kwargs
768770
)
769771

770772
def run_sql_snowflake(sql: str) -> pd.DataFrame:
@@ -831,6 +833,7 @@ def connect_to_postgres(
831833
user: str = None,
832834
password: str = None,
833835
port: int = None,
836+
**kwargs
834837
):
835838
"""
836839
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
@@ -900,6 +903,7 @@ def connect_to_postgres(
900903
user=user,
901904
password=password,
902905
port=port,
906+
**kwargs
903907
)
904908
except psycopg2.Error as e:
905909
raise ValidationError(e)
@@ -931,12 +935,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
931935

932936

933937
def connect_to_mysql(
934-
self,
935-
host: str = None,
936-
dbname: str = None,
937-
user: str = None,
938-
password: str = None,
939-
port: int = None,
938+
self,
939+
host: str = None,
940+
dbname: str = None,
941+
user: str = None,
942+
password: str = None,
943+
port: int = None,
944+
**kwargs
940945
):
941946

942947
try:
@@ -980,12 +985,15 @@ def connect_to_mysql(
980985
conn = None
981986

982987
try:
983-
conn = pymysql.connect(host=host,
984-
user=user,
985-
password=password,
986-
database=dbname,
987-
port=port,
988-
cursorclass=pymysql.cursors.DictCursor)
988+
conn = pymysql.connect(
989+
host=host,
990+
user=user,
991+
password=password,
992+
database=dbname,
993+
port=port,
994+
cursorclass=pymysql.cursors.DictCursor,
995+
**kwargs
996+
)
989997
except pymysql.Error as e:
990998
raise ValidationError(e)
991999

@@ -1015,12 +1023,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
10151023
self.run_sql = run_sql_mysql
10161024

10171025
def connect_to_clickhouse(
1018-
self,
1019-
host: str = None,
1020-
dbname: str = None,
1021-
user: str = None,
1022-
password: str = None,
1023-
port: int = None,
1026+
self,
1027+
host: str = None,
1028+
dbname: str = None,
1029+
user: str = None,
1030+
password: str = None,
1031+
port: int = None,
1032+
**kwargs
10241033
):
10251034

10261035
try:
@@ -1070,6 +1079,7 @@ def connect_to_clickhouse(
10701079
username=user,
10711080
password=password,
10721081
database=dbname,
1082+
**kwargs
10731083
)
10741084
print(conn)
10751085
except Exception as e:
@@ -1087,15 +1097,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
10871097

10881098
except Exception as e:
10891099
raise e
1090-
1100+
10911101
self.run_sql_is_set = True
10921102
self.run_sql = run_sql_clickhouse
10931103

10941104
def connect_to_oracle(
1095-
self,
1096-
user: str = None,
1097-
password: str = None,
1098-
dsn: str = None,
1105+
self,
1106+
user: str = None,
1107+
password: str = None,
1108+
dsn: str = None,
1109+
**kwargs
10991110
):
11001111

11011112
"""
@@ -1148,7 +1159,8 @@ def connect_to_oracle(
11481159
user=user,
11491160
password=password,
11501161
dsn=dsn,
1151-
)
1162+
**kwargs
1163+
)
11521164
except oracledb.Error as e:
11531165
raise ValidationError(e)
11541166

@@ -1180,7 +1192,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
11801192
self.run_sql_is_set = True
11811193
self.run_sql = run_sql_oracle
11821194

1183-
def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
1195+
def connect_to_bigquery(
1196+
self,
1197+
cred_file_path: str = None,
1198+
project_id: str = None,
1199+
**kwargs
1200+
):
11841201
"""
11851202
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
11861203
**Example:**
@@ -1242,7 +1259,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None
12421259
)
12431260

12441261
try:
1245-
conn = bigquery.Client(project=project_id, credentials=credentials)
1262+
conn = bigquery.Client(
1263+
project=project_id,
1264+
credentials=credentials,
1265+
**kwargs
1266+
)
12461267
except:
12471268
raise ImproperlyConfigured(
12481269
"Could not connect to bigquery please correct credentials"
@@ -1265,7 +1286,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
12651286
self.run_sql_is_set = True
12661287
self.run_sql = run_sql_bigquery
12671288

1268-
def connect_to_duckdb(self, url: str, init_sql: str = None):
1289+
def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs):
12691290
"""
12701291
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
12711292
@@ -1303,7 +1324,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
13031324
f.write(response.content)
13041325

13051326
# Connect to the database
1306-
conn = duckdb.connect(path)
1327+
conn = duckdb.connect(path, **kwargs)
13071328
if init_sql:
13081329
conn.query(init_sql)
13091330

@@ -1314,7 +1335,7 @@ def run_sql_duckdb(sql: str):
13141335
self.run_sql = run_sql_duckdb
13151336
self.run_sql_is_set = True
13161337

1317-
def connect_to_mssql(self, odbc_conn_str: str):
1338+
def connect_to_mssql(self, odbc_conn_str: str, **kwargs):
13181339
"""
13191340
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
13201341
@@ -1347,7 +1368,7 @@ def connect_to_mssql(self, odbc_conn_str: str):
13471368

13481369
from sqlalchemy import create_engine
13491370

1350-
engine = create_engine(connection_url)
1371+
engine = create_engine(connection_url, **kwargs)
13511372

13521373
def run_sql_mssql(sql: str):
13531374
# Execute the SQL statement and return the result as a pandas DataFrame
@@ -1362,16 +1383,17 @@ def run_sql_mssql(sql: str):
13621383
self.run_sql = run_sql_mssql
13631384
self.run_sql_is_set = True
13641385
def connect_to_presto(
1365-
self,
1366-
host: str,
1367-
catalog: str = 'hive',
1368-
schema: str = 'default',
1369-
user: str = None,
1370-
password: str = None,
1371-
port: int = None,
1372-
combined_pem_path: str = None,
1373-
protocol: str = 'https',
1374-
requests_kwargs: dict = None
1386+
self,
1387+
host: str,
1388+
catalog: str = 'hive',
1389+
schema: str = 'default',
1390+
user: str = None,
1391+
password: str = None,
1392+
port: int = None,
1393+
combined_pem_path: str = None,
1394+
protocol: str = 'https',
1395+
requests_kwargs: dict = None,
1396+
**kwargs
13751397
):
13761398
"""
13771399
Connect to a Presto database using the specified parameters.
@@ -1444,7 +1466,8 @@ def connect_to_presto(
14441466
schema=schema,
14451467
port=port,
14461468
protocol=protocol,
1447-
requests_kwargs=requests_kwargs)
1469+
requests_kwargs=requests_kwargs,
1470+
**kwargs)
14481471
except presto.Error as e:
14491472
raise ValidationError(e)
14501473

@@ -1477,13 +1500,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
14771500
self.run_sql = run_sql_presto
14781501

14791502
def connect_to_hive(
1480-
self,
1481-
host: str = None,
1482-
dbname: str = 'default',
1483-
user: str = None,
1484-
password: str = None,
1485-
port: int = None,
1486-
auth: str = 'CUSTOM'
1503+
self,
1504+
host: str = None,
1505+
dbname: str = 'default',
1506+
user: str = None,
1507+
password: str = None,
1508+
port: int = None,
1509+
auth: str = 'CUSTOM',
1510+
**kwargs
14871511
):
14881512
"""
14891513
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]

0 commit comments

Comments
 (0)