Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyinfra/connectors: Fix overwriting of users known_hosts file (#1209) #1251

Closed
wants to merge 8 commits into from
Closed
40 changes: 26 additions & 14 deletions pyinfra/connectors/sshuserclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SSHException,
)
from paramiko.agent import AgentRequestHandler
from paramiko.hostkeys import HostKeyEntry

from pyinfra import logger
from pyinfra.api.util import memoize
Expand All @@ -31,6 +32,28 @@ def missing_host_key(self, client, hostname, key):
)


def append_hostkey(client, hostname, key):
"""Append hostname to the clients host_keys_file"""

with HOST_KEYS_LOCK:
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_key_entry = HostKeyEntry([hostname], key)
if host_key_entry is None:
raise SSHException(
"Append Hostkey: Failed to parse host {0}, could not append to hostfile".format(
hostname
),
)
with open(client._host_keys_filename, "a") as host_keys_file:
hk_entry = host_key_entry.to_line()
if hk_entry is None:
raise SSHException(f"Append Hostkey: Failed to append hostkey ({host_key_entry})")

host_keys_file.write(hk_entry)


class AcceptNewPolicy(MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
logger.warning(
Expand All @@ -40,13 +63,8 @@ def missing_host_key(self, client, hostname, key):
),
)

with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))


class AskPolicy(MissingHostKeyPolicy):
Expand All @@ -60,13 +78,7 @@ def missing_host_key(self, client, hostname, key):
raise SSHException(
"AskPolicy: No host key for {0} found in known_hosts".format(hostname),
)
with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))
return

Expand Down
69 changes: 68 additions & 1 deletion tests/test_connectors/test_sshuserclient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from base64 import b64decode
from unittest import TestCase
from unittest.mock import mock_open, patch

from paramiko import ProxyCommand
from paramiko import PKey, ProxyCommand, SSHException

from pyinfra.connectors.sshuserclient import SSHClient
from pyinfra.connectors.sshuserclient.client import AskPolicy, get_ssh_config
Expand Down Expand Up @@ -41,6 +42,30 @@
Include other_file
"""

# To ensure that we don't remove things from users hostfiles
# we should test that all modifications only append to the
# hostfile, and don't delete any data or comments.
EXAMPLE_KEY_1 = (
"AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+"
"VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/"

Check warning on line 50 in tests/test_connectors/test_sshuserclient.py

View workflow job for this annotation

GitHub Actions / spell-check

"Ue" should be "Use" or "Due".
"C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXk"
"E2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMj"
"A2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIE"
"s4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+Ej"
"qoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/Wnw"
"H6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk="
)

KNOWN_HOSTS_EXAMPLE_DATA = f"""
# this is an important comment

# another comment after the newline

@cert-authority example-domain.lan ssh-rsa {EXAMPLE_KEY_1}

192.168.1.222 ssh-rsa {EXAMPLE_KEY_1}
"""


class TestSSHUserConfigMissing(TestCase):
def setUp(self):
Expand Down Expand Up @@ -199,3 +224,45 @@
port=22,
test="kwarg",
)

def test_missing_hostkey(self):
client = SSHClient()
policy = AskPolicy()
example_hostname = "new_host"
example_keytype = "ecdsa-sha2-nistp256"
example_key = (
"AAAAE2VjZHNhLXNoYTItbmlzdHAyNT"
"YAAAAIbmlzdHAyNTYAAABBBHNp1NM"
"ZjxPBuuKwIPfkVJqWaH3oUtW137kIW"
"P4PlCyACt8zVIIimFhIpwRUidcf7jw"
"VWPAJvfBjEPqewDApnZQ="
)

key = PKey.from_type_string(
example_keytype,
b64decode(example_key),
)

# Check if AskPolicy respects not importing and properly raises SSHException
with self.subTest("Check user 'no'"):
with patch("builtins.input", return_value="n"):
self.assertRaises(
SSHException, lambda: policy.missing_host_key(client, example_hostname, key)
)

# Check if AskPolicy properly appends to hostfile
with self.subTest("Check user 'yes'"):
mock_data = mock_open(read_data=KNOWN_HOSTS_EXAMPLE_DATA)
# Read mock hostfile
with patch("pyinfra.connectors.sshuserclient.client.open", mock_data):
with patch("paramiko.hostkeys.open", mock_data):
with patch("builtins.input", return_value="y"):
policy.missing_host_key(client, "new_host", key)

# Assert that we appended correctly to the file
write_call_args = mock_data.return_value.write.call_args
# Ensure we only wrote once and then closed the handle.
assert len(write_call_args) == 2
# Ensure we wrote the correct content
correct_output = f"{example_hostname} {example_keytype} {example_key}\n"
assert write_call_args[0][0] == correct_output
Loading