diff --git a/actions.yaml b/actions.yaml index 0517e665ca..4ba1956917 100644 --- a/actions.yaml +++ b/actions.yaml @@ -22,3 +22,9 @@ set-password: password: type: string description: The password will be auto-generated if this option is not specified. +set-tls-private-key: + description: Set the private key, which will be used for certificate signing requests (CSR). Run for each unit separately. + params: + private-key: + type: string + description: The content of private key for communications with clients. Content will be auto-generated if this option is not specified. diff --git a/lib/charms/postgresql_k8s/v0/postgresql.py b/lib/charms/postgresql_k8s/v0/postgresql.py index d030c657f5..5b5e002703 100644 --- a/lib/charms/postgresql_k8s/v0/postgresql.py +++ b/lib/charms/postgresql_k8s/v0/postgresql.py @@ -31,7 +31,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 logger = logging.getLogger(__name__) @@ -178,6 +178,20 @@ def get_postgresql_version(self) -> str: logger.error(f"Failed to get PostgreSQL version: {e}") raise PostgreSQLGetPostgreSQLVersionError() + def is_tls_enabled(self) -> bool: + """Returns whether TLS is enabled. + + Returns: + whether TLS is enabled. + """ + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + cursor.execute("SHOW ssl;") + return "on" in cursor.fetchone()[0] + except psycopg2.Error as e: + logger.error(f"Failed to get check whether TLS is enabled: {e}") + return False + def update_user_password(self, username: str, password: str) -> None: """Update a user password. diff --git a/src/charm.py b/src/charm.py index 68ce1fd924..affc14f4b0 100755 --- a/src/charm.py +++ b/src/charm.py @@ -11,6 +11,8 @@ PostgreSQL, PostgreSQLUpdateUserPasswordError, ) +from charms.postgresql_k8s.v0.postgresql_tls import PostgreSQLTLS +from charms.rolling_ops.v0.rollingops import RollingOpsManager from lightkube import ApiError, Client, codecs from lightkube.resources.core_v1 import Endpoints, Pod, Service from ops.charm import ( @@ -25,11 +27,12 @@ from ops.model import ( ActiveStatus, BlockedStatus, + Container, MaintenanceStatus, Relation, WaitingStatus, ) -from ops.pebble import Layer +from ops.pebble import Layer, PathError, ProtocolError from requests import ConnectionError from tenacity import RetryError @@ -38,8 +41,13 @@ REPLICATION_PASSWORD_KEY, REPLICATION_USER, SYSTEM_USERS, + TLS_CA_FILE, + TLS_CERT_FILE, + TLS_KEY_FILE, USER, USER_PASSWORD_KEY, + WORKLOAD_OS_GROUP, + WORKLOAD_OS_USER, ) from patroni import NotReadyError, Patroni from relations.db import DbProvides @@ -78,6 +86,10 @@ def __init__(self, *args): self.postgresql_client_relation = PostgreSQLProvider(self) self.legacy_db_relation = DbProvides(self, admin=False) self.legacy_db_admin_relation = DbProvides(self, admin=True) + self.tls = PostgreSQLTLS(self, PEER) + self.restart_manager = RollingOpsManager( + charm=self, relation="restart", callback=self._restart + ) @property def app_peer_data(self) -> Dict: @@ -97,7 +109,7 @@ def unit_peer_data(self) -> Dict: return relation.data[self.unit] - def _get_secret(self, scope: str, key: str) -> Optional[str]: + def get_secret(self, scope: str, key: str) -> Optional[str]: """Get secret from the secret storage.""" if scope == "unit": return self.unit_peer_data.get(key, None) @@ -106,7 +118,7 @@ def _get_secret(self, scope: str, key: str) -> Optional[str]: else: raise RuntimeError("Unknown secret scope.") - def _set_secret(self, scope: str, key: str, value: Optional[str]) -> None: + def set_secret(self, scope: str, key: str, value: Optional[str]) -> None: """Get secret from the secret storage.""" if scope == "unit": if not value: @@ -127,7 +139,7 @@ def postgresql(self) -> PostgreSQL: return PostgreSQL( host=self.primary_endpoint, user=USER, - password=self._get_secret("app", f"{USER}-password"), + password=self.get_secret("app", f"{USER}-password"), database="postgres", ) @@ -150,6 +162,18 @@ def _build_service_name(self, service: str) -> str: """Build a full k8s service name based on the service name.""" return f"{self._name}-{service}.{self._namespace}.svc.cluster.local" + def get_hostname_by_unit(self, unit_name: str) -> str: + """Create a DNS name for a PostgreSQL unit. + + Args: + unit_name: the juju unit name, e.g. "postgre-sql/1". + + Returns: + A string representing the hostname of the PostgreSQL unit. + """ + unit_id = unit_name.split("/")[1] + return f"{self.app.name}-{unit_id}.{self.app.name}-endpoints" + def _get_endpoints_to_remove(self) -> List[str]: """List the endpoints that were part of the cluster but departed.""" old = self._endpoints @@ -194,7 +218,7 @@ def _on_peer_relation_changed(self, event: RelationChangedEvent) -> None: # Update the list of the cluster members in the replicas to make them know each other. # Update the cluster members in this unit (updating patroni configuration). - self._patroni.update_cluster_members() + self.update_config() # Validate the status of the member before setting an ActiveStatus. if not self._patroni.member_started: @@ -301,11 +325,11 @@ def _get_hostname_from_unit(self, member: str) -> str: def _on_leader_elected(self, event: LeaderElectedEvent) -> None: """Handle the leader-elected event.""" - if self._get_secret("app", USER_PASSWORD_KEY) is None: - self._set_secret("app", USER_PASSWORD_KEY, new_password()) + if self.get_secret("app", USER_PASSWORD_KEY) is None: + self.set_secret("app", USER_PASSWORD_KEY, new_password()) - if self._get_secret("app", REPLICATION_PASSWORD_KEY) is None: - self._set_secret("app", REPLICATION_PASSWORD_KEY, new_password()) + if self.get_secret("app", REPLICATION_PASSWORD_KEY) is None: + self.set_secret("app", REPLICATION_PASSWORD_KEY, new_password()) # Create resources and add labels needed for replication. self._create_resources() @@ -345,6 +369,13 @@ def _on_postgresql_pebble_ready(self, event: WorkloadEvent) -> None: event.defer() return + try: + self.push_tls_files_to_workload(container) + except (PathError, ProtocolError) as e: + logger.error("Cannot push TLS certificates: %r", e) + event.defer() + return + # Get the current layer. current_layer = container.get_plan() # Check if there are any changes to layer services. @@ -354,7 +385,6 @@ def _on_postgresql_pebble_ready(self, event: WorkloadEvent) -> None: logging.info("Added updated layer 'postgresql' to Pebble plan") # TODO: move this file generation to on config changed hook # when adding configs to this charm. - self._patroni.render_patroni_yml_file() # Restart it and report a new status to Juju. container.restart(self._postgresql_service) logging.info("Restarted postgresql service") @@ -447,9 +477,7 @@ def _on_get_password(self, event: ActionEvent) -> None: f" {', '.join(SYSTEM_USERS)} not {username}" ) return - event.set_results( - {f"{username}-password": self._get_secret("app", f"{username}-password")} - ) + event.set_results({f"{username}-password": self.get_secret("app", f"{username}-password")}) def _on_set_password(self, event: ActionEvent) -> None: """Set the password for the specified user.""" @@ -470,7 +498,7 @@ def _on_set_password(self, event: ActionEvent) -> None: if "password" in event.params: password = event.params["password"] - if password == self._get_secret("app", f"{username}-password"): + if password == self.get_secret("app", f"{username}-password"): event.log("The old and new passwords are equal.") event.set_results({f"{username}-password": password}) return @@ -495,12 +523,11 @@ def _on_set_password(self, event: ActionEvent) -> None: return # Update the password in the secret store. - self._set_secret("app", f"{username}-password", password) + self.set_secret("app", f"{username}-password", password) # Update and reload Patroni configuration in this unit to use the new password. # Other units Patroni configuration will be reloaded in the peer relation changed event. - self._patroni.render_patroni_yml_file() - self._patroni.reload_patroni_configuration() + self.update_config() event.set_results({f"{username}-password": password}) @@ -569,8 +596,8 @@ def _patroni(self): self._namespace, self.app.planned_units(), self._storage_path, - self._get_secret("app", USER_PASSWORD_KEY), - self._get_secret("app", REPLICATION_PASSWORD_KEY), + self.get_secret("app", USER_PASSWORD_KEY), + self.get_secret("app", REPLICATION_PASSWORD_KEY), ) @property @@ -649,6 +676,67 @@ def _peers(self) -> Relation: """ return self.model.get_relation(PEER) + def push_tls_files_to_workload(self, container: Container = None) -> None: + """Uploads TLS files to the workload container.""" + if container is None: + container = self.unit.get_container("postgresql") + + key, ca, cert = self.tls.get_tls_files() + if key is not None: + container.push( + f"{self._storage_path}/{TLS_KEY_FILE}", + key, + make_dirs=True, + permissions=0o400, + user=WORKLOAD_OS_USER, + group=WORKLOAD_OS_GROUP, + ) + if ca is not None: + container.push( + f"{self._storage_path}/{TLS_CA_FILE}", + ca, + make_dirs=True, + permissions=0o400, + user=WORKLOAD_OS_USER, + group=WORKLOAD_OS_GROUP, + ) + if cert is not None: + container.push( + f"{self._storage_path}/{TLS_CERT_FILE}", + cert, + make_dirs=True, + permissions=0o400, + user=WORKLOAD_OS_USER, + group=WORKLOAD_OS_GROUP, + ) + + self.update_config() + + def _restart(self, _) -> None: + """Restart PostgreSQL.""" + try: + self._patroni.restart_postgresql() + except RetryError as e: + logger.error("failed to restart PostgreSQL") + self.unit.status = BlockedStatus(f"failed to restart PostgreSQL with error {e}") + + def update_config(self) -> None: + """Updates Patroni config file based on the existence of the TLS files.""" + enable_tls = all(self.tls.get_tls_files()) + + # Update and reload configuration based on TLS files availability. + self._patroni.render_patroni_yml_file(enable_tls=enable_tls) + if not self._patroni.member_started: + return + + restart_postgresql = enable_tls != self.postgresql.is_tls_enabled() + self._patroni.reload_patroni_configuration() + + # Restart PostgreSQL if TLS configuration has changed + # (so the both old and new connections use the configuration). + if restart_postgresql: + self.on[self.restart_manager.name].acquire_lock.emit() + def _unit_name_to_pod_name(self, unit_name: str) -> str: """Converts unit name to pod name. diff --git a/src/constants.py b/src/constants.py index 860e98b296..bd7bae6662 100644 --- a/src/constants.py +++ b/src/constants.py @@ -7,7 +7,12 @@ PEER = "database-peers" REPLICATION_USER = "replication" REPLICATION_PASSWORD_KEY = "replication-password" +TLS_KEY_FILE = "key.pem" +TLS_CA_FILE = "ca.pem" +TLS_CERT_FILE = "cert.pem" USER = "operator" USER_PASSWORD_KEY = "operator-password" +WORKLOAD_OS_GROUP = "postgres" +WORKLOAD_OS_USER = "postgres" # List of system usernames needed for correct work of the charm/workload. SYSTEM_USERS = [REPLICATION_USER, USER] diff --git a/src/patroni.py b/src/patroni.py index df8e5c4a68..82424bf85c 100644 --- a/src/patroni.py +++ b/src/patroni.py @@ -135,13 +135,18 @@ def _render_file(self, path: str, content: str, mode: int) -> None: # Ignore non existing user error when it wasn't created yet. pass - def render_patroni_yml_file(self) -> None: - """Render the Patroni configuration file.""" + def render_patroni_yml_file(self, enable_tls: bool = False) -> None: + """Render the Patroni configuration file. + + Args: + enable_tls: whether to enable TLS. + """ # Open the template postgresql.conf file. with open("templates/patroni.yml.j2", "r") as file: template = Template(file.read()) # Render the template file with the correct values. rendered = template.render( + enable_tls=enable_tls, endpoint=self._endpoint, endpoints=self._endpoints, namespace=self._namespace, @@ -165,21 +170,12 @@ def render_postgresql_conf_file(self) -> None: ) self._render_file(f"{self._storage_path}/postgresql-k8s-operator.conf", rendered, 0o644) - def update_cluster_members(self) -> None: - """Update the list of members of the cluster.""" - # Update the members in the Patroni configuration. - self.render_patroni_yml_file() - - try: - if self.member_started: - # Make Patroni use the updated configuration. - self.reload_patroni_configuration() - except RetryError: - # Ignore retry errors that happen when the member has not started yet. - # The configuration will be loaded correctly when Patroni starts. - pass - @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) def reload_patroni_configuration(self) -> None: """Reloads the configuration after it was updated in the file.""" requests.post(f"http://{self._endpoint}:8008/reload") + + @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) + def restart_postgresql(self) -> None: + """Restart PostgreSQL.""" + requests.post(f"http://{self._endpoint}:8008/restart") diff --git a/templates/patroni.yml.j2 b/templates/patroni.yml.j2 index 54da2d62ef..d97ebc6ef9 100644 --- a/templates/patroni.yml.j2 +++ b/templates/patroni.yml.j2 @@ -9,8 +9,8 @@ bootstrap: - locale: en_US.UTF-8 - data-checksums pg_hba: - - host all all 0.0.0.0/0 md5 - - host replication replication 127.0.0.1/32 md5 + - {{ 'hostssl' if enable_tls else 'host' }} all all 0.0.0.0/0 md5 + - {{ 'hostssl' if enable_tls else 'host' }} replication replication 127.0.0.1/32 md5 bypass_api_service: true log: dir: /var/log/postgresql @@ -23,12 +23,19 @@ postgresql: custom_conf: {{ storage_path }}/postgresql-k8s-operator.conf data_dir: {{ storage_path }}/pgdata listen: 0.0.0.0:5432 + {%- if enable_tls %} + parameters: + ssl: on + ssl_ca_file: {{ storage_path }}/ca.pem + ssl_cert_file: {{ storage_path }}/cert.pem + ssl_key_file: {{ storage_path }}/key.pem + {%- endif %} pgpass: /tmp/pgpass pg_hba: - - host all all 0.0.0.0/0 md5 - - host replication replication 127.0.0.1/32 md5 + - {{ 'hostssl' if enable_tls else 'host' }} all all 0.0.0.0/0 md5 + - {{ 'hostssl' if enable_tls else 'host' }} replication replication 127.0.0.1/32 md5 {%- for endpoint in endpoints %} - - host replication replication {{ endpoint }}.{{ namespace }}.svc.cluster.local md5 + - {{ 'hostssl' if enable_tls else 'host' }} replication replication {{ endpoint }}.{{ namespace }}.svc.cluster.local md5 {%- endfor %} authentication: replication: diff --git a/tests/unit/key.pem b/tests/unit/key.pem new file mode 100644 index 0000000000..f9ee650010 --- /dev/null +++ b/tests/unit/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA90Itc633siO21qTxggqwoh/6wZ/+xcA75WgXE7zQzMApHPVh +gdX/yzcOtLE+/JZqpdXyhuMG7l55hBN2i8KKWrY869qO4lSAlcH+D6PqTRyOPpAf +NScC3kt4xX+ek+zDQNmD+6464beNmf8YgjH0SuKmVYuVlTP6Qzgev0mPe49fZOXD +yqFpnQ5A7Lw+nqxqn29adVANErUkgI5Fqz2Mj/PPHyrNVpK4udwFs7uYXBEpVGFd +bMnu1Kt9ofJYiQBZtHOlQrQajR6LnSVppl1qn9nKwjK6WDch4Y1oQnl9q0/WMril +A0cWd36YlX1ZQqE8GJl7WmvrweLlnIhXgn8vHQIDAQABAoIBAAYP9wjBWOqHg1r7 +WXmv6WlZWs7f9oYL2/5hqBzoWNGsW1WzkxK+KjLTn+ZlH9rNAwvScMtHASbJU1qL +8eeSMhbL3IA+MCvOvXPJ+UnbKwE0b79CFxQUQxJQObFCiOxAtm1sDHKquctu+4wb +DjfsIaE9jWQ0s10OS3toJyH70pjJfII+pu1MYl7d0AXyBGhbhRhI19fGWIxK99cx +Trubdsz54nBFXOMWfO3tHDQMWr9ulm+dMGNQydv84slLOnoP/TzX4NZoxshgH92v +F3rhZxzK3PremX1CY6QknbCNtIlDrLCjNkwSvbvB7o/rgU/05ZU4sfe2YIvGYpA7 +59ceI8ECgYEA+aEyJwtdR43y4RPWD3fBydTtzwKhNT0SIsx3UNSwNYQaYziJw7jN +KDm+3KxaXGj5tpDqvSrepmXRbTfS464UTVqPXWoWr4T1T3nwcW0HRf2E6mnrd1T7 +Td6LBvAPc/8YhIKGILxnRx7sIEG9ihFvc3EqlfgQ0P5gfKNI5GnhFfsCgYEA/ZF9 +toTWk1bQ5ll8Bc5wUG9ulqNnauV2plIhqlBgKl4J8nn1+VbzzyjwNlFlboKJlgI/ +o3RZV+B8Tsrwm7R5oiik7FottMWDUxBv3jynwpmkKVrNqxeN+8GaTe0vU5P3XzQJ +WuCQUZaIFhMlk3c3F4Uc6Ji97lu2v64cqTzO+8cCgYEArH1zL0GOChSO4HIZdwS0 +VmeYj3Nsu2Hgt0T7qVHeFIycwzTdFO4MbSBBvzAlHDe5XSqT2XTW9rniVYq+YW73 +PmA2MkFNPaks7OcAew/wd69veZ8JAqMpJyyAeqHEu81oPqAGWHZ3EtDOLpfehZn4 +nhdzar9Ht3Ieb+FQS+zRiKMCgYBnF4Dp9QYcbz6NeLJA1ha5zSREIHlKb7KWmmvP +h9AhSYx3xfgogJ6tifJn0x4PGQmBMLbY6NPuM3m2uzGpzG1rbWToJx3QEkF7QwKI +c1aterPQkHdv6SqzwZnPtu/35f+7+DcZeJWUZkQ73Vr4oo6GhHslYfxObYFWWx6R +/AQXPwKBgQDT8/ic+UfRW96PIp7mTBF9L7M2ax97PrJM0jZCxbXUyOfqk6tI2Sme +VkbUtDt2/W+HAwWSJYU51PlpB+GqbIng374bRJbngWmrcZ5TrcCO9KnMX3zfLmox +XcaOTQKF9Owvc+JD8N9//aWBVzunKb35pup0cZ0EdQUNdR6C3BN5Zg== +-----END RSA PRIVATE KEY----- diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index f3d54dda6d..1293281268 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -79,11 +79,11 @@ def test_on_leader_elected(self, _, __, _render_postgresql_conf_file, ___): @patch("charm.Patroni.render_postgresql_conf_file") @patch_network_get(private_address="1.1.1.1") @patch("charm.Patroni.member_started") - @patch("charm.Patroni.render_patroni_yml_file") + @patch("charm.PostgresqlOperatorCharm.push_tls_files_to_workload") @patch("charm.PostgresqlOperatorCharm._patch_pod_labels") - @patch("charm.PostgresqlOperatorCharm._create_resources") + @patch("charm.PostgresqlOperatorCharm._on_leader_elected") def test_on_postgresql_pebble_ready( - self, _, __, _render_patroni_yml_file, _member_started, ___, ____ + self, _, __, _push_tls_files_to_workload, _member_started, ___, ____ ): # Check that the initial plan is empty. plan = self.harness.get_container_pebble_plan(self._postgresql_container) @@ -103,7 +103,7 @@ def test_on_postgresql_pebble_ready( self.assertEqual(self.harness.model.unit.status, ActiveStatus()) container = self.harness.model.unit.get_container(self._postgresql_container) self.assertEqual(container.get_service(self._postgresql_service).is_running(), True) - _render_patroni_yml_file.assert_called_once() + _push_tls_files_to_workload.assert_called_once() def test_on_get_password(self): # Create a mock event and set passwords in peer relation data. @@ -138,8 +138,8 @@ def test_on_get_password(self): ) @patch("charm.Patroni.reload_patroni_configuration") - @patch("charm.Patroni.render_patroni_yml_file") - @patch("charm.PostgresqlOperatorCharm._set_secret") + @patch("charm.PostgresqlOperatorCharm.update_config") + @patch("charm.PostgresqlOperatorCharm.set_secret") @patch("charm.PostgresqlOperatorCharm.postgresql") @patch("charm.Patroni.are_all_members_ready") @patch("charm.PostgresqlOperatorCharm._on_leader_elected") @@ -149,7 +149,7 @@ def test_on_set_password( _are_all_members_ready, _postgresql, _set_secret, - _render_patroni_yml_file, + _update_config, _reload_patroni_configuration, ): # Create a mock event. @@ -329,18 +329,18 @@ def test_get_secret(self, _, __, ___): self.harness.set_leader() # Test application scope. - assert self.charm._get_secret("app", "password") is None + assert self.charm.get_secret("app", "password") is None self.harness.update_relation_data( self.rel_id, self.charm.app.name, {"password": "test-password"} ) - assert self.charm._get_secret("app", "password") == "test-password" + assert self.charm.get_secret("app", "password") == "test-password" # Test unit scope. - assert self.charm._get_secret("unit", "password") is None + assert self.charm.get_secret("unit", "password") is None self.harness.update_relation_data( self.rel_id, self.charm.unit.name, {"password": "test-password"} ) - assert self.charm._get_secret("unit", "password") == "test-password" + assert self.charm.get_secret("unit", "password") == "test-password" @patch("charm.Patroni.reload_patroni_configuration") @patch("charm.Patroni.render_postgresql_conf_file") @@ -350,7 +350,7 @@ def test_set_secret(self, _, __, ___): # Test application scope. assert "password" not in self.harness.get_relation_data(self.rel_id, self.charm.app.name) - self.charm._set_secret("app", "password", "test-password") + self.charm.set_secret("app", "password", "test-password") assert ( self.harness.get_relation_data(self.rel_id, self.charm.app.name)["password"] == "test-password" @@ -358,7 +358,7 @@ def test_set_secret(self, _, __, ___): # Test unit scope. assert "password" not in self.harness.get_relation_data(self.rel_id, self.charm.unit.name) - self.charm._set_secret("unit", "password", "test-password") + self.charm.set_secret("unit", "password", "test-password") assert ( self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["password"] == "test-password" diff --git a/tests/unit/test_patroni.py b/tests/unit/test_patroni.py index f023a10cf6..c2e7a78a8f 100644 --- a/tests/unit/test_patroni.py +++ b/tests/unit/test_patroni.py @@ -94,7 +94,7 @@ def test_render_patroni_yml_file(self, _render_file): # Patch the `open` method with our mock. with patch("builtins.open", mock, create=True): # Call the method - self.patroni.render_patroni_yml_file() + self.patroni.render_patroni_yml_file(enable_tls=False) # Check the template is opened read-only in the call to open. self.assertEqual(mock.call_args_list[0][0], ("templates/patroni.yml.j2", "r")) @@ -105,6 +105,40 @@ def test_render_patroni_yml_file(self, _render_file): 0o644, ) + # Then test the rendering of the file with TLS enabled. + _render_file.reset_mock() + expected_content_with_tls = template.render( + enable_tls=True, + endpoint=self.patroni._endpoint, + endpoints=self.patroni._endpoints, + namespace=self.patroni._namespace, + storage_path=self.patroni._storage_path, + superuser_password=self.patroni._superuser_password, + replication_password=self.patroni._replication_password, + ) + self.assertNotEqual(expected_content_with_tls, expected_content) + + # Patch the `open` method with our mock. + with patch("builtins.open", mock, create=True): + # Call the method + self.patroni.render_patroni_yml_file(enable_tls=True) + + # Ensure the correct rendered template is sent to _render_file method. + _render_file.assert_called_once_with( + f"{STORAGE_PATH}/patroni.yml", + expected_content_with_tls, + 0o644, + ) + + # Also, ensure the right parameters are in the expected content + # (as it was already validated with the above render file call). + self.assertIn("ssl: on", expected_content_with_tls) + self.assertIn("ssl_ca_file: /var/lib/postgresql/data/ca.pem", expected_content_with_tls) + self.assertIn( + "ssl_cert_file: /var/lib/postgresql/data/cert.pem", expected_content_with_tls + ) + self.assertIn("ssl_key_file: /var/lib/postgresql/data/key.pem", expected_content_with_tls) + @patch("charm.Patroni._render_file") def test_render_postgresql_conf_file(self, _render_file): # Get the expected content from a file. diff --git a/tests/unit/test_postgresql_tls.py b/tests/unit/test_postgresql_tls.py new file mode 100644 index 0000000000..1acc58ea16 --- /dev/null +++ b/tests/unit/test_postgresql_tls.py @@ -0,0 +1,200 @@ +# Copyright 2022 Canonical Ltd. +# See LICENSE file for licensing details. +import base64 +import socket +import unittest +from unittest.mock import MagicMock, patch + +from ops.testing import Harness + +from charm import PostgresqlOperatorCharm +from constants import PEER +from tests.helpers import patch_network_get + +RELATION_NAME = "certificates" +SCOPE = "unit" + + +class TestPostgreSQLTLS(unittest.TestCase): + def delete_secrets(self) -> None: + # Delete TLS secrets from the secret store. + self.charm.set_secret(SCOPE, "ca", None) + self.charm.set_secret(SCOPE, "cert", None) + self.charm.set_secret(SCOPE, "chain", None) + + def emit_certificate_available_event(self) -> None: + self.charm.tls.certs.on.certificate_available.emit( + certificate_signing_request="test-csr", + certificate="test-cert", + ca="test-ca", + chain="test-chain", + ) + + def emit_certificate_expiring_event(self) -> None: + self.charm.tls.certs.on.certificate_expiring.emit(certificate="test-cert", expiry=None) + + @staticmethod + def get_content_from_file(filename: str) -> str: + with open(filename, "r") as file: + content = file.read() + return content + + def no_secrets(self, include_certificate: bool = True) -> bool: + # Check whether there is no TLS secrets in the secret store. + secrets = [self.charm.get_secret(SCOPE, "ca"), self.charm.get_secret(SCOPE, "chain")] + if include_certificate: + secrets.append(self.charm.get_secret(SCOPE, "cert")) + return all(secret is None for secret in secrets) + + def relate_to_tls_certificates_operator(self) -> int: + # Relate the charm to the TLS certificates operator. + rel_id = self.harness.add_relation(RELATION_NAME, "tls-certificates-operator") + self.harness.add_relation_unit(rel_id, "tls-certificates-operator/0") + return rel_id + + def set_secrets(self) -> None: + # Set some TLS secrets in the secret store. + self.charm.set_secret(SCOPE, "ca", "test-ca") + self.charm.set_secret(SCOPE, "cert", "test-cert") + self.charm.set_secret(SCOPE, "chain", "test-chain") + + def setUp(self): + self.harness = Harness(PostgresqlOperatorCharm) + self.addCleanup(self.harness.cleanup) + + # Set up the initial relation and hooks. + self.peer_rel_id = self.harness.add_relation(PEER, "postgresql-k8s") + self.harness.add_relation_unit(self.peer_rel_id, "postgresql-k8s/0") + self.harness.begin() + self.charm = self.harness.charm + + @patch("charms.postgresql_k8s.v0.postgresql_tls.PostgreSQLTLS._request_certificate") + def test_on_set_tls_private_key(self, _request_certificate): + # Create a mock event. + mock_event = MagicMock(params={}) + + # Test without providing a private key. + self.charm.tls._on_set_tls_private_key(mock_event) + _request_certificate.assert_called_once_with(None) + + # Test providing the private key. + mock_event.params["private-key"] = "test-key" + _request_certificate.reset_mock() + self.charm.tls._on_set_tls_private_key(mock_event) + _request_certificate.assert_called_once_with("test-key") + + @patch_network_get(private_address="1.1.1.1") + @patch( + "charms.tls_certificates_interface.v1.tls_certificates.TLSCertificatesRequiresV1.request_certificate_creation" + ) + def test_request_certificate(self, _request_certificate_creation): + # Test without an established relation. + self.delete_secrets() + self.charm.tls._request_certificate(None) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "key")) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "csr")) + _request_certificate_creation.assert_not_called() + + # Test without providing a private key. + with self.harness.hooks_disabled(): + self.relate_to_tls_certificates_operator() + self.charm.tls._request_certificate(None) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "key")) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "csr")) + _request_certificate_creation.assert_called_once() + + # Test providing a private key. + _request_certificate_creation.reset_mock() + key = self.get_content_from_file(filename="tests/unit/key.pem") + self.charm.tls._request_certificate(key) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "key")) + self.assertIsNotNone(self.charm.get_secret(SCOPE, "csr")) + _request_certificate_creation.assert_called_once() + + def test_parse_tls_file(self): + # Test with a plain text key. + key = self.get_content_from_file(filename="tests/unit/key.pem") + parsed_key = self.charm.tls._parse_tls_file(key) + self.assertEqual(parsed_key, key.encode("utf-8")) + + # Test with a base64 encoded key. + key = self.get_content_from_file(filename="tests/unit/key.pem") + parsed_key = self.charm.tls._parse_tls_file( + base64.b64encode(key.encode("utf-8")).decode("utf-8") + ) + self.assertEqual(parsed_key, key.encode("utf-8")) + + @patch("charms.postgresql_k8s.v0.postgresql_tls.PostgreSQLTLS._request_certificate") + def test_on_tls_relation_joined(self, _request_certificate): + self.relate_to_tls_certificates_operator() + _request_certificate.assert_called_once_with(None) + + @patch_network_get(private_address="1.1.1.1") + @patch("charm.PostgresqlOperatorCharm.update_config") + def test_on_tls_relation_broken(self, _update_config): + _update_config.reset_mock() + rel_id = self.relate_to_tls_certificates_operator() + self.harness.remove_relation(rel_id) + _update_config.assert_called_once() + self.assertTrue(self.no_secrets()) + + @patch("charm.PostgresqlOperatorCharm.push_tls_files_to_workload") + def test_on_certificate_available(self, _push_tls_files_to_workload): + # Test with no provided or invalid CSR. + self.emit_certificate_available_event() + self.assertTrue(self.no_secrets()) + _push_tls_files_to_workload.assert_not_called() + + # Test providing CSR. + self.charm.set_secret(SCOPE, "csr", "test-csr") + self.emit_certificate_available_event() + self.assertEqual(self.charm.get_secret(SCOPE, "ca"), "test-ca") + self.assertEqual(self.charm.get_secret(SCOPE, "cert"), "test-cert") + self.assertEqual(self.charm.get_secret(SCOPE, "chain"), "test-chain") + _push_tls_files_to_workload.assert_called_once() + + @patch_network_get(private_address="1.1.1.1") + @patch( + "charms.tls_certificates_interface.v1.tls_certificates.TLSCertificatesRequiresV1.request_certificate_renewal" + ) + def test_on_certificate_expiring(self, _request_certificate_renewal): + # Test with no provided or invalid certificate. + self.emit_certificate_expiring_event() + self.assertTrue(self.no_secrets()) + + # Test providing a certificate. + self.charm.set_secret( + SCOPE, "key", self.get_content_from_file(filename="tests/unit/key.pem") + ) + self.charm.set_secret(SCOPE, "cert", "test-cert") + self.charm.set_secret(SCOPE, "csr", "test-csr") + self.emit_certificate_expiring_event() + self.assertTrue(self.no_secrets(include_certificate=False)) + _request_certificate_renewal.assert_called_once() + + @patch_network_get(private_address="1.1.1.1") + def test_get_sans(self): + sans = self.charm.tls._get_sans() + self.assertEqual(sans, ["postgresql-k8s-0", socket.getfqdn(), "1.1.1.1"]) + + def test_get_tls_extensions(self): + extensions = self.charm.tls._get_tls_extensions() + self.assertEqual(len(extensions), 1) + self.assertEqual(extensions[0].ca, True) + self.assertIsNone(extensions[0].path_length) + + def test_get_tls_files(self): + # Test with no TLS files available. + key, ca, certificate = self.charm.tls.get_tls_files() + self.assertIsNone(key) + self.assertIsNone(ca) + self.assertIsNone(certificate) + + # Test with TLS files available. + self.charm.set_secret(SCOPE, "key", "test-key") + self.charm.set_secret(SCOPE, "ca", "test-ca") + self.charm.set_secret(SCOPE, "cert", "test-cert") + key, ca, certificate = self.charm.tls.get_tls_files() + self.assertEqual(key, "test-key") + self.assertEqual(ca, "test-ca") + self.assertEqual(certificate, "test-cert")