Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Feb 13, 2024
1 parent 88ad53e commit ff69bca
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 61 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def get_cluster(

def check_cluster_autoscaling_ability(self, cluster: Cluster | dict):
"""
Helper method to check if the specified Cluster has ability to autoscale.
Check if the specified Cluster has ability to autoscale.
Cluster should be Autopilot, with Node Auto-provisioning or regular auto-scaled node pools.
Returns True if the Cluster supports autoscaling, otherwise returns False.
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self._ssl_ca_cert = None

def fetch_cluster_info(self) -> tuple[str, str | None]:
"""Fetches cluster info for connecting to it."""
"""Fetch cluster info for connecting to it."""
cluster = self.cluster_hook.get_cluster(
name=self.cluster_name,
project_id=self.project_id,
Expand Down Expand Up @@ -530,7 +530,7 @@ def pod_hook(self) -> GKEPodHook:

@staticmethod
def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
"""Helper method to download content of YAML file and separate it into several dictionaries."""
"""Download content of YAML file and separate it into several dictionaries."""
response = requests.get(kueue_yaml_url, allow_redirects=True)
yaml_dicts = []
if response.status_code == 200:
Expand Down Expand Up @@ -733,7 +733,6 @@ def hook(self) -> GKEPodHook:
return hook

def execute(self, context: Context):
"""Executes process of creating pod and executing provided command inside it."""
"""Execute process of creating pod and executing provided command inside it."""
self.fetch_cluster_info()
return super().execute(context)
Expand Down
109 changes: 52 additions & 57 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
from airflow.providers.google.cloud.operators.kubernetes_engine import (
GKEClusterAuthDetails,
GKECreateClusterOperator,
GKEDeleteClusterOperator,
GKEStartKueueInsideClusterOperator,
Expand Down Expand Up @@ -74,6 +73,9 @@
KUB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute"
TEMP_FILE = "tempfile.NamedTemporaryFile"
GKE_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator"
GKE_CREATE_CLUSTER_PATH = (
"airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator"
)
GKE_CLUSTER_AUTH_DETAILS_PATH = (
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEClusterAuthDetails"
)
Expand All @@ -100,8 +102,8 @@ class TestGoogleCloudPlatformContainerOperator:
PROJECT_BODY_CREATE_CLUSTER_NODE_POOLS,
],
)
@mock.patch(f"{GKE_HOOK_PATH}.create_cluster")
@mock.patch(GKE_HOOK_PATH)
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_create_execute(self, mock_hook, mock_cluster_hook, body):
operator = GKECreateClusterOperator(
project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID
Expand Down Expand Up @@ -259,21 +261,15 @@ def setup_method(self):
name=TASK_NAME,
namespace=NAMESPACE,
)
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT

def test_template_fields(self):
assert set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields)

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute(self, fetch_cluster_info_mock, file_mock, mock_get_conn, exec_mock):
def test_execute(self, fetch_cluster_info_mock, file_mock, exec_mock):
self.gke_op.execute(context=mock.MagicMock())
fetch_cluster_info_mock.assert_called_once()

Expand All @@ -297,10 +293,9 @@ def test_config_file_throws_error(self):
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(GKE_HOOK_PATH)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute_with_impersonation_service_account(
self, fetch_cluster_info_mock, mock_cluster_hook, file_mock, exec_mock, get_con_mock
self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
):
self.gke_op.impersonation_chain = "test_account@example.com"
self.gke_op.execute(context=mock.MagicMock())
Expand All @@ -313,10 +308,9 @@ def test_execute_with_impersonation_service_account(
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(GKE_HOOK_PATH)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute_with_impersonation_service_chain_one_element(
self, fetch_cluster_info_mock, mock_cluster_hook, file_mock, exec_mock, get_con_mock
self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
):
self.gke_op.impersonation_chain = ["test_account@example.com"]
self.gke_op.execute(context=mock.MagicMock())
Expand All @@ -325,52 +319,52 @@ def test_execute_with_impersonation_service_chain_one_element(

@pytest.mark.db_test
@pytest.mark.parametrize("use_internal_ip", [True, False])
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_cluster_info(self, get_cluster_mock, mock_get_credentials, use_internal_ip):
@mock.patch(f"{GKE_HOOK_PATH}.get_cluster")
def test_cluster_info(self, get_cluster_mock, use_internal_ip):
get_cluster_mock.return_value = mock.MagicMock(
**{
"endpoint": "test-host",
"private_cluster_config.private_endpoint": "test-private-host",
"master_auth.cluster_ca_certificate": SSL_CA_CERT,
}
)
self.gke_op.execute(context=mock.MagicMock())
cluster_info = GKEClusterAuthDetails(
gke_op = GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
use_internal_ip=use_internal_ip,
cluster_hook=get_cluster_mock,
)
cluster_url, ssl_ca_cert = cluster_info.fetch_cluster_info()
cluster_url, ssl_ca_cert = gke_op.fetch_cluster_info()

assert cluster_url == CLUSTER_PRIVATE_URL if use_internal_ip else CLUSTER_URL
assert ssl_ca_cert == SSL_CA_CERT

@pytest.mark.db_test
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_default_gcp_conn_id(self, get_cluster_mock, mock_get_credentials):
get_cluster_mock.return_value = mock.MagicMock(
**{
"endpoint": "test-host",
"private_cluster_config.private_endpoint": "test-private-host",
"master_auth.cluster_ca_certificate": SSL_CA_CERT,
}
def test_default_gcp_conn_id(self):
gke_op = GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
)
assert self.gke_op.hook.gcp_conn_id == "google_cloud_default"
gke_op._cluster_url = CLUSTER_URL
gke_op._ssl_ca_cert = SSL_CA_CERT
hook = gke_op.hook

assert hook.gcp_conn_id == "google_cloud_default"

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
def test_gcp_conn_id(self, mock_get_credentials):
def test_gcp_conn_id(self, get_con_mock):
gke_op = GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
Expand Down Expand Up @@ -468,14 +462,15 @@ def setup_test(self):
@mock.patch(GKE_DEPLOYMENT_HOOK_PATH)
def test_execute(self, mock_depl_hook, mock_hook, fetch_cluster_info_mock):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT
self.gke_op.execute(context=mock.MagicMock())
mock_hook.return_value.get_cluster.assert_called_once()

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
Expand All @@ -485,15 +480,16 @@ def test_execute_autoscaled_cluster(
self, mock_depl_hook, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True
mock_pod_hook.return_value.apply_from_yaml_file.side_effect = mock.MagicMock()
mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT
self.gke_op.execute(context=mock.MagicMock())
assert "Kueue installed successfully!" in caplog.text

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
Expand All @@ -502,14 +498,15 @@ def test_execute_autoscaled_cluster_check_error(
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True
mock_pod_hook.return_value.apply_from_yaml_file.side_effect = FailToCreateError("error")
self.gke_op.execute(context=mock.MagicMock())
assert "Kueue is already enabled for the cluster" in caplog.text

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
Expand All @@ -518,6 +515,7 @@ def test_execute_non_autoscaled_cluster_check_error(
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
self.gke_op.execute(context=mock.MagicMock())
assert (
Expand All @@ -526,16 +524,16 @@ def test_execute_non_autoscaled_cluster_check_error(
)

@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(f"{GKE_HOOK_PATH}.get_cluster")
@mock.patch(GKE_HOOK_PATH)
def test_execute_with_impersonation_service_account(
self, mock_hook, mock_get_cluster, fetch_cluster_info_mock, get_con_mock
):
mock_get_cluster.return_value = mock.MagicMock()
mock_get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
self.gke_op.impersonation_chain = "test_account@example.com"
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
Expand All @@ -544,14 +542,15 @@ def test_execute_with_impersonation_service_account(

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
def test_execute_with_impersonation_service_chain_one_element(
self, mock_hook, fetch_cluster_info_mock, get_con_mock
):
mock_hook.return_value.get_cluster.side_effect = [mock.MagicMock(), mock.MagicMock()]
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
self.gke_op.impersonation_chain = ["test_account@example.com"]
Expand Down Expand Up @@ -579,13 +578,9 @@ def test_default_gcp_conn_id(self):
@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})),
)
def test_gcp_conn_id(self, mock_get_credentials, get_con_mock):
def test_gcp_conn_id(self, mock_get_credentials):
gke_op = GKEStartKueueInsideClusterOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
Expand Down Expand Up @@ -626,7 +621,7 @@ def setup_method(self):
@mock.patch(KUB_OP_PATH.format("build_pod_request_obj"))
@mock.patch(KUB_OP_PATH.format("get_or_create_pod"))
@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
"airflow.hooks.base.BaseHook.get_connection",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
Expand Down

0 comments on commit ff69bca

Please sign in to comment.