diff --git a/chart/templates/secrets/metadata-connection-secret.yaml b/chart/templates/secrets/metadata-connection-secret.yaml index 031d52735d7..c4061fecead 100644 --- a/chart/templates/secrets/metadata-connection-secret.yaml +++ b/chart/templates/secrets/metadata-connection-secret.yaml @@ -18,12 +18,13 @@ ################################ ## Airflow Metadata Secret ################################# -{{- if (and .Values.data.metadataConnection (not .Values.data.metadataSecretName)) }} -{{- $postgresHost := .Values.data.metadataConnection.host | default (printf "%s-%s.%s.svc.cluster.local" .Release.Name "postgresql" .Release.Namespace) }} +{{- if not .Values.data.metadataSecretName }} +{{- $metadataHost := .Values.data.metadataConnection.host | default (printf "%s-%s.%s.svc.cluster.local" .Release.Name "postgresql" .Release.Namespace) }} {{- $pgbouncerHost := (printf "%s-%s.%s.svc.cluster.local" .Release.Name "pgbouncer" .Release.Namespace) }} -{{- $host := ternary $pgbouncerHost $postgresHost .Values.pgbouncer.enabled }} +{{- $host := ternary $pgbouncerHost $metadataHost .Values.pgbouncer.enabled }} {{- $port := ((ternary .Values.ports.pgbouncer .Values.data.metadataConnection.port .Values.pgbouncer.enabled) | toString) }} {{- $database := (ternary (printf "%s-%s" .Release.Name "metadata") .Values.data.metadataConnection.db .Values.pgbouncer.enabled) }} +{{- $extras := ternary (printf "?sslmode=%s" .Values.data.metadataConnection.sslmode) "" (eq .Values.data.metadataConnection.protocol "postgresql") }} kind: Secret apiVersion: v1 @@ -38,5 +39,5 @@ metadata: {{- end }} type: Opaque data: - connection: {{ (printf "postgresql://%s:%s@%s:%s/%s?sslmode=%s" .Values.data.metadataConnection.user .Values.data.metadataConnection.pass $host $port $database .Values.data.metadataConnection.sslmode) | b64enc | quote }} + connection: {{ (printf "%s://%s:%s@%s:%s/%s%s" .Values.data.metadataConnection.protocol .Values.data.metadataConnection.user .Values.data.metadataConnection.pass $host $port $database $extras) | b64enc | quote }} {{- end }} diff --git a/chart/templates/secrets/result-backend-connection-secret.yaml b/chart/templates/secrets/result-backend-connection-secret.yaml index 693404d1856..57835c9888c 100644 --- a/chart/templates/secrets/result-backend-connection-secret.yaml +++ b/chart/templates/secrets/result-backend-connection-secret.yaml @@ -18,8 +18,10 @@ ################################ ## Airflow Result Backend Secret ################################# -{{- if (and .Values.data.resultBackendConnection (not .Values.data.resultBackendSecretName)) }} +{{- if not .Values.data.resultBackendSecretName }} +{{- if or (eq .Values.executor "CeleryExecutor") (eq .Values.executor "CeleryKubernetesExecutor") }} {{- $host := .Values.data.resultBackendConnection.host | default (printf "%s-%s" .Release.Name "postgresql") }} +{{- $extras := ternary (printf "?sslmode=%s" .Values.data.resultBackendConnection.sslmode) "" (eq .Values.data.resultBackendConnection.protocol "postgresql") }} kind: Secret apiVersion: v1 metadata: @@ -33,5 +35,6 @@ metadata: {{- end }} type: Opaque data: - connection: {{ (printf "db+postgresql://%s:%s@%s:%s/%s?sslmode=%s" .Values.data.resultBackendConnection.user .Values.data.resultBackendConnection.pass (ternary (printf "%s-%s" .Release.Name "pgbouncer") $host .Values.pgbouncer.enabled) ((ternary .Values.ports.pgbouncer .Values.data.resultBackendConnection.port .Values.pgbouncer.enabled) | toString) (ternary (printf "%s-%s" .Release.Name "result-backend") .Values.data.resultBackendConnection.db .Values.pgbouncer.enabled) .Values.data.resultBackendConnection.sslmode) | b64enc | quote }} + connection: {{ (printf "db+%s://%s:%s@%s:%s/%s%s" .Values.data.resultBackendConnection.protocol .Values.data.resultBackendConnection.user .Values.data.resultBackendConnection.pass (ternary (printf "%s-%s" .Release.Name "pgbouncer") $host .Values.pgbouncer.enabled) ((ternary .Values.ports.pgbouncer .Values.data.resultBackendConnection.port .Values.pgbouncer.enabled) | toString) (ternary (printf "%s-%s" .Release.Name "result-backend") .Values.data.resultBackendConnection.db .Values.pgbouncer.enabled) $extras) | b64enc | quote }} +{{- end }} {{- end }} diff --git a/chart/tests/test_metadata_connection_secret.py b/chart/tests/test_metadata_connection_secret.py new file mode 100644 index 00000000000..fe90cc0bc25 --- /dev/null +++ b/chart/tests/test_metadata_connection_secret.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import unittest + +import jmespath + +from tests.helm_template_generator import render_chart + + +class MetadataConnectionSecretTest(unittest.TestCase): + + non_chart_database_values = { + "user": "someuser", + "pass": "somepass", + "host": "somehost", + "port": 7777, + "db": "somedb", + } + + def test_should_not_generate_a_document_if_using_existing_secret(self): + docs = render_chart( + values={"data": {"metadataSecretName": "foo"}}, + show_only=["templates/secrets/metadata-connection-secret.yaml"], + ) + + assert 0 == len(docs) + + def _get_connection(self, values: dict) -> str: + docs = render_chart( + values=values, + show_only=["templates/secrets/metadata-connection-secret.yaml"], + ) + encoded_connection = jmespath.search("data.connection", docs[0]) + return base64.b64decode(encoded_connection).decode() + + def test_default_connection(self): + connection = self._get_connection({}) + + assert ( + "postgresql://postgres:postgres@RELEASE-NAME-postgresql.default.svc.cluster.local:5432" + "/postgres?sslmode=disable" == connection + ) + + def test_should_set_pgbouncer_overrides_when_enabled(self): + values = {"pgbouncer": {"enabled": True}} + connection = self._get_connection(values) + + # host, port, dbname get overridden + assert ( + "postgresql://postgres:postgres@RELEASE-NAME-pgbouncer.default.svc.cluster.local:6543" + "/RELEASE-NAME-metadata?sslmode=disable" == connection + ) + + def test_should_set_pgbouncer_overrides_with_non_chart_database_when_enabled(self): + values = { + "pgbouncer": {"enabled": True}, + "data": {"metadataConnection": {**self.non_chart_database_values}}, + } + connection = self._get_connection(values) + + # host, port, dbname still get overridden even with an non-chart db + assert ( + "postgresql://someuser:somepass@RELEASE-NAME-pgbouncer.default.svc.cluster.local:6543" + "/RELEASE-NAME-metadata?sslmode=disable" == connection + ) + + def test_should_correctly_use_non_chart_database(self): + values = { + "data": { + "metadataConnection": { + **self.non_chart_database_values, + "sslmode": "require", + } + } + } + connection = self._get_connection(values) + + assert "postgresql://someuser:somepass@somehost:7777/somedb?sslmode=require" == connection + + def test_should_support_non_postgres_db(self): + values = { + "data": { + "metadataConnection": { + **self.non_chart_database_values, + "protocol": "mysql", + } + } + } + connection = self._get_connection(values) + + # sslmode is only added for postgresql + assert "mysql://someuser:somepass@somehost:7777/somedb" == connection diff --git a/chart/tests/test_result_backend_connection_secret.py b/chart/tests/test_result_backend_connection_secret.py new file mode 100644 index 00000000000..74b61b6d25c --- /dev/null +++ b/chart/tests/test_result_backend_connection_secret.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import unittest + +import jmespath +from parameterized import parameterized + +from tests.helm_template_generator import render_chart + + +class ResultBackendConnectionSecretTest(unittest.TestCase): + + non_chart_database_values = { + "user": "someuser", + "pass": "somepass", + "host": "somehost", + "port": 7777, + "db": "somedb", + } + + def test_should_not_generate_a_document_if_using_existing_secret(self): + docs = render_chart( + values={"data": {"resultBackendSecretName": "foo"}}, + show_only=["templates/secrets/result-backend-connection-secret.yaml"], + ) + + assert 0 == len(docs) + + @parameterized.expand( + [ + ("CeleryExecutor", 1), + ("CeleryKubernetesExecutor", 1), + ("LocalExecutor", 0), + ] + ) + def test_should_a_document_be_generated_for_executor(self, executor, expected_doc_count): + docs = render_chart( + values={"executor": executor}, + show_only=["templates/secrets/result-backend-connection-secret.yaml"], + ) + + assert expected_doc_count == len(docs) + + def _get_connection(self, values: dict) -> str: + docs = render_chart( + values=values, + show_only=["templates/secrets/result-backend-connection-secret.yaml"], + ) + encoded_connection = jmespath.search("data.connection", docs[0]) + return base64.b64decode(encoded_connection).decode() + + def test_default_connection(self): + connection = self._get_connection({}) + + assert ( + "db+postgresql://postgres:postgres@RELEASE-NAME-postgresql:5432/postgres?sslmode=disable" + == connection + ) + + def test_should_set_pgbouncer_overrides_when_enabled(self): + values = {"pgbouncer": {"enabled": True}} + connection = self._get_connection(values) + + # host, port, dbname get overridden + assert ( + "db+postgresql://postgres:postgres@RELEASE-NAME-pgbouncer:6543" + "/RELEASE-NAME-result-backend?sslmode=disable" == connection + ) + + def test_should_set_pgbouncer_overrides_with_non_chart_database_when_enabled(self): + values = { + "pgbouncer": {"enabled": True}, + "data": {"resultBackendConnection": {**self.non_chart_database_values}}, + } + connection = self._get_connection(values) + + # host, port, dbname still get overridden even with an non-chart db + assert ( + "db+postgresql://someuser:somepass@RELEASE-NAME-pgbouncer:6543" + "/RELEASE-NAME-result-backend?sslmode=disable" == connection + ) + + def test_should_correctly_use_non_chart_database(self): + values = { + "data": { + "resultBackendConnection": { + **self.non_chart_database_values, + "sslmode": "require", + } + } + } + connection = self._get_connection(values) + + assert "db+postgresql://someuser:somepass@somehost:7777/somedb?sslmode=require" == connection + + def test_should_support_non_postgres_db(self): + values = { + "data": { + "resultBackendConnection": { + **self.non_chart_database_values, + "protocol": "mysql", + } + } + } + connection = self._get_connection(values) + + # sslmode is only added for postgresql + assert "db+mysql://someuser:somepass@somehost:7777/somedb" == connection diff --git a/chart/values.schema.json b/chart/values.schema.json index d30810dd3d3..90d708102a4 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -435,6 +435,10 @@ "description": "The user's password.", "type": "string" }, + "protocol": { + "description": "The database protocol.", + "type": "string" + }, "host": { "description": "The database host.", "type": [ @@ -469,6 +473,10 @@ "description": "The database password.", "type": "string" }, + "protocol": { + "description": "The database protocol.", + "type": "string" + }, "host": { "description": "The database host.", "type": [ diff --git a/chart/values.yaml b/chart/values.yaml index d57b25d3f6f..7d02f13d200 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -226,6 +226,7 @@ data: metadataConnection: user: postgres pass: postgres + protocol: postgresql host: ~ port: 5432 db: postgres @@ -233,6 +234,7 @@ data: resultBackendConnection: user: postgres pass: postgres + protocol: postgresql host: ~ port: 5432 db: postgres