Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hostcfgd] Move hostcfgd back to ConfigDBConnector for subscribing to updates #10168

Merged
merged 13 commits into from
Apr 7, 2022
Merged
147 changes: 84 additions & 63 deletions src/sonic-host-services/scripts/hostcfgd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
alexrallen marked this conversation as resolved.
Show resolved Hide resolved
#!/usr/bin/python3

import ast
import copy
Expand All @@ -11,8 +11,7 @@ import signal

import jinja2
from sonic_py_common import device_info
from swsscommon.swsscommon import SubscriberStateTable, DBConnector, Select
from swsscommon.swsscommon import ConfigDBConnector, TableConsumable
from swsscommon.swsscommon import ConfigDBConnector

# FILE
PAM_AUTH_CONF = "/etc/pam.d/common-auth-sonic"
Expand Down Expand Up @@ -195,6 +194,20 @@ class FeatureHandler(object):
else:
self.resync_feature_state(self._cached_config[feature_name])

def update_all_features_config(self, feature_table):
for feature_name in feature_table.keys():
if not feature_name:
syslog.syslog(syslog.LOG_WARNING, "Feature is None")
continue

feature = Feature(feature_name, feature_table[feature_name], self._device_config)
self._cached_config.setdefault(feature_name, feature)

self.update_feature_auto_restart(feature, feature_name)

self.update_feature_state(feature)
self.resync_feature_state(feature)

def sync_state_field(self):
"""
Summary:
Expand Down Expand Up @@ -386,6 +399,10 @@ class Iptables(object):
'''
return (isinstance(key, tuple))

def load(self, lpbk_table):
for row in lpbk_table:
self.iptables_handler(row, lpbk_table[row])

def command(self, chain, ip, ver, op):
cmd = 'iptables' if ver == '4' else 'ip6tables'
cmd += ' -t mangle --{} {} -p tcp --tcp-flags SYN SYN'.format(op, chain)
Expand Down Expand Up @@ -892,6 +909,15 @@ class NtpCfg(object):
self.ntp_global = {}
self.ntp_servers = set()

def load(self, ntp_global_conf, ntp_server_conf):
syslog.syslog(syslog.LOG_INFO, "NtpCfg load ...")

for row in ntp_global_conf:
self.ntp_global_update(row, ntp_global_conf[row], load=True)

# Force reload on init
self.ntp_server_update(0, None, load=True)

def handle_ntp_source_intf_chg(self, intf_name):
# if no ntp server configured, do nothing
if not self.ntp_servers:
Expand All @@ -905,7 +931,7 @@ class NtpCfg(object):
cmd = 'systemctl restart ntp-config'
run_cmd(cmd)

def ntp_global_update(self, key, data):
def ntp_global_update(self, key, data, load=False):
alexrallen marked this conversation as resolved.
Show resolved Hide resolved
syslog.syslog(syslog.LOG_INFO, 'NTP GLOBAL Update')
orig_src = self.ntp_global.get('src_intf', '')
orig_src_set = set(orig_src.split(";"))
Expand All @@ -918,6 +944,9 @@ class NtpCfg(object):
# Update the Local Cache
self.ntp_global = data

# If initial load don't restart daemon
if load: return

# check if ntp server configured, if not, do nothing
if not self.ntp_servers:
syslog.syslog(syslog.LOG_INFO, "No ntp server when global config change, do nothing")
Expand All @@ -934,16 +963,19 @@ class NtpCfg(object):
cmd = 'service ntp restart'
run_cmd(cmd)

def ntp_server_update(self, key, op):
def ntp_server_update(self, key, op, load=False):
alexrallen marked this conversation as resolved.
Show resolved Hide resolved
syslog.syslog(syslog.LOG_INFO, 'ntp server update key {}'.format(key))

restart_config = False
if op == "SET" and key not in self.ntp_servers:
restart_config = True
self.ntp_servers.add(key)
elif op == "DEL" and key in self.ntp_servers:
if not load:
if op == "SET" and key not in self.ntp_servers:
restart_config = True
self.ntp_servers.add(key)
elif op == "DEL" and key in self.ntp_servers:
restart_config = True
self.ntp_servers.remove(key)
else:
restart_config = True
self.ntp_servers.remove(key)

if restart_config:
cmd = 'systemctl restart ntp-config'
Expand All @@ -956,14 +988,8 @@ class HostConfigDaemon:
# before moving forward
self.config_db = ConfigDBConnector()
self.config_db.connect(wait_for_init=True, retry_on=True)
self.dbconn = DBConnector(CFG_DB, 0)
self.selector = Select()
syslog.syslog(syslog.LOG_INFO, 'ConfigDB connect success')

self.select = Select()
self.callbacks = dict()
self.subscriber_map = dict()

# Load DEVICE metadata configurations
self.device_config = {}
self.device_config['DEVICE_METADATA'] = self.config_db.get_table('DEVICE_METADATA')
Expand All @@ -987,15 +1013,30 @@ class HostConfigDaemon:
# Initialize AAACfg
self.hostname_cache=""
self.aaacfg = AaaCfg()

# Initialize cache
self.cache = {}


def load(self):
features = self.config_db.get_table('FEATURE')
aaa = self.config_db.get_table('AAA')
tacacs_global = self.config_db.get_table('TACPLUS')
tacacs_server = self.config_db.get_table('TACPLUS_SERVER')
radius_global = self.config_db.get_table('RADIUS')
radius_server = self.config_db.get_table('RADIUS_SERVER')
lpbk_table = self.config_db.get_table('LOOPBACK_INTERFACE')
ntp_server = self.config_db.get_table('NTP_SERVER')
ntp_global = self.config_db.get_table('NTP')

self.cache = {'FEATURES': features, 'AAA': aaa, 'TACPLUS': tacacs_global, 'TACPLUS_SERVER': tacacs_server,
'RADIUS': radius_global, 'RADIUS_SERVER': radius_server, 'LOOPBACK_INTERFACE': lpbk_table, 'NTP_SERVER': ntp_server,
'NTP': ntp_global}

self.feature_handler.update_all_features_config(features)
alexrallen marked this conversation as resolved.
Show resolved Hide resolved
self.aaacfg.load(aaa, tacacs_global, tacacs_server, radius_global, radius_server)
self.iptables.load(lpbk_table)
self.ntpcfg.load(ntp_global, ntp_server)

try:
dev_meta = self.config_db.get_table('DEVICE_METADATA')
Expand Down Expand Up @@ -1097,68 +1138,48 @@ class HostConfigDaemon:
systemctl_cmd = "sudo systemctl is-system-running --wait --quiet"
subprocess.call(systemctl_cmd, shell=True)

def subscribe(self, table, callback, pri):
try:
if table not in self.callbacks:
self.callbacks[table] = []
subscriber = SubscriberStateTable(self.dbconn, table, TableConsumable.DEFAULT_POP_BATCH_SIZE, pri)
self.selector.addSelectable(subscriber) # Add to the Selector
self.subscriber_map[subscriber.getFd()] = (subscriber, table) # Maintain a mapping b/w subscriber & fd
def register_callbacks(self):

self.callbacks[table].append(callback)
except Exception as err:
syslog.syslog(syslog.LOG_ERR, "Subscribe to table {} failed with error {}".format(table, err))
def make_callback(func):
def callback(table, key, data):
if data is None:
op = "DEL"
else:
op = "SET"
return func(key, op, data)
return callback

def register_callbacks(self):
self.subscribe('KDUMP', lambda table, key, op, data: self.kdump_handler(key, op, data), HOSTCFGD_MAX_PRI)
self.config_db.subscribe('KDUMP', make_callback(self.kdump_handler))
# Handle FEATURE updates before other tables
self.subscribe('FEATURE', lambda table, key, op, data: self.feature_handler.handle(key, op, data), HOSTCFGD_MAX_PRI-1)
self.config_db.subscribe('FEATURE', make_callback(self.feature_handler.handle))
# Handle AAA, TACACS and RADIUS related tables
self.subscribe('AAA', lambda table, key, op, data: self.aaa_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('TACPLUS', lambda table, key, op, data: self.tacacs_global_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('TACPLUS_SERVER', lambda table, key, op, data: self.tacacs_server_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('RADIUS', lambda table, key, op, data: self.radius_global_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.subscribe('RADIUS_SERVER', lambda table, key, op, data: self.radius_server_handler(key, op, data), HOSTCFGD_MAX_PRI-2)
self.config_db.subscribe('AAA', make_callback(self.aaa_handler))
self.config_db.subscribe('TACPLUS', make_callback(self.tacacs_global_handler))
self.config_db.subscribe('TACPLUS_SERVER', make_callback(self.tacacs_server_handler))
self.config_db.subscribe('RADIUS', make_callback(self.radius_global_handler))
self.config_db.subscribe('RADIUS_SERVER', make_callback(self.radius_server_handler))
# Handle IPTables configuration
self.subscribe('LOOPBACK_INTERFACE', lambda table, key, op, data: self.lpbk_handler(key, op, data), HOSTCFGD_MAX_PRI-3)
self.config_db.subscribe('LOOPBACK_INTERFACE', make_callback(self.lpbk_handler))
# Handle NTP & NTP_SERVER updates
self.subscribe('NTP', lambda table, key, op, data: self.ntp_global_handler(key, op, data), HOSTCFGD_MAX_PRI-4)
self.subscribe('NTP_SERVER', lambda table, key, op, data: self.ntp_server_handler(key, op, data), HOSTCFGD_MAX_PRI-4)
self.config_db.subscribe('NTP', make_callback(self.ntp_global_handler))
self.config_db.subscribe('NTP_SERVER', make_callback(self.ntp_server_handler))
# Handle updates to src intf changes in radius
self.subscribe('MGMT_INTERFACE', lambda table, key, op, data: self.mgmt_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('VLAN_INTERFACE', lambda table, key, op, data: self.vlan_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('VLAN_SUB_INTERFACE', lambda table, key, op, data: self.vlan_sub_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('PORTCHANNEL_INTERFACE', lambda table, key, op, data: self.portchannel_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.subscribe('INTERFACE', lambda table, key, op, data: self.phy_intf_handler(key, op, data), HOSTCFGD_MAX_PRI-5)
self.config_db.subscribe('MGMT_INTERFACE', make_callback(self.mgmt_intf_handler))
self.config_db.subscribe('VLAN_INTERFACE', make_callback(self.vlan_intf_handler))
self.config_db.subscribe('VLAN_SUB_INTERFACE', make_callback(self.vlan_sub_intf_handler))
self.config_db.subscribe('PORTCHANNEL_INTERFACE', make_callback(self.portchannel_intf_handler))
self.config_db.subscribe('INTERFACE', make_callback(self.phy_intf_handler))

syslog.syslog(syslog.LOG_INFO,
"Waiting for systemctl to finish initialization")
self.wait_till_system_init_done()
syslog.syslog(syslog.LOG_INFO,
"systemctl has finished initialization -- proceeding ...")

def start(self):
while True:
state, selectable_ = self.selector.select(DEFAULT_SELECT_TIMEOUT)
if state == self.selector.TIMEOUT:
continue
elif state == self.selector.ERROR:
syslog.syslog(syslog.LOG_ERR,
"error returned by select")
continue
self.config_db.listen(start=False)

fd = selectable_.getFd()
# Get the Corresponding subscriber & table
subscriber, table = self.subscriber_map.get(fd, (None, ""))
if not subscriber:
syslog.syslog(syslog.LOG_ERR,
"No Subscriber object found for fd: {}, subscriber map: {}".format(fd, subscriber_map))
continue
key, op, fvs = subscriber.pop()
# Get the registered callback
cbs = self.callbacks.get(table, None)
for callback in cbs:
callback(table, key, op, dict(fvs))
def start(self):
self.config_db.process(cache=self.cache)


def main():
Expand Down
79 changes: 9 additions & 70 deletions src/sonic-host-services/tests/common/mock_configdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ class MockConfigDb(object):
"""
STATE_DB = None
CONFIG_DB = None
event_queue = []

def __init__(self, **kwargs):
pass
self.handlers = {}

@staticmethod
def set_config_db(test_config_db):
Expand Down Expand Up @@ -44,75 +45,13 @@ def set_entry(self, key, field, data):
def get_table(self, table_name):
return MockConfigDb.CONFIG_DB[table_name]

def subscribe(self, table_name, callback):
self.handlers[table_name] = callback

class MockSelect():

event_queue = []

@staticmethod
def set_event_queue(Q):
MockSelect.event_queue = Q

@staticmethod
def get_event_queue():
return MockSelect.event_queue

@staticmethod
def reset_event_queue():
MockSelect.event_queue = []

def __init__(self):
self.sub_map = {}
self.TIMEOUT = "TIMEOUT"
self.ERROR = "ERROR"

def addSelectable(self, subscriber):
self.sub_map[subscriber.table] = subscriber

def select(self, TIMEOUT):
if not MockSelect.get_event_queue():
raise TimeoutError
table, key = MockSelect.get_event_queue().pop(0)
self.sub_map[table].nextKey(key)
return "OBJECT", self.sub_map[table]


class MockSubscriberStateTable():

FD_INIT = 0

@staticmethod
def generate_fd():
curr = MockSubscriberStateTable.FD_INIT
MockSubscriberStateTable.FD_INIT = curr + 1
return curr

@staticmethod
def reset_fd():
MockSubscriberStateTable.FD_INIT = 0

def __init__(self, conn, table, pop, pri):
self.fd = MockSubscriberStateTable.generate_fd()
self.next_key = ''
self.table = table

def getFd(self):
return self.fd

def nextKey(self, key):
self.next_key = key

def pop(self):
table = MockConfigDb.CONFIG_DB.get(self.table, {})
if self.next_key not in table:
op = "DEL"
fvs = {}
else:
op = "SET"
fvs = table.get(self.next_key, {})
return self.next_key, op, fvs
def listen(self, start):
pass

def process(self, cache):
for e in MockConfigDb.event_queue:
self.handlers[e[0]](e[0], e[1], self.get_entry(e[0], e[1]))

class MockDBConnector():
def __init__(self, db, val):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from parameterized import parameterized
from unittest import TestCase, mock
from tests.hostcfgd.test_radius_vectors import HOSTCFGD_TEST_RADIUS_VECTOR
from tests.common.mock_configdb import MockConfigDb, MockSubscriberStateTable
from tests.common.mock_configdb import MockSelect, MockDBConnector
from tests.common.mock_configdb import MockConfigDb


test_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand All @@ -33,9 +32,6 @@

# Mock swsscommon classes
hostcfgd.ConfigDBConnector = MockConfigDb
hostcfgd.SubscriberStateTable = MockSubscriberStateTable
hostcfgd.Select = MockSelect
hostcfgd.DBConnector = MockDBConnector


class TestHostcfgdRADIUS(TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from parameterized import parameterized
from unittest import TestCase, mock
from tests.hostcfgd.test_tacacs_vectors import HOSTCFGD_TEST_TACACS_VECTOR
from tests.common.mock_configdb import MockConfigDb, MockSubscriberStateTable
from tests.common.mock_configdb import MockSelect, MockDBConnector
from tests.common.mock_configdb import MockConfigDb

test_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
modules_path = os.path.dirname(test_path)
Expand All @@ -32,9 +31,6 @@

# Mock swsscommon classes
hostcfgd.ConfigDBConnector = MockConfigDb
hostcfgd.SubscriberStateTable = MockSubscriberStateTable
hostcfgd.Select = MockSelect
hostcfgd.DBConnector = MockDBConnector

class TestHostcfgdTACACS(TestCase):
"""
Expand Down
Loading