Skip to content

Commit e6aea6b

Browse files
authored
Merge pull request #139 from vanna-ai/training-plan-base
add generic training plan
2 parents 8364f3d + 9f33650 commit e6aea6b

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

src/vanna/base/base.py

+59
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,65 @@ def _get_information_schema_tables(self, database: str) -> pd.DataFrame:
612612

613613
return df_tables
614614

615+
def get_training_plan_generic(self, df) -> TrainingPlan:
616+
# For each of the following, we look at the df columns to see if there's a match:
617+
database_column = df.columns[
618+
df.columns.str.lower().str.contains("database")
619+
| df.columns.str.lower().str.contains("table_catalog")
620+
].to_list()[0]
621+
schema_column = df.columns[
622+
df.columns.str.lower().str.contains("table_schema")
623+
].to_list()[0]
624+
table_column = df.columns[
625+
df.columns.str.lower().str.contains("table_name")
626+
].to_list()[0]
627+
column_column = df.columns[
628+
df.columns.str.lower().str.contains("column_name")
629+
].to_list()[0]
630+
data_type_column = df.columns[
631+
df.columns.str.lower().str.contains("data_type")
632+
].to_list()[0]
633+
634+
plan = TrainingPlan([])
635+
636+
for database in df[database_column].unique().tolist():
637+
for schema in (
638+
df.query(f'{database_column} == "{database}"')[schema_column]
639+
.unique()
640+
.tolist()
641+
):
642+
for table in (
643+
df.query(
644+
f'{database_column} == "{database}" and {schema_column} == "{schema}"'
645+
)[table_column]
646+
.unique()
647+
.tolist()
648+
):
649+
df_columns_filtered_to_table = df.query(
650+
f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"'
651+
)
652+
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
653+
doc += df_columns_filtered_to_table[
654+
[
655+
database_column,
656+
schema_column,
657+
table_column,
658+
column_column,
659+
data_type_column,
660+
]
661+
].to_markdown()
662+
663+
plan._plan.append(
664+
TrainingPlanItem(
665+
item_type=TrainingPlanItem.ITEM_TYPE_IS,
666+
item_group=f"{database}.{schema}",
667+
item_name=table,
668+
item_value=doc,
669+
)
670+
)
671+
672+
return plan
673+
615674
def get_training_plan_snowflake(
616675
self,
617676
filter_databases: Union[List[str], None] = None,

0 commit comments

Comments
 (0)