Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not store failed connections in host attributes #497

Merged
merged 4 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions nornir/core/connections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, NoReturn, Optional, Type
from typing import Any, Dict, Optional, Type


from nornir.core.configuration import Config
Expand All @@ -23,7 +23,7 @@ class ConnectionPlugin(ABC):
__slots__ = ("connection", "state")

def __init__(self) -> None:
self.connection: Any = UnestablishedConnection()
self.connection: Any = None
self.state: Dict[str, Any] = {}

@abstractmethod
Expand All @@ -49,13 +49,6 @@ def close(self) -> None:
pass


class UnestablishedConnection(object):
def close(self) -> NoReturn:
raise ValueError("Connection not established")

disconnect = close


class Connections(Dict[str, ConnectionPlugin]):
available: Dict[str, Type[ConnectionPlugin]] = {}

Expand Down
63 changes: 35 additions & 28 deletions nornir/core/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from nornir.core import deserializer
from nornir.core.configuration import Config
from nornir.core.connections import ConnectionPlugin, Connections
from nornir.core.connections import (
ConnectionPlugin,
Connections,
)
from nornir.core.exceptions import ConnectionAlreadyOpen, ConnectionNotOpen


Expand Down Expand Up @@ -336,39 +339,43 @@ def open_connection(
Returns:
An already established connection
"""
if connection in self.connections:
raise ConnectionAlreadyOpen(connection)
conn_name = connection
existing_conn = self.connections.get(conn_name)
if existing_conn is not None:
raise ConnectionAlreadyOpen(conn_name)

self.connections[connection] = self.connections.get_plugin(connection)()
plugin = self.connections.get_plugin(conn_name)
conn_obj = plugin()
if default_to_host_attributes:
conn_params = self.get_connection_parameters(connection)
self.connections[connection].open(
hostname=hostname if hostname is not None else conn_params.hostname,
username=username if username is not None else conn_params.username,
password=password if password is not None else conn_params.password,
port=port if port is not None else conn_params.port,
platform=platform if platform is not None else conn_params.platform,
extras=extras if extras is not None else conn_params.extras,
configuration=configuration,
)
else:
self.connections[connection].open(
hostname=hostname,
username=username,
password=password,
port=port,
platform=platform,
extras=extras,
configuration=configuration,
)
return self.connections[connection]
conn_params = self.get_connection_parameters(conn_name)
hostname = hostname if hostname is not None else conn_params.hostname
username = username if username is not None else conn_params.username
password = password if password is not None else conn_params.password
port = port if port is not None else conn_params.port
platform = platform if platform is not None else conn_params.platform
extras = extras if extras is not None else conn_params.extras

conn_obj.open(
hostname=hostname,
username=username,
password=password,
port=port,
platform=platform,
extras=extras,
configuration=configuration,
)
self.connections[conn_name] = conn_obj
return connection

def close_connection(self, connection: str) -> None:
""" Close the connection"""
if connection not in self.connections:
raise ConnectionNotOpen(connection)
conn_name = connection
if conn_name not in self.connections:
raise ConnectionNotOpen(conn_name)

self.connections.pop(connection).close()
conn_obj = self.connections.pop(conn_name)
if conn_obj is not None:
conn_obj.close()

def close_connections(self) -> None:
# Decouple deleting dictionary elements from iterating over connections dict
Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ class AnotherDummyConnectionPlugin(DummyConnectionPlugin):
pass


class FailedConnection(Exception):
pass


class FailedConnectionPlugin(ConnectionPlugin):
name = "fail"

def open(
self,
hostname: Optional[str],
username: Optional[str],
password: Optional[str],
port: Optional[int],
platform: Optional[str],
extras: Optional[Dict[str, Any]] = None,
configuration: Optional[Config] = None,
) -> None:
raise FailedConnection(
f"Failed to open connection to {self.hostname}:{self.port}"
)

def close(self) -> None:
pass


def open_and_close_connection(task):
task.host.open_connection("dummy", task.nornir.config)
assert "dummy" in task.host.connections
Expand Down Expand Up @@ -69,6 +94,10 @@ def close_not_opened_connection(task):
assert "dummy" not in task.host.connections


def failed_connection(task):
task.host.open_connection(FailedConnectionPlugin.name, task.nornir.config)


def a_task(task):
task.host.get_connection("dummy", task.nornir.config)

Expand All @@ -86,6 +115,7 @@ def setup_class(cls):
Connections.register("dummy", DummyConnectionPlugin)
Connections.register("dummy2", DummyConnectionPlugin)
Connections.register("dummy_no_overrides", DummyConnectionPlugin)
Connections.register(FailedConnectionPlugin.name, FailedConnectionPlugin)

def test_open_and_close_connection(self, nornir):
nr = nornir.filter(name="dev2.group_1")
Expand All @@ -105,6 +135,14 @@ def test_close_not_opened_connection(self, nornir):
assert len(r) == 1
assert not r.failed

def test_failed_connection(self, nornir):
nr = nornir.filter(name="dev2.group_1")
nr.run(task=failed_connection, num_workers=1)
assert (
FailedConnectionPlugin.name
not in nornir.inventory.hosts["dev2.group_1"].connections
)

def test_context_manager(self, nornir):
with nornir.filter(name="dev2.group_1") as nr:
nr.run(task=a_task)
Expand Down