diff --git a/netjsonconfig/backends/base/backend.py b/netjsonconfig/backends/base/backend.py index d95bb4c87..04bdff6fa 100644 --- a/netjsonconfig/backends/base/backend.py +++ b/netjsonconfig/backends/base/backend.py @@ -21,8 +21,9 @@ class BaseBackend(object): schema = None FILE_SECTION_DELIMITER = '# ---------- files ---------- #' intermediate_data = None + list_identifiers = [] - def __init__(self, config, templates=[], context={}): + def __init__(self, config, templates=None, context=None): """ :param config: ``dict`` containing valid **NetJSON DeviceConfiguration** :param templates: ``list`` containing **NetJSON** dictionaries that will be @@ -54,17 +55,17 @@ def _merge_config(self, config, templates): """ Merges config with templates """ + if not templates: + return config # type check if not isinstance(templates, list): raise TypeError('templates argument must be an instance of list') - # merge any present template with main configuration + # merge templates with main configuration base_config = {} - for template in templates: - template = self._load(template) - base_config = merge_config(base_config, template) + base_config = merge_config(base_config, self._load(template), self.list_identifiers) if base_config: - return merge_config(base_config, config) + return merge_config(base_config, config, self.list_identifiers) return config def _evaluate_vars(self, config, context): diff --git a/netjsonconfig/backends/openwisp/openwisp.py b/netjsonconfig/backends/openwisp/openwisp.py index 930f3106b..440d1116b 100644 --- a/netjsonconfig/backends/openwisp/openwisp.py +++ b/netjsonconfig/backends/openwisp/openwisp.py @@ -25,10 +25,11 @@ def _sanitize_radios(self): for radio in self.config.get('radios', []): radio.setdefault('disabled', False) - def _render_template(self, template, context={}): + def _render_template(self, template, context=None): openwisp_env = Environment(loader=PackageLoader(self.__module__, 'templates'), trim_blocks=True) template = openwisp_env.get_template(template) + context = context or {} return template.render(**context) def _add_unique_file(self, item): diff --git a/netjsonconfig/backends/openwrt/openwrt.py b/netjsonconfig/backends/openwrt/openwrt.py index 1ce5647e4..70433ceb6 100644 --- a/netjsonconfig/backends/openwrt/openwrt.py +++ b/netjsonconfig/backends/openwrt/openwrt.py @@ -28,6 +28,7 @@ class OpenWrt(BaseBackend): converters.Default, ] renderer = OpenWrtRenderer + list_identifiers = ['name', 'config_value', 'id'] def _generate_contents(self, tar): """ diff --git a/netjsonconfig/utils.py b/netjsonconfig/utils.py index d76665b23..1d4ad6462 100644 --- a/netjsonconfig/utils.py +++ b/netjsonconfig/utils.py @@ -5,7 +5,7 @@ import six -def merge_config(template, config): +def merge_config(template, config, list_identifiers=None): """ Merges ``config`` on top of ``template``. @@ -19,6 +19,7 @@ def merge_config(template, config): :param template: template ``dict`` :param config: config ``dict`` + :param list_identifiers: ``list`` or ``None`` :returns: merged ``dict`` """ result = template.copy() @@ -27,13 +28,13 @@ def merge_config(template, config): node = result.get(key, {}) result[key] = merge_config(node, value) elif isinstance(value, list) and isinstance(result.get(key), list): - result[key] = merge_list(result[key], value) + result[key] = merge_list(result[key], value, list_identifiers) else: result[key] = value return result -def merge_list(list1, list2, identifiers=['name', 'config_value', 'id']): +def merge_list(list1, list2, identifiers=None): """ Merges ``list2`` on top of ``list1``. @@ -43,10 +44,12 @@ def merge_list(list1, list2, identifiers=['name', 'config_value', 'id']): The remaining elements will be summed in order to create a list which contains elements of both lists. - :param list1: list from template - :param list2: list from config + :param list1: ``list`` from template + :param list2: ``list from config + :param identifiers: ``list`` or ``None`` :returns: merged ``list`` """ + identifiers = identifiers or [] dict_map = {'list1': OrderedDict(), 'list2': OrderedDict()} counter = 1 for list_ in [list1, list2]: diff --git a/tests/openwrt/test_backend.py b/tests/openwrt/test_backend.py index 098e3122a..499e3c667 100644 --- a/tests/openwrt/test_backend.py +++ b/tests/openwrt/test_backend.py @@ -209,7 +209,7 @@ def test_templates_type_error(self): } } with self.assertRaises(TypeError): - OpenWrt(config, templates={}) + OpenWrt(config, templates={'a': 'a'}) def test_templates_config_error(self): config = { diff --git a/tests/test_utils.py b/tests/test_utils.py index 17d73239b..f65c7adc3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -173,7 +173,7 @@ def test_evaluate_vars_one_char(self): def test_merge_list_override(self): template = [{"name": "test1", "tx": 1}] config = [{"name": "test1", "tx": 2}] - result = merge_list(template, config) + result = merge_list(template, config, ['name']) self.assertEqual(result, config) def test_merge_list_union_and_override(self): @@ -182,7 +182,7 @@ def test_merge_list_union_and_override(self): {"id": "test1", "a": "0", "b": "b"}, {"id": "test2", "c": "c"} ] - result = merge_list(template, config) + result = merge_list(template, config, ['id']) self.assertEqual(result, [ {"id": "test1", "a": "0", "b": "b"}, {"id": "test2", "c": "c"} @@ -191,7 +191,7 @@ def test_merge_list_union_and_override(self): def test_merge_list_config_value(self): template = [{"config_value": "test1", "tx": 1}] config = [{"config_value": "test1", "tx": 2}] - result = merge_list(template, config) + result = merge_list(template, config, ['config_value']) self.assertEqual(result, config) def test_get_copy_default(self):