Skip to content

Commit

Permalink
add asserts for targets
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-rinchin committed Dec 14, 2024
1 parent 78e2362 commit 4e0d1d8
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions scripts/experiments/run_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool)
train = pd.read_csv(os.path.join(dataset_local_path, "train.csv"))
test = pd.read_csv(os.path.join(dataset_local_path, "test.csv"))

if task_type == "multilabel":
target_name = [x for x in test.columns if x.startswith("target")]
else:
target_name = test.columns[-1]

if task_type in ["binary", "multiclass", "multilabel"]:
assert (
train[target_name].nunique() == test[target_name].nunique()
), "train and test has different unique values."

assert min(train[target_name].nunique() > 1) is True, "Only one class present in train target."
assert min(test[target_name].nunique() > 1) is True, "Only one class present in test target."

assert train.isnull().values.any() is False, "train has nans in target."
assert test.isnull().values.any() is False, "test has nans in target."

task = Task(task_type)

# =================================== automl config:
Expand All @@ -50,9 +66,9 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool)
cpu_limit=cpu_limit,
memory_limit=memory_limit,
timeout=15 * 60,
general_params={
# "use_algos": [["mlp"]]
}, # ['nn', 'mlp', 'dense', 'denselight', 'resnet', 'snn', 'node', 'autoint', 'fttransformer'] or custom torch model
# general_params={
# "use_algos": [["mlp"]]
# }, # ['nn', 'mlp', 'dense', 'denselight', 'resnet', 'snn', 'node', 'autoint', 'fttransformer'] or custom torch model
# nn_params={"n_epochs": 10, "bs": 512, "num_workers": 0, "path_to_save": None, "freeze_defaults": True},
# nn_pipeline_params={"use_qnt": True, "use_te": False},
reader_params={
Expand All @@ -65,11 +81,6 @@ def main(dataset_name: str, cpu_limit: int, memory_limit: int, save_model: bool)

cml_task.connect(automl)

if task_type == "multilabel":
target_name = [x for x in test.columns if x.startswith("target")]
else:
target_name = test.columns[-1]

kwargs = {}
if save_model:
kwargs["path_to_save"] = "model"
Expand Down

0 comments on commit 4e0d1d8

Please sign in to comment.