Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dns] Implement config and show commands for static DNS. #2737

Merged
merged 3 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions config/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

import click
from swsscommon.swsscommon import ConfigDBConnector
from .validated_config_db_connector import ValidatedConfigDBConnector
import ipaddress


ADHOC_VALIDATION = True
NAMESERVERS_MAX_NUM = 3


def to_ip_address(address):
"""Check if the given IP address is valid"""
try:
ip = ipaddress.ip_address(address)

if ADHOC_VALIDATION:
if ip.is_reserved or ip.is_multicast or ip.is_loopback:
return

invalid_ips = [
ipaddress.IPv4Address('0.0.0.0'),
ipaddress.IPv4Address('255.255.255.255'),
ipaddress.IPv6Address("0::0"),
ipaddress.IPv6Address("0::1")
]
if ip in invalid_ips:
return

return ip
except Exception:
return


def get_nameservers(db):
nameservers = db.get_table('DNS_NAMESERVER')
return [ipaddress.ip_address(ip) for ip in nameservers]


# 'dns' group ('config dns ...')
@click.group()
@click.pass_context
def dns(ctx):
"""Static DNS configuration"""
config_db = ValidatedConfigDBConnector(ConfigDBConnector())
config_db.connect()
ctx.obj = {'db': config_db}


# dns nameserver config
@dns.group('nameserver')
@click.pass_context
def nameserver(ctx):
"""Static DNS nameservers configuration"""
pass


# dns nameserver add
@nameserver.command('add')
@click.argument('ip_address_str', metavar='<ip_address>', required=True)
@click.pass_context
def add_dns_nameserver(ctx, ip_address_str):
"""Add static DNS nameserver entry"""
ip_address = to_ip_address(ip_address_str)
if not ip_address:
ctx.fail(f"{ip_address_str} invalid nameserver ip address")

db = ctx.obj['db']

nameservers = get_nameservers(db)
if ip_address in nameservers:
ctx.fail(f"{ip_address} nameserver is already configured")

if len(nameservers) >= NAMESERVERS_MAX_NUM:
ctx.fail(f"The maximum number ({NAMESERVERS_MAX_NUM}) of nameservers exceeded.")

db.set_entry('DNS_NAMESERVER', ip_address, {})

# dns nameserver delete
@nameserver.command('del')
@click.argument('ip_address_str', metavar='<ip_address>', required=True)
@click.pass_context
def del_dns_nameserver(ctx, ip_address_str):
"""Delete static DNS nameserver entry"""

ip_address = to_ip_address(ip_address_str)
if not ip_address:
ctx.fail(f"{ip_address_str} invalid nameserver ip address")

db = ctx.obj['db']

nameservers = get_nameservers(db)
if ip_address not in nameservers:
ctx.fail(f"DNS nameserver {ip_address} is not configured")

db.set_entry('DNS_NAMESERVER', ip_address, None)
4 changes: 4 additions & 0 deletions config/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .config_mgmt import ConfigMgmtDPB, ConfigMgmt
from . import mclag
from . import syslog
from . import dns

# mock masic APIs for unit test
try:
Expand Down Expand Up @@ -1200,6 +1201,9 @@ def config(ctx):
# syslog module
config.add_command(syslog.syslog)

# DNS module
config.add_command(dns.dns)

@config.command()
@click.option('-y', '--yes', is_flag=True, callback=_abort_if_false,
expose_value=False, prompt='Existing files will be overwritten, continue?')
Expand Down
30 changes: 30 additions & 0 deletions show/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import click
import utilities_common.cli as clicommon
from natsort import natsorted
from tabulate import tabulate

from swsscommon.swsscommon import ConfigDBConnector
from utilities_common.cli import pass_db


# 'dns' group ("show dns ...")
@click.group(cls=clicommon.AliasedGroup)
@click.pass_context
def dns(ctx):
"""Show details of the static DNS configuration """
config_db = ConfigDBConnector()
config_db.connect()
ctx.obj = {'db': config_db}


# 'nameserver' subcommand ("show dns nameserver")
@dns.command()
@click.pass_context
def nameserver(ctx):
""" Show static DNS configuration """
header = ["Nameserver"]
db = ctx.obj['db']

nameservers = db.get_table('DNS_NAMESERVER')

click.echo(tabulate([(ns,) for ns in nameservers.keys()], header, tablefmt='simple', stralign='right'))
2 changes: 2 additions & 0 deletions show/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from . import warm_restart
from . import plugins
from . import syslog
from . import dns

# Global Variables
PLATFORM_JSON = 'platform.json'
Expand Down Expand Up @@ -289,6 +290,7 @@ def cli(ctx):
cli.add_command(vxlan.vxlan)
cli.add_command(system_health.system_health)
cli.add_command(warm_restart.warm_restart)
cli.add_command(dns.dns)

# syslog module
cli.add_command(syslog.syslog)
Expand Down
193 changes: 193 additions & 0 deletions tests/dns_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os
import pytest

from click.testing import CliRunner

import config.main as config
import show.main as show
from utilities_common.db import Db

test_path = os.path.dirname(os.path.abspath(__file__))

dns_show_nameservers_header = """\
Nameserver
------------
"""

dns_show_nameservers = """\
Nameserver
--------------------
1.1.1.1
2001:4860:4860::8888
"""

class TestDns(object):

valid_nameservers = (
("1.1.1.1",),
("1.1.1.1", "8.8.8.8", "10.10.10.10",),
("1.1.1.1", "2001:4860:4860::8888"),
("2001:4860:4860::8888", "2001:4860:4860::8844", "2001:4860:4860::8800")
)

invalid_nameservers = (
"0.0.0.0",
"255.255.255.255",
"224.0.0.0",
"0::0",
"0::1",
"1.1.1.x",
"2001:4860:4860.8888",
"ff02::1"
)

config_dns_ns_add = config.config.commands["dns"].commands["nameserver"].commands["add"]
config_dns_ns_del = config.config.commands["dns"].commands["nameserver"].commands["del"]
show_dns_ns = show.cli.commands["dns"].commands["nameserver"]

@classmethod
def setup_class(cls):
print("SETUP")
os.environ["UTILITIES_UNIT_TESTING"] = "1"

@classmethod
def teardown_class(cls):
os.environ['UTILITIES_UNIT_TESTING'] = "0"
print("TEARDOWN")

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_add_del_with_valid_ip_addresses(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip not in db.cfgdb.get_table('DNS_NAMESERVER')

@pytest.mark.parametrize('nameserver', invalid_nameservers)
def test_dns_config_nameserver_add_del_with_invalid_ip_addresses(self, nameserver):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

# config dns nameserver add <nameserver>
result = runner.invoke(self.config_dns_ns_add, [nameserver], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "invalid nameserver ip address" in result.output

# config dns nameserver del <nameserver>
result = runner.invoke(self.config_dns_ns_del, [nameserver], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "invalid nameserver ip address" in result.output

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_add_existing_ip(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

# Execute command once more
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "nameserver is already configured" in result.output

# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

@pytest.mark.parametrize('nameservers', valid_nameservers)
def test_dns_config_nameserver_del_unexisting_ip(self, nameservers):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "is not configured" in result.output

def test_dns_config_nameserver_add_max_number(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

nameservers = ("1.1.1.1", "2.2.2.2", "3.3.3.3")
for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, ["4.4.4.4"], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code != 0
assert "nameservers exceeded" in result.output

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0

def test_dns_show_nameserver_empty_table(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

# show dns nameserver
result = runner.invoke(self.show_dns_ns, [], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert result.output == dns_show_nameservers_header

def test_dns_show_nameserver(self):
db = Db()
runner = CliRunner()
obj = {'db': db.cfgdb}

nameservers = ("1.1.1.1", "2001:4860:4860::8888")

for ip in nameservers:
# config dns nameserver add <ip>
result = runner.invoke(self.config_dns_ns_add, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip in db.cfgdb.get_table('DNS_NAMESERVER')

# show dns nameserver
result = runner.invoke(self.show_dns_ns, [], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert result.output == dns_show_nameservers

for ip in nameservers:
# config dns nameserver del <ip>
result = runner.invoke(self.config_dns_ns_del, [ip], obj=obj)
print(result.exit_code, result.output)
assert result.exit_code == 0
assert ip not in db.cfgdb.get_table('DNS_NAMESERVER')