diff --git a/qds_sdk/engine.py b/qds_sdk/engine.py index ec041b44..3070f699 100644 --- a/qds_sdk/engine.py +++ b/qds_sdk/engine.py @@ -28,6 +28,8 @@ def set_engine_config(self, dbtap_id=None, fernet_key=None, overrides=None, + airflow_version=None, + airflow_python_version=None, is_ha=None, enable_rubix=None): ''' @@ -59,6 +61,10 @@ def set_engine_config(self, overrides: Airflow configuration to override the default settings.Use the following syntax for overrides:
.=\n
.=... + airflow_version: The airflow version. + + airflow_python_version: The python version for the environment on the cluster. + is_ha: Enabling HA config for cluster is_deeplearning : this is a deeplearning cluster config enable_rubix: Enable rubix on the cluster @@ -68,7 +74,7 @@ def set_engine_config(self, self.set_hadoop_settings(custom_hadoop_config, use_qubole_placement_policy, is_ha, fairscheduler_config_xml, default_pool, enable_rubix) self.set_presto_settings(presto_version, custom_presto_config) self.set_spark_settings(spark_version, custom_spark_config) - self.set_airflow_settings(dbtap_id, fernet_key, overrides) + self.set_airflow_settings(dbtap_id, fernet_key, overrides, airflow_version, airflow_python_version) def set_fairscheduler_settings(self, fairscheduler_config_xml=None, @@ -106,10 +112,14 @@ def set_spark_settings(self, def set_airflow_settings(self, dbtap_id=None, fernet_key=None, - overrides=None): + overrides=None, + airflow_version="1.10.0", + airflow_python_version="2.7"): self.airflow_settings['dbtap_id'] = dbtap_id self.airflow_settings['fernet_key'] = fernet_key self.airflow_settings['overrides'] = overrides + self.airflow_settings['version'] = airflow_version + self.airflow_settings['airflow_python_version'] = airflow_python_version def set_engine_config_settings(self, arguments): custom_hadoop_config = util._read_file(arguments.custom_hadoop_config_file) @@ -128,6 +138,8 @@ def set_engine_config_settings(self, arguments): dbtap_id=arguments.dbtap_id, fernet_key=arguments.fernet_key, overrides=arguments.overrides, + airflow_version=arguments.airflow_version, + airflow_python_version=arguments.airflow_python_version, enable_rubix=arguments.enable_rubix) @staticmethod @@ -215,4 +227,12 @@ def engine_parser(argparser): dest="overrides", default=None, help="overrides for airflow cluster", ) + airflow_settings_group.add_argument("--airflow-version", + dest="airflow_version", + default=None, + help="airflow version for airflow cluster", ) + airflow_settings_group.add_argument("--airflow-python-version", + dest="airflow_python_version", + default=None, + help="python environment version for airflow cluster", ) diff --git a/tests/test_clusterv2.py b/tests/test_clusterv2.py index 1b33f4e2..2c3ebaae 100644 --- a/tests/test_clusterv2.py +++ b/tests/test_clusterv2.py @@ -444,6 +444,27 @@ def test_spark_engine_config(self): 'custom_spark_config': 'spark-overrides'}}, 'cluster_info': {'label': ['test_label'],}}) + def test_airflow_engine_config(self): + with tempfile.NamedTemporaryFile() as temp: + temp.write("config.properties:\na=1\nb=2".encode("utf8")) + temp.flush() + sys.argv = ['qds.py', '--version', 'v2', 'cluster', 'create', '--label', 'test_label', + '--flavour', 'airflow', '--dbtap-id', '1', '--fernet-key', '-1', '--overrides', 'airflow_overrides', '--airflow-version', '1.10.0', '--airflow-python-version', '2.7'] + Qubole.cloud = None + print_command() + Connection._api_call = Mock(return_value={}) + qds.main() + Connection._api_call.assert_called_with('POST', 'clusters', + {'engine_config': + {'flavour': 'airflow', + 'airflow_settings': { + 'dbtap_id': '1', + 'fernet_key': '-1', + 'overrides': 'airflow_overrides', + 'version': '1.10.0', + 'airflow_python_version': '2.7' + }}, + 'cluster_info': {'label': ['test_label'],}}) def test_persistent_security_groups_v2(self): sys.argv = ['qds.py', '--version', 'v2', 'cluster', 'create', '--label', 'test_label',