Skip to content

Commit

Permalink
inventory/aws_ec2: allow multi-entries for one host
Browse files Browse the repository at this point in the history
Add an option to allow multiple duplicated entry for on single instance.

Closes: ansible-collections#950
  • Loading branch information
goneri committed Sep 19, 2022
1 parent a079a0d commit bb1d76b
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 22 deletions.
67 changes: 59 additions & 8 deletions plugins/inventory/aws_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@
type: str
default: '_'
required: False
allow_duplicated_hosts:
description:
- By default, only the first name that back the I(hostnames) list is returned.
- Turn this flag on if you don't mind having duplicated entries in the inventory
and you want to get all the hostnames that match.
type: bool
default: False
filters:
description:
- A dictionary of filter value pairs.
Expand Down Expand Up @@ -176,6 +183,9 @@
separator: '-' # Hostname will be aws-test_literal
prefix: 'aws'
# Returns all the hostnames for a given instance
allow_duplicated_hosts: False
# Example using constructed features to create groups and set ansible_host
plugin: aws_ec2
regions:
Expand Down Expand Up @@ -626,7 +636,7 @@ def _sanitize_hostname(self, hostname):
else:
return to_text(hostname)

def _get_hostname(self, instance, hostnames):
def _get_preferred_hostname(self, instance, hostnames):
'''
:param instance: an instance dict returned by boto3 ec2 describe_instances()
:param hostnames: a list of hostname destination variables in order of preference
Expand All @@ -635,14 +645,47 @@ def _get_hostname(self, instance, hostnames):
if not hostnames:
hostnames = ['dns-name', 'private-dns-name']

hostname = None
for preference in hostnames:
if isinstance(preference, dict):
if 'name' not in preference:
raise AnsibleError("A 'name' key must be defined in a hostnames dictionary.")
hostname = self._get_preferred_hostname(instance, [preference["name"]])
hostname_from_prefix = self._get_preferred_hostname(instance, [preference["prefix"]])
separator = preference.get("separator", "_")
if hostname and hostname_from_prefix and 'prefix' in preference:
hostname = hostname_from_prefix + separator + hostname
elif preference.startswith('tag:'):
tags = self._get_tag_hostname(preference, instance)
hostname = tags[0] if tags else None
else:
hostname = self._get_boto_attr_chain(preference, instance)
if hostname:
break
if hostname:
if ':' in to_text(hostname):
return self._sanitize_group_name((to_text(hostname)))
else:
return to_text(hostname)


def get_all_hostnames(self, instance, hostnames):
'''
:param instance: an instance dict returned by boto3 ec2 describe_instances()
:param hostnames: a list of hostname destination variables
:return all the candidats matching the expectation
'''
if not hostnames:
hostnames = ['dns-name', 'private-dns-name']

hostname = None
hostname_list = []
for preference in hostnames:
if isinstance(preference, dict):
if 'name' not in preference:
raise AnsibleError("A 'name' key must be defined in a hostnames dictionary.")
hostname = self._get_hostname(instance, [preference["name"]])
hostname_from_prefix = self._get_hostname(instance, [preference["prefix"]])
hostname = self.get_all_hostnames(instance, [preference["name"]])
hostname_from_prefix = self.get_all_hostnames(instance, [preference["prefix"]])
separator = preference.get("separator", "_")
if hostname and hostname_from_prefix and 'prefix' in preference:
hostname = hostname_from_prefix[0] + separator + hostname[0]
Expand Down Expand Up @@ -689,20 +732,27 @@ def _query(self, regions, include_filters, exclude_filters, strict_permissions):

return {'aws_ec2': instances}

def _populate(self, groups, hostnames):
def _populate(self, groups, hostnames, allow_duplicated_hosts=False):
for group in groups:
group = self.inventory.add_group(group)
self._add_hosts(hosts=groups[group], group=group, hostnames=hostnames)
self._add_hosts(
hosts=groups[group],
group=group,
hostnames=hostnames,
allow_duplicated_hosts=allow_duplicated_hosts)
self.inventory.add_child('all', group)

def _add_hosts(self, hosts, group, hostnames):
def _add_hosts(self, hosts, group, hostnames, allow_duplicated_hosts=False):
'''
:param hosts: a list of hosts to be added to a group
:param group: the name of the group to which the hosts belong
:param hostnames: a list of hostname destination variables in order of preference
'''
for host in hosts:
hostname_list = self._get_hostname(host, hostnames)
if allow_duplicated_hosts:
hostname_list = self.get_all_hostnames(host, hostnames)
else:
hostname_list = [self._get_preferred_hostname(host, hostnames)]

host = camel_dict_to_snake_dict(host, ignore_list=['Tags'])
host['tags'] = boto3_tag_list_to_ansible_dict(host.get('tags', []))
Expand Down Expand Up @@ -820,6 +870,7 @@ def parse(self, inventory, loader, path, cache=True):
exclude_filters = self.get_option('exclude_filters')
hostnames = self.get_option('hostnames')
strict_permissions = self.get_option('strict_permissions')
allow_duplicated_hosts = self.get_option('allow_duplicated_hosts')

cache_key = self.get_cache_key(path)
# false when refresh_cache or --flush-cache is used
Expand All @@ -839,7 +890,7 @@ def parse(self, inventory, loader, path, cache=True):
if not cache or cache_needs_update:
results = self._query(regions, include_filters, exclude_filters, strict_permissions)

self._populate(results, hostnames)
self._populate(results, hostnames, allow_duplicated_hosts=allow_duplicated_hosts)

# If the cache has expired/doesn't exist or if refresh_inventory/flush cache is used
# when the user is using caching, update the cached inventory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
assert:
that:
- "'aws_ec2' in groups"
- groups['aws_ec2'] | length == 2
- groups['aws_ec2'] | length == 1
- "'Tag1_Test1' in groups['aws_ec2']"
- "'Tag2_Test2' in groups['aws_ec2']"
- "'Tag2_Test2' not in groups['aws_ec2']"
- "'Tag1_Test1' in hostvars"
- "'Tag2_Test2' in hostvars"
- "'Tag2_Test2' not in hostvars"

always:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
assert:
that:
- "'aws_ec2' in groups"
- groups['aws_ec2'] | length == 2
- groups['aws_ec2'] | length == 1
- "'Test1' in groups['aws_ec2']"
- "'Test2' in groups['aws_ec2']"
- "'Test2' not in groups['aws_ec2']"
- "'Test1' in hostvars"
- "'Test2' in hostvars"
- "'Test2' not in hostvars"

always:

Expand Down
116 changes: 108 additions & 8 deletions tests/unit/plugins/inventory/test_aws_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import pytest
import datetime
from unittest.mock import MagicMock

from ansible.errors import AnsibleError
from ansible.parsing.dataloader import DataLoader
Expand Down Expand Up @@ -108,9 +109,25 @@
}


@pytest.fixture(scope="module")
@pytest.fixture()
def inventory():
return InventoryModule()
inventory = InventoryModule()
inventory._options = {
"aws_profile": "first_precedence",
"aws_access_key": "test_access_key",
"aws_secret_key": "test_secret_key",
"aws_security_token": "test_security_token",
"iam_role_arn": None,
"use_contrib_script_compatible_ec2_tag_keys": False,
"hostvars_prefix": "",
"hostvars_suffix": "",
"strict": True,
"compose": {},
"groups": {},
"keyed_groups": [],
}
inventory.inventory = MagicMock()
return inventory


def test_compile_values(inventory):
Expand Down Expand Up @@ -139,21 +156,51 @@ def test_boto3_conn(inventory):
assert "Insufficient credentials found" in error_message


def test_get_hostname_default(inventory):
def testget_all_hostnames_default(inventory):
instance = instances['Instances'][0]
assert inventory.get_all_hostnames(instance, hostnames=None) == ["ec2-12-345-67-890.compute-1.amazonaws.com", "ip-098-76-54-321.ec2.internal"]


def testget_all_hostnames(inventory):
hostnames = ['ip-address', 'dns-name']
instance = instances['Instances'][0]
assert inventory.get_all_hostnames(instance, hostnames) == ["12.345.67.890", "ec2-12-345-67-890.compute-1.amazonaws.com"]


def testget_all_hostnames_dict(inventory):
hostnames = [{'name': 'private-ip-address', 'separator': '_', 'prefix': 'tag:Name'}]
instance = instances['Instances'][0]
assert inventory.get_all_hostnames(instance, hostnames) == ["aws_ec2_098.76.54.321"]


def testget_all_hostnames_with_2_tags(inventory):
hostnames = ['tag:ansible', 'tag:Name']
instance = instances['Instances'][0]
assert inventory.get_all_hostnames(instance, hostnames) == ["test", "aws_ec2"]



def test_get_preferred_hostname_default(inventory):
instance = instances['Instances'][0]
assert inventory._get_hostname(instance, hostnames=None)[0] == "ec2-12-345-67-890.compute-1.amazonaws.com"
assert inventory._get_preferred_hostname(instance, hostnames=None) == "ec2-12-345-67-890.compute-1.amazonaws.com"


def test_get_hostname(inventory):
def test_get_preferred_hostname(inventory):
hostnames = ['ip-address', 'dns-name']
instance = instances['Instances'][0]
assert inventory._get_hostname(instance, hostnames)[0] == "12.345.67.890"
assert inventory._get_preferred_hostname(instance, hostnames) == "12.345.67.890"


def test_get_hostname_dict(inventory):
def test_get_preferred_hostname_dict(inventory):
hostnames = [{'name': 'private-ip-address', 'separator': '_', 'prefix': 'tag:Name'}]
instance = instances['Instances'][0]
assert inventory._get_hostname(instance, hostnames)[0] == "aws_ec2_098.76.54.321"
assert inventory._get_preferred_hostname(instance, hostnames) == "aws_ec2_098.76.54.321"


def test_get_preferred_hostname_with_2_tags(inventory):
hostnames = ['tag:ansible', 'tag:Name']
instance = instances['Instances'][0]
assert inventory._get_preferred_hostname(instance, hostnames) == "test"


def test_set_credentials(inventory):
Expand Down Expand Up @@ -216,3 +263,56 @@ def test_include_filters_with_filter_and_include_filters(inventory):
assert inventory.build_include_filters() == [
{"from_filter": 1},
{"from_include_filter": "bar"}]


def test_add_host_empty_hostnames(inventory):
hosts = [
{
"Placement": {
"AvailabilityZone": "us-east-1a",
},
"PublicDnsName": "ip-10-85-0-4.ec2.internal"
},
]
inventory._add_hosts(hosts, "aws_ec2", [])
inventory.inventory.add_host.assert_called_with("ip-10-85-0-4.ec2.internal", group="aws_ec2")


def test_add_host_with_hostnames_and_one_criteria(inventory):
hosts = [{
"Placement": {
"AvailabilityZone": "us-east-1a",
},
"PublicDnsName": "sample-host",
}]

inventory._add_hosts(hosts, "aws_ec2", hostnames=["tag:Name", "private-dns-name", "dns-name"])
assert inventory.inventory.add_host.call_count == 1
inventory.inventory.add_host.assert_called_with("sample-host", group="aws_ec2")

def test_add_host_with_hostnames_and_two_matching_criteria(inventory):
hosts = [{
"Placement": {
"AvailabilityZone": "us-east-1a",
},
"PublicDnsName": "name-from-PublicDnsName",
"Tags": [{"Value": "name-from-tag-Name", "Key": "Name"}],
}]

inventory._add_hosts(hosts, "aws_ec2", hostnames=["tag:Name", "private-dns-name", "dns-name"])
assert inventory.inventory.add_host.call_count == 1
inventory.inventory.add_host.assert_called_with("name-from-tag-Name", group="aws_ec2")

def test_add_host_with_hostnames_and_two_matching_criteria_and_allow_duplicated_hosts(inventory):
hosts = [{
"Placement": {
"AvailabilityZone": "us-east-1a",
},
"PublicDnsName": "name-from-PublicDnsName",
"Tags": [{"Value": "name-from-tag-Name", "Key": "Name"}],
}]

inventory._add_hosts(hosts, "aws_ec2", hostnames=["tag:Name", "private-dns-name", "dns-name"], allow_duplicated_hosts=True)
assert inventory.inventory.add_host.call_count == 2
inventory.inventory.add_host.assert_any_call("name-from-PublicDnsName", group="aws_ec2")
inventory.inventory.add_host.assert_any_call("name-from-tag-Name", group="aws_ec2")

0 comments on commit bb1d76b

Please sign in to comment.