diff --git a/cogdl/experiments.py b/cogdl/experiments.py index befe6dfb..6daa3a72 100644 --- a/cogdl/experiments.py +++ b/cogdl/experiments.py @@ -68,18 +68,6 @@ def run(self): return self.best_results -def examine_link_prediction(args, dataset): - if "link_prediction" in args.mw: - args.num_entities = dataset.data.num_nodes - # args.num_entities = len(torch.unique(self.data.edge_index)) - if dataset.data.edge_attr is not None: - args.num_rels = len(torch.unique(dataset.data.edge_attr)) - args.monitor = "mrr" - else: - args.monitor = "auc" - return args - - def set_best_config(args): configs = BEST_CONFIGS[args.task] if args.model not in configs: @@ -104,19 +92,20 @@ def train(args): # noqa: C901 set_random_seed(args.seed) model_name = args.model if isinstance(args.model, str) else args.model.model_name + dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__ + mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__ print( f""" -|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(args.dw) + len(args.mw))}| - *** Running (`{args.dataset}`, `{model_name}`, `{args.dw}`, `{args.mw}`) -|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(args.dw) + len(args.mw))}|""" +|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}| + *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`) +|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|""" ) if getattr(args, "use_best_config", False): args = set_best_config(args) # setup dataset and specify `num_features` and `num_classes` for model - args.monitor = "val_acc" if isinstance(args.dataset, Dataset): dataset = args.dataset else: @@ -142,8 +131,6 @@ def train(args): # noqa: C901 if hasattr(args, key) and key != "model": model_wrapper_args[key] = getattr(args, key) - args = examine_link_prediction(args, dataset) - # setup data_wrapper dataset_wrapper = dw_class(dataset, **data_wrapper_args) @@ -186,7 +173,7 @@ def train(args): # noqa: C901 optimizer_cfg["hidden_size"] = args.hidden_size # setup model_wrapper - if "embedding" in args.mw: + if isinstance(args.mw, str) and "embedding" in args.mw: model_wrapper = mw_class(model, **model_wrapper_args) else: model_wrapper = mw_class(model, optimizer_cfg, **model_wrapper_args) @@ -201,7 +188,6 @@ def train(args): # noqa: C901 save_emb_path=args.save_emb_path, load_emb_path=args.load_emb_path, cpu_inference=args.cpu_inference, - # monitor=args.monitor, progress_bar=args.progress_bar, distributed_training=args.distributed, checkpoint_path=args.checkpoint_path, diff --git a/cogdl/options.py b/cogdl/options.py index 3b0b8423..dec5c802 100644 --- a/cogdl/options.py +++ b/cogdl/options.py @@ -37,7 +37,6 @@ def get_parser(): parser.add_argument("--devices", default=[0], type=int, nargs="+", help="which GPU to use") parser.add_argument("--cpu", action="store_true", help="use CPU instead of CUDA") parser.add_argument("--cpu-inference", action="store_true", help="do validation and test in cpu") - # parser.add_argument("--monitor", type=str, default="val_acc") parser.add_argument("--distributed", action="store_true") parser.add_argument("--progress-bar", type=str, default="epoch") parser.add_argument("--local_rank", type=int, default=0) diff --git a/cogdl/trainer/trainer.py b/cogdl/trainer/trainer.py index 0e34d0a8..cdfd3feb 100644 --- a/cogdl/trainer/trainer.py +++ b/cogdl/trainer/trainer.py @@ -57,7 +57,6 @@ def __init__( distributed_inference: bool = False, master_addr: str = "localhost", master_port: int = 10086, - # monitor: str = "val_acc", early_stopping: bool = True, patience: int = 100, eval_step: int = 1, diff --git a/cogdl/wrappers/data_wrapper/__init__.py b/cogdl/wrappers/data_wrapper/__init__.py index f5aa1183..e764dab6 100644 --- a/cogdl/wrappers/data_wrapper/__init__.py +++ b/cogdl/wrappers/data_wrapper/__init__.py @@ -19,6 +19,8 @@ def register_data_wrapper_cls(cls): def fetch_data_wrapper(name): + if isinstance(name, type): + return name if name in SUPPORTED_DW: path = ".".join(SUPPORTED_DW[name].split(".")[:-1]) module = importlib.import_module(path) diff --git a/cogdl/wrappers/model_wrapper/__init__.py b/cogdl/wrappers/model_wrapper/__init__.py index 2727c80c..782579d5 100644 --- a/cogdl/wrappers/model_wrapper/__init__.py +++ b/cogdl/wrappers/model_wrapper/__init__.py @@ -20,6 +20,8 @@ def register_model_wrapper_cls(cls): def fetch_model_wrapper(name): + if isinstance(name, type): + return name if name in SUPPORTED_MW: path = ".".join(SUPPORTED_MW[name].split(".")[:-1]) module = importlib.import_module(path)