diff --git a/easybuild/framework/easyblock.py b/easybuild/framework/easyblock.py index 99471eb48a..d6599e166e 100644 --- a/easybuild/framework/easyblock.py +++ b/easybuild/framework/easyblock.py @@ -540,6 +540,10 @@ def fetch_extension_sources(self, skip_checksums=False): 'options': ext_options, } + # if a particular easyblock is specified, make sure it's used + # (this is picked up by init_ext_instances) + ext_src['easyblock'] = ext_options.get('easyblock', None) + # construct dictionary with template values; # inherited from parent, except for name/version templates which are specific to this extension template_values = copy.deepcopy(self.cfg.template_values) @@ -2295,15 +2299,23 @@ def init_ext_instances(self): ext_name = ext['name'] self.log.debug("Creating class instance for extension %s...", ext_name) + # if a specific easyblock is specified for this extension, honor it; + # just passing this to get_easyblock_class is sufficient + easyblock = ext.get('easyblock', None) + if easyblock: + class_name = easyblock + mod_path = get_module_path(class_name) + else: + class_name = encode_class_name(ext_name) + mod_path = get_module_path(class_name, generic=False) + cls, inst = None, None - class_name = encode_class_name(ext_name) - mod_path = get_module_path(class_name, generic=False) - # try instantiating extension-specific class + # try instantiating extension-specific class, or honor specified easyblock try: # no error when importing class fails, in case we run into an existing easyblock # with a similar name (e.g., Perl Extension 'GO' vs 'Go' for which 'EB_Go' is available) - cls = get_easyblock_class(None, name=ext_name, error_on_failed_import=False, + cls = get_easyblock_class(easyblock, name=ext_name, error_on_failed_import=False, error_on_missing_easyblock=False) self.log.debug("Obtained class %s for extension %s", cls, ext_name) if cls is not None: diff --git a/test/framework/easyblock.py b/test/framework/easyblock.py index d6b2a1d57a..ed44889feb 100644 --- a/test/framework/easyblock.py +++ b/test/framework/easyblock.py @@ -993,6 +993,30 @@ def test_extensions_step(self): eb.close_log() os.remove(eb.logfile) + def test_init_extensions(self): + """Test creating extension instances.""" + + testdir = os.path.abspath(os.path.dirname(__file__)) + toy_ec_file = os.path.join(testdir, 'easyconfigs', 'test_ecs', 't', 'toy', 'toy-0.0-gompi-2018a-test.eb') + toy_ec_txt = read_file(toy_ec_file) + + test_ec = os.path.join(self.test_prefix, 'test.eb') + test_ec_txt = toy_ec_txt.replace("('barbar', '0.0', {", "('barbar', '0.0', {'easyblock': 'DummyExtension',") + write_file(test_ec, test_ec_txt) + ec = process_easyconfig(test_ec)[0] + eb = get_easyblock_instance(ec) + + eb.prepare_for_extensions() + eb.init_ext_instances() + ext_inst_class_names = [x.__class__.__name__ for x in eb.ext_instances] + expected = [ + 'Toy_Extension', # 'ls' extension + 'Toy_Extension', # 'bar' extension + 'DummyExtension', # 'barbar' extension + 'EB_toy', # 'toy' extension + ] + self.assertEqual(ext_inst_class_names, expected) + def test_skip_extensions_step(self): """Test the skip_extensions_step"""