@@ -130,7 +130,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->
130
130
question_sql_list = question_sql_list ,
131
131
ddl_list = ddl_list ,
132
132
doc_list = doc_list ,
133
- ** kwargs ,
133
+ ** kwargs
134
134
)
135
135
self .log (title = "SQL Prompt" , message = prompt )
136
136
llm_response = self .submit_prompt (prompt , ** kwargs )
@@ -153,7 +153,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->
153
153
question_sql_list = question_sql_list ,
154
154
ddl_list = ddl_list ,
155
155
doc_list = doc_list + [f"The following is a pandas DataFrame with the results of the intermediate SQL query { intermediate_sql } : \n " + df .to_markdown ()],
156
- ** kwargs ,
156
+ ** kwargs
157
157
)
158
158
self .log (title = "Final SQL Prompt" , message = prompt )
159
159
llm_response = self .submit_prompt (prompt , ** kwargs )
@@ -529,7 +529,7 @@ def get_sql_prompt(
529
529
question_sql_list : list ,
530
530
ddl_list : list ,
531
531
doc_list : list ,
532
- ** kwargs ,
532
+ ** kwargs
533
533
):
534
534
"""
535
535
Example:
@@ -599,7 +599,7 @@ def get_followup_questions_prompt(
599
599
question_sql_list : list ,
600
600
ddl_list : list ,
601
601
doc_list : list ,
602
- ** kwargs ,
602
+ ** kwargs
603
603
) -> list :
604
604
initial_prompt = f"The user initially asked the question: '{ question } ': \n \n "
605
605
@@ -655,7 +655,7 @@ def generate_question(self, sql: str, **kwargs) -> str:
655
655
),
656
656
self .user_message (sql ),
657
657
],
658
- ** kwargs ,
658
+ ** kwargs
659
659
)
660
660
661
661
return response
@@ -718,6 +718,7 @@ def connect_to_snowflake(
718
718
database : str ,
719
719
role : Union [str , None ] = None ,
720
720
warehouse : Union [str , None ] = None ,
721
+ ** kwargs
721
722
):
722
723
try :
723
724
snowflake = __import__ ("snowflake.connector" )
@@ -764,7 +765,8 @@ def connect_to_snowflake(
764
765
password = password ,
765
766
account = account ,
766
767
database = database ,
767
- client_session_keep_alive = True
768
+ client_session_keep_alive = True ,
769
+ ** kwargs
768
770
)
769
771
770
772
def run_sql_snowflake (sql : str ) -> pd .DataFrame :
@@ -831,6 +833,7 @@ def connect_to_postgres(
831
833
user : str = None ,
832
834
password : str = None ,
833
835
port : int = None ,
836
+ ** kwargs
834
837
):
835
838
"""
836
839
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
@@ -900,6 +903,7 @@ def connect_to_postgres(
900
903
user = user ,
901
904
password = password ,
902
905
port = port ,
906
+ ** kwargs
903
907
)
904
908
except psycopg2 .Error as e :
905
909
raise ValidationError (e )
@@ -931,12 +935,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
931
935
932
936
933
937
def connect_to_mysql (
934
- self ,
935
- host : str = None ,
936
- dbname : str = None ,
937
- user : str = None ,
938
- password : str = None ,
939
- port : int = None ,
938
+ self ,
939
+ host : str = None ,
940
+ dbname : str = None ,
941
+ user : str = None ,
942
+ password : str = None ,
943
+ port : int = None ,
944
+ ** kwargs
940
945
):
941
946
942
947
try :
@@ -980,12 +985,15 @@ def connect_to_mysql(
980
985
conn = None
981
986
982
987
try :
983
- conn = pymysql .connect (host = host ,
984
- user = user ,
985
- password = password ,
986
- database = dbname ,
987
- port = port ,
988
- cursorclass = pymysql .cursors .DictCursor )
988
+ conn = pymysql .connect (
989
+ host = host ,
990
+ user = user ,
991
+ password = password ,
992
+ database = dbname ,
993
+ port = port ,
994
+ cursorclass = pymysql .cursors .DictCursor ,
995
+ ** kwargs
996
+ )
989
997
except pymysql .Error as e :
990
998
raise ValidationError (e )
991
999
@@ -1015,12 +1023,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
1015
1023
self .run_sql = run_sql_mysql
1016
1024
1017
1025
def connect_to_clickhouse (
1018
- self ,
1019
- host : str = None ,
1020
- dbname : str = None ,
1021
- user : str = None ,
1022
- password : str = None ,
1023
- port : int = None ,
1026
+ self ,
1027
+ host : str = None ,
1028
+ dbname : str = None ,
1029
+ user : str = None ,
1030
+ password : str = None ,
1031
+ port : int = None ,
1032
+ ** kwargs
1024
1033
):
1025
1034
1026
1035
try :
@@ -1070,6 +1079,7 @@ def connect_to_clickhouse(
1070
1079
username = user ,
1071
1080
password = password ,
1072
1081
database = dbname ,
1082
+ ** kwargs
1073
1083
)
1074
1084
print (conn )
1075
1085
except Exception as e :
@@ -1087,15 +1097,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
1087
1097
1088
1098
except Exception as e :
1089
1099
raise e
1090
-
1100
+
1091
1101
self .run_sql_is_set = True
1092
1102
self .run_sql = run_sql_clickhouse
1093
1103
1094
1104
def connect_to_oracle (
1095
- self ,
1096
- user : str = None ,
1097
- password : str = None ,
1098
- dsn : str = None ,
1105
+ self ,
1106
+ user : str = None ,
1107
+ password : str = None ,
1108
+ dsn : str = None ,
1109
+ ** kwargs
1099
1110
):
1100
1111
1101
1112
"""
@@ -1148,7 +1159,8 @@ def connect_to_oracle(
1148
1159
user = user ,
1149
1160
password = password ,
1150
1161
dsn = dsn ,
1151
- )
1162
+ ** kwargs
1163
+ )
1152
1164
except oracledb .Error as e :
1153
1165
raise ValidationError (e )
1154
1166
@@ -1180,7 +1192,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
1180
1192
self .run_sql_is_set = True
1181
1193
self .run_sql = run_sql_oracle
1182
1194
1183
- def connect_to_bigquery (self , cred_file_path : str = None , project_id : str = None ):
1195
+ def connect_to_bigquery (
1196
+ self ,
1197
+ cred_file_path : str = None ,
1198
+ project_id : str = None ,
1199
+ ** kwargs
1200
+ ):
1184
1201
"""
1185
1202
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1186
1203
**Example:**
@@ -1242,7 +1259,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None
1242
1259
)
1243
1260
1244
1261
try :
1245
- conn = bigquery .Client (project = project_id , credentials = credentials )
1262
+ conn = bigquery .Client (
1263
+ project = project_id ,
1264
+ credentials = credentials ,
1265
+ ** kwargs
1266
+ )
1246
1267
except :
1247
1268
raise ImproperlyConfigured (
1248
1269
"Could not connect to bigquery please correct credentials"
@@ -1265,7 +1286,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
1265
1286
self .run_sql_is_set = True
1266
1287
self .run_sql = run_sql_bigquery
1267
1288
1268
- def connect_to_duckdb (self , url : str , init_sql : str = None ):
1289
+ def connect_to_duckdb (self , url : str , init_sql : str = None , ** kwargs ):
1269
1290
"""
1270
1291
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1271
1292
@@ -1303,7 +1324,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
1303
1324
f .write (response .content )
1304
1325
1305
1326
# Connect to the database
1306
- conn = duckdb .connect (path )
1327
+ conn = duckdb .connect (path , ** kwargs )
1307
1328
if init_sql :
1308
1329
conn .query (init_sql )
1309
1330
@@ -1314,7 +1335,7 @@ def run_sql_duckdb(sql: str):
1314
1335
self .run_sql = run_sql_duckdb
1315
1336
self .run_sql_is_set = True
1316
1337
1317
- def connect_to_mssql (self , odbc_conn_str : str ):
1338
+ def connect_to_mssql (self , odbc_conn_str : str , ** kwargs ):
1318
1339
"""
1319
1340
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1320
1341
@@ -1347,7 +1368,7 @@ def connect_to_mssql(self, odbc_conn_str: str):
1347
1368
1348
1369
from sqlalchemy import create_engine
1349
1370
1350
- engine = create_engine (connection_url )
1371
+ engine = create_engine (connection_url , ** kwargs )
1351
1372
1352
1373
def run_sql_mssql (sql : str ):
1353
1374
# Execute the SQL statement and return the result as a pandas DataFrame
@@ -1362,16 +1383,17 @@ def run_sql_mssql(sql: str):
1362
1383
self .run_sql = run_sql_mssql
1363
1384
self .run_sql_is_set = True
1364
1385
def connect_to_presto (
1365
- self ,
1366
- host : str ,
1367
- catalog : str = 'hive' ,
1368
- schema : str = 'default' ,
1369
- user : str = None ,
1370
- password : str = None ,
1371
- port : int = None ,
1372
- combined_pem_path : str = None ,
1373
- protocol : str = 'https' ,
1374
- requests_kwargs : dict = None
1386
+ self ,
1387
+ host : str ,
1388
+ catalog : str = 'hive' ,
1389
+ schema : str = 'default' ,
1390
+ user : str = None ,
1391
+ password : str = None ,
1392
+ port : int = None ,
1393
+ combined_pem_path : str = None ,
1394
+ protocol : str = 'https' ,
1395
+ requests_kwargs : dict = None ,
1396
+ ** kwargs
1375
1397
):
1376
1398
"""
1377
1399
Connect to a Presto database using the specified parameters.
@@ -1444,7 +1466,8 @@ def connect_to_presto(
1444
1466
schema = schema ,
1445
1467
port = port ,
1446
1468
protocol = protocol ,
1447
- requests_kwargs = requests_kwargs )
1469
+ requests_kwargs = requests_kwargs ,
1470
+ ** kwargs )
1448
1471
except presto .Error as e :
1449
1472
raise ValidationError (e )
1450
1473
@@ -1477,13 +1500,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1477
1500
self .run_sql = run_sql_presto
1478
1501
1479
1502
def connect_to_hive (
1480
- self ,
1481
- host : str = None ,
1482
- dbname : str = 'default' ,
1483
- user : str = None ,
1484
- password : str = None ,
1485
- port : int = None ,
1486
- auth : str = 'CUSTOM'
1503
+ self ,
1504
+ host : str = None ,
1505
+ dbname : str = 'default' ,
1506
+ user : str = None ,
1507
+ password : str = None ,
1508
+ port : int = None ,
1509
+ auth : str = 'CUSTOM' ,
1510
+ ** kwargs
1487
1511
):
1488
1512
"""
1489
1513
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
0 commit comments