diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 98e311c42113..ce23e7b89329 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -442,8 +442,7 @@ def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = N instance = hydra.utils.instantiate(config=config) else: instance = None - imported_cls_tb = None - instance_init_error = None + prev_error = "" # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config["target"] # No guarantee that this is a omegaconf class @@ -451,36 +450,23 @@ def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = N try: # try to import the target class imported_cls = import_class_by_path(target_cls) - except Exception: - imported_cls_tb = traceback.format_exc() - - # try instantiating model with target class - if imported_cls is not None: # if calling class (cls) is subclass of imported class, # use subclass instead if issubclass(cls, imported_cls): imported_cls = cls - - try: - accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls) - if accepts_trainer: - instance = imported_cls(cfg=config, trainer=trainer) - else: - instance = imported_cls(cfg=config) - - except Exception as e: - imported_cls_tb = traceback.format_exc() - instance_init_error = str(e) - instance = None + accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls) + if accepts_trainer: + instance = imported_cls(cfg=config, trainer=trainer) + else: + instance = imported_cls(cfg=config) + except Exception as e: + # record previous error + tb = traceback.format_exc() + prev_error = f"Model instantiation failed!\nTarget class:\t{target_cls}" f"\nError(s):\t{e}\n{tb}" + logging.debug(prev_error + "\nFalling back to `cls`.") # target class resolution was unsuccessful, fall back to current `cls` if instance is None: - if imported_cls_tb is not None: - logging.info( - f"Model instantiation from target class {target_cls} failed with following error.\n" - f"Falling back to `cls`.\n" - f"{imported_cls_tb}" - ) try: accepts_trainer = Serialization._inspect_signature_for_trainer(cls) if accepts_trainer: @@ -489,9 +475,9 @@ def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = N instance = cls(cfg=config) except Exception as e: - if imported_cls_tb is not None: - logging.error(f"Instance failed restore_from due to: {instance_init_error}") - logging.error(f"{imported_cls_tb}") + # report saved errors, if any, and raise + if prev_error: + logging.error(prev_error) raise e if not hasattr(instance, '_cfg'):