From 3da781a62bc2044561c85f0d68bbe724c89ad9c8 Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Tue, 14 Sep 2021 06:13:27 +0200 Subject: [PATCH] Fix undefined names in Python code (#599) * Update pytorch_tabnet.py $ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics` ``` ./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp' self.independ.append(GLU(inp, out_dim, vbs=vbs)) ^ ./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train' run_task(task_train, self.task_pool, experiment_name=self.experiment_name) ^ 2 F821 undefined name 'task_train' 2 ``` * Fix undefined names in Python code * from qlib.model.trainer import task_train --- examples/model_rolling/task_manager_rolling.py | 2 +- qlib/contrib/model/pytorch_tabnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 844f181980..091a87862f 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,7 +17,7 @@ from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b05d9a026d..bd8f085ec1 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -564,7 +564,7 @@ def __init__(self, inp_dim, out_dim, shared, n_ind, vbs): self.shared = None self.independ = nn.ModuleList() if first: - self.independ.append(GLU(inp, out_dim, vbs=vbs)) + self.independ.append(GLU(inp_dim, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) self.scale = float(np.sqrt(0.5))