From bc0e2e2112bbbcd6ebc20d2db34d0ad4255ab549 Mon Sep 17 00:00:00 2001 From: Dmitry Figol Date: Tue, 24 Mar 2020 12:21:15 +0100 Subject: [PATCH] Do not store failed connections in host attributes (#497) * Tests for failed connections #350 * Do not store failed connections in host.connections, fix #350 * Replace sentinel object UNESTABLISHED_CONNECTION with None * Use variable name conn_obj instead of connection not to confuse mypy --- nornir/core/connections.py | 11 ++---- nornir/core/inventory.py | 63 +++++++++++++++++++--------------- tests/core/test_connections.py | 38 ++++++++++++++++++++ 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/nornir/core/connections.py b/nornir/core/connections.py index fa5b19c6..a554914e 100644 --- a/nornir/core/connections.py +++ b/nornir/core/connections.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, NoReturn, Optional, Type +from typing import Any, Dict, Optional, Type from nornir.core.configuration import Config @@ -23,7 +23,7 @@ class ConnectionPlugin(ABC): __slots__ = ("connection", "state") def __init__(self) -> None: - self.connection: Any = UnestablishedConnection() + self.connection: Any = None self.state: Dict[str, Any] = {} @abstractmethod @@ -49,13 +49,6 @@ def close(self) -> None: pass -class UnestablishedConnection(object): - def close(self) -> NoReturn: - raise ValueError("Connection not established") - - disconnect = close - - class Connections(Dict[str, ConnectionPlugin]): available: Dict[str, Type[ConnectionPlugin]] = {} diff --git a/nornir/core/inventory.py b/nornir/core/inventory.py index 3307ad98..d1c6cdd5 100644 --- a/nornir/core/inventory.py +++ b/nornir/core/inventory.py @@ -4,7 +4,10 @@ from nornir.core import deserializer from nornir.core.configuration import Config -from nornir.core.connections import ConnectionPlugin, Connections +from nornir.core.connections import ( + ConnectionPlugin, + Connections, +) from nornir.core.exceptions import ConnectionAlreadyOpen, ConnectionNotOpen @@ -336,39 +339,43 @@ def open_connection( Returns: An already established connection """ - if connection in self.connections: - raise ConnectionAlreadyOpen(connection) + conn_name = connection + existing_conn = self.connections.get(conn_name) + if existing_conn is not None: + raise ConnectionAlreadyOpen(conn_name) - self.connections[connection] = self.connections.get_plugin(connection)() + plugin = self.connections.get_plugin(conn_name) + conn_obj = plugin() if default_to_host_attributes: - conn_params = self.get_connection_parameters(connection) - self.connections[connection].open( - hostname=hostname if hostname is not None else conn_params.hostname, - username=username if username is not None else conn_params.username, - password=password if password is not None else conn_params.password, - port=port if port is not None else conn_params.port, - platform=platform if platform is not None else conn_params.platform, - extras=extras if extras is not None else conn_params.extras, - configuration=configuration, - ) - else: - self.connections[connection].open( - hostname=hostname, - username=username, - password=password, - port=port, - platform=platform, - extras=extras, - configuration=configuration, - ) - return self.connections[connection] + conn_params = self.get_connection_parameters(conn_name) + hostname = hostname if hostname is not None else conn_params.hostname + username = username if username is not None else conn_params.username + password = password if password is not None else conn_params.password + port = port if port is not None else conn_params.port + platform = platform if platform is not None else conn_params.platform + extras = extras if extras is not None else conn_params.extras + + conn_obj.open( + hostname=hostname, + username=username, + password=password, + port=port, + platform=platform, + extras=extras, + configuration=configuration, + ) + self.connections[conn_name] = conn_obj + return connection def close_connection(self, connection: str) -> None: """ Close the connection""" - if connection not in self.connections: - raise ConnectionNotOpen(connection) + conn_name = connection + if conn_name not in self.connections: + raise ConnectionNotOpen(conn_name) - self.connections.pop(connection).close() + conn_obj = self.connections.pop(conn_name) + if conn_obj is not None: + conn_obj.close() def close_connections(self) -> None: # Decouple deleting dictionary elements from iterating over connections dict diff --git a/tests/core/test_connections.py b/tests/core/test_connections.py index c9f22aed..a4ac1d7b 100644 --- a/tests/core/test_connections.py +++ b/tests/core/test_connections.py @@ -42,6 +42,31 @@ class AnotherDummyConnectionPlugin(DummyConnectionPlugin): pass +class FailedConnection(Exception): + pass + + +class FailedConnectionPlugin(ConnectionPlugin): + name = "fail" + + def open( + self, + hostname: Optional[str], + username: Optional[str], + password: Optional[str], + port: Optional[int], + platform: Optional[str], + extras: Optional[Dict[str, Any]] = None, + configuration: Optional[Config] = None, + ) -> None: + raise FailedConnection( + f"Failed to open connection to {self.hostname}:{self.port}" + ) + + def close(self) -> None: + pass + + def open_and_close_connection(task): task.host.open_connection("dummy", task.nornir.config) assert "dummy" in task.host.connections @@ -69,6 +94,10 @@ def close_not_opened_connection(task): assert "dummy" not in task.host.connections +def failed_connection(task): + task.host.open_connection(FailedConnectionPlugin.name, task.nornir.config) + + def a_task(task): task.host.get_connection("dummy", task.nornir.config) @@ -86,6 +115,7 @@ def setup_class(cls): Connections.register("dummy", DummyConnectionPlugin) Connections.register("dummy2", DummyConnectionPlugin) Connections.register("dummy_no_overrides", DummyConnectionPlugin) + Connections.register(FailedConnectionPlugin.name, FailedConnectionPlugin) def test_open_and_close_connection(self, nornir): nr = nornir.filter(name="dev2.group_1") @@ -105,6 +135,14 @@ def test_close_not_opened_connection(self, nornir): assert len(r) == 1 assert not r.failed + def test_failed_connection(self, nornir): + nr = nornir.filter(name="dev2.group_1") + nr.run(task=failed_connection, num_workers=1) + assert ( + FailedConnectionPlugin.name + not in nornir.inventory.hosts["dev2.group_1"].connections + ) + def test_context_manager(self, nornir): with nornir.filter(name="dev2.group_1") as nr: nr.run(task=a_task)