From 13a1925b1e2082da5e4d0b1e99121bfa2255c5b9 Mon Sep 17 00:00:00 2001 From: Vadim Smirnov Date: Sat, 2 Nov 2024 10:45:18 +0000 Subject: [PATCH] Custom nn params fix --- lightautoml/ml_algo/dl_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightautoml/ml_algo/dl_model.py b/lightautoml/ml_algo/dl_model.py index e36ad870..be2b5d6e 100644 --- a/lightautoml/ml_algo/dl_model.py +++ b/lightautoml/ml_algo/dl_model.py @@ -277,7 +277,7 @@ def _infer_params(self): params["metric"] = self.task.losses["torch"].metric_func if params["bert_name"] is None and params["use_text"]: - params["bert_name"] = _model_name_by_lang[params["lang"]] + params["bert_name"] = _model_name_by_lang[params.get("lang", "en")] is_text = (len(params["text_features"]) > 0) and (params["use_text"]) and (params["device"].type == "cuda") is_cat = (len(params["cat_features"]) > 0) and (params["use_cat"]) @@ -309,7 +309,7 @@ def _infer_params(self): net_params={ "task": self.task, "cont_embedder_": cont_embedder_by_name.get(params["cont_embedder"], LinearEmbedding) - if input_type_by_name[params["model"]] == "seq" and is_cont + if input_type_by_name.get(params["model"], "flat") == "seq" and is_cont else cont_embedder_by_name_flat.get(params["cont_embedder"], ContEmbedder) if is_cont else None, @@ -323,7 +323,7 @@ def _infer_params(self): if is_cont else None, "cat_embedder_": cat_embedder_by_name.get(params["cat_embedder"], BasicCatEmbedding) - if input_type_by_name[params["model"]] == "seq" and is_cat + if input_type_by_name.get(params["model"], "flat") == "seq" and is_cat else cat_embedder_by_name_flat.get(params["cat_embedder"], CatEmbedder) if is_cat else None,