Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow to use any port as TLS port #145

Merged
merged 7 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
username=None,
password=None,
client_id=None,
is_ssl=True,
is_ssl=None,
keep_alive=60,
recv_timeout=10,
socket_pool=None,
Expand Down Expand Up @@ -220,13 +220,19 @@ def __init__(
): # [MQTT-3.1.3.5]
raise MMQTTException("Password length is too large.")

# The connection will be insecure unless is_ssl is set to True.
# If the port is not specified, the security will be set based on the is_ssl parameter.
# If the port is specified, the is_ssl parameter will be honored.
self.port = MQTT_TCP_PORT
if is_ssl:
if is_ssl is None:
is_ssl = False
self._is_ssl = is_ssl
if self._is_ssl:
self.port = MQTT_TLS_PORT
if port:
self.port = port

# define client identifer
# define client identifier
if client_id:
# user-defined client_id MAY allow client_id's > 23 bytes or
# non-alpha-numeric characters
Expand Down Expand Up @@ -282,12 +288,12 @@ def _get_connect_socket(self, host, port, *, timeout=1):
if not isinstance(port, int):
raise RuntimeError("Port must be an integer")

if port == MQTT_TLS_PORT and not self._ssl_context:
if self._is_ssl and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
)

if port == MQTT_TLS_PORT:
if self._is_ssl:
self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}")
else:
self.logger.info(f"Establishing an INSECURE connection to {host}:{port}")
Expand All @@ -306,7 +312,7 @@ def _get_connect_socket(self, host, port, *, timeout=1):
raise TemporaryError from exc

connect_host = addr_info[-1][0]
if port == MQTT_TLS_PORT:
if self._is_ssl:
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout)
Expand Down
9 changes: 6 additions & 3 deletions examples/cpython/minimqtt_adafruitio_cpython.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries
# SPDX-License-Identifier: MIT

import time
import socket
import ssl
import time

import adafruit_minimqtt.adafruit_minimqtt as MQTT

### Secrets File Setup ###
Expand Down Expand Up @@ -46,11 +48,12 @@ def message(client, topic, message):

# Set up a MiniMQTT Client
mqtt_client = MQTT.MQTT(
broker=secrets["broker"],
port=1883,
broker="io.adafruit.com",
username=secrets["aio_username"],
password=secrets["aio_key"],
socket_pool=socket,
is_ssl=True,
ssl_context=ssl.create_default_context(),
)

# Setup the callback methods above
Expand Down
6 changes: 5 additions & 1 deletion examples/ethernet/minimqtt_simpletest_eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def publish(client, userdata, topic, pid):
MQTT.set_socket(socket, eth)

# Set up a MiniMQTT Client
# NOTE: We'll need to connect insecurely for ethernet configurations.
client = MQTT.MQTT(
broker=secrets["broker"], username=secrets["user"], password=secrets["pass"]
broker=secrets["broker"],
username=secrets["user"],
password=secrets["pass"],
is_ssl=False,
)

# Connect callback handlers to client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def message(client, topic, message):

# Set up a MiniMQTT Client
mqtt_client = MQTT.MQTT(
broker=secrets["broker"],
broker="io.adafruit.com",
port=secrets["port"],
username=secrets["aio_username"],
password=secrets["aio_key"],
Expand Down
124 changes: 124 additions & 0 deletions tests/test_port_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
#
# SPDX-License-Identifier: Unlicense

"""tests that verify the connect behavior w.r.t. port number and TLS"""

import socket
import ssl
from unittest import TestCase, main
from unittest.mock import Mock, call, patch

import adafruit_minimqtt.adafruit_minimqtt as MQTT


class PortSslSetup(TestCase):
"""This class contains tests that verify how host/port and TLS is set for connect().
These tests assume that there is no MQTT broker running on the hosts/ports they connect to.
"""

def test_default_port(self) -> None:
"""verify default port value and that TLS is not used"""
host = "127.0.0.1"
port = 1883

with patch.object(socket.socket, "connect") as connect_mock:
ssl_context = ssl.create_default_context()
mqtt_client = MQTT.MQTT(
broker=host,
socket_pool=socket,
ssl_context=ssl_context,
connect_retries=1,
)

ssl_mock = Mock()
ssl_context.wrap_socket = ssl_mock

with self.assertRaises(MQTT.MMQTTException):
expected_port = port
mqtt_client.connect()

ssl_mock.assert_not_called()
connect_mock.assert_called()
# Assuming the repeated calls will have the same arguments.
connect_mock.assert_has_calls([call((host, expected_port))])

def test_connect_override(self):
"""Test that connect() can override host and port."""
host = "127.0.0.1"
port = 1883

with patch.object(socket.socket, "connect") as connect_mock:
connect_mock.side_effect = OSError("artificial error")
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
connect_retries=1,
)

with self.assertRaises(MQTT.MMQTTException):
expected_host = "127.0.0.2"
expected_port = 1884
self.assertNotEqual(expected_port, port, "port override should differ")
self.assertNotEqual(expected_host, host, "host override should differ")
mqtt_client.connect(host=expected_host, port=expected_port)

connect_mock.assert_called()
# Assuming the repeated calls will have the same arguments.
connect_mock.assert_has_calls([call((expected_host, expected_port))])

def test_tls_port(self) -> None:
"""verify that when is_ssl=True is set, the default port is 8883
and the socket is TLS wrapped. Also test that the TLS port can be overridden."""
host = "127.0.0.1"

for port in [None, 8884]:
if port is None:
expected_port = 8883
else:
expected_port = port
with self.subTest():
ssl_mock = Mock()
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
is_ssl=True,
ssl_context=ssl_mock,
connect_retries=1,
)

socket_mock = Mock()
connect_mock = Mock(side_effect=OSError)
socket_mock.connect = connect_mock
ssl_mock.wrap_socket = Mock(return_value=socket_mock)

with self.assertRaises(MQTT.MMQTTException):
mqtt_client.connect()

ssl_mock.wrap_socket.assert_called()

connect_mock.assert_called()
# Assuming the repeated calls will have the same arguments.
connect_mock.assert_has_calls([call((host, expected_port))])

def test_tls_without_ssl_context(self) -> None:
"""verify that when is_ssl=True is set, the code will check that ssl_context is not None"""
host = "127.0.0.1"

mqtt_client = MQTT.MQTT(
broker=host,
socket_pool=socket,
is_ssl=True,
ssl_context=None,
connect_retries=1,
)

with self.assertRaises(RuntimeError) as context:
mqtt_client.connect()
self.assertTrue("ssl_context must be set" in str(context))


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
#
# SPDX-License-Identifier: MIT

[tox]
envlist = py39

[testenv]
changedir = {toxinidir}/tests
deps = pytest==6.2.5
commands = pytest -v