From 4e0d1d8ac843b53722baafb18149ad360814c052 Mon Sep 17 00:00:00 2001 From: Rinchin Date: Sat, 14 Dec 2024 17:28:49 +0000 Subject: [PATCH] add asserts for targets --- scripts/experiments/run_tabular.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/scripts/experiments/run_tabular.py b/scripts/experiments/run_tabular.py index e43a741c..82fbf8b1 100644 --- a/scripts/experiments/run_tabular.py +++ b/scripts/experiments/run_tabular.py @@ -41,6 +41,22 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool) train = pd.read_csv(os.path.join(dataset_local_path, "train.csv")) test = pd.read_csv(os.path.join(dataset_local_path, "test.csv")) + if task_type == "multilabel": + target_name = [x for x in test.columns if x.startswith("target")] + else: + target_name = test.columns[-1] + + if task_type in ["binary", "multiclass", "multilabel"]: + assert ( + train[target_name].nunique() == test[target_name].nunique() + ), "train and test has different unique values." + + assert min(train[target_name].nunique() > 1) is True, "Only one class present in train target." + assert min(test[target_name].nunique() > 1) is True, "Only one class present in test target." + + assert train.isnull().values.any() is False, "train has nans in target." + assert test.isnull().values.any() is False, "test has nans in target." + task = Task(task_type) # =================================== automl config: @@ -50,9 +66,9 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool) cpu_limit=cpu_limit, memory_limit=memory_limit, timeout=15 * 60, - general_params={ - # "use_algos": [["mlp"]] - }, # ['nn', 'mlp', 'dense', 'denselight', 'resnet', 'snn', 'node', 'autoint', 'fttransformer'] or custom torch model + # general_params={ + # "use_algos": [["mlp"]] + # }, # ['nn', 'mlp', 'dense', 'denselight', 'resnet', 'snn', 'node', 'autoint', 'fttransformer'] or custom torch model # nn_params={"n_epochs": 10, "bs": 512, "num_workers": 0, "path_to_save": None, "freeze_defaults": True}, # nn_pipeline_params={"use_qnt": True, "use_te": False}, reader_params={ @@ -65,11 +81,6 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool) cml_task.connect(automl) - if task_type == "multilabel": - target_name = [x for x in test.columns if x.startswith("target")] - else: - target_name = test.columns[-1] - kwargs = {} if save_model: kwargs["path_to_save"] = "model"