Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Feb 8, 2024
1 parent b587481 commit 076feab
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 58 deletions.
6 changes: 2 additions & 4 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,12 @@ def apply_from_yaml_file(
):
"""
Perform an action from a yaml file on a Pod.
This is done until the given Pod reaches given State, or raises an error.
:param yaml_file: Contains the path to yaml file.
:param yaml_objects: List of YAML objects; used instead of reading the yaml_file.
:param verbose: If True, print confirmation from create action. Default is False.
:param namespace: Contains the namespace to create all
resources inside. The namespace must preexist otherwise the resource creation will fail. If the API
object in the yaml file already contains a namespace definition this parameter has no effect.
:param namespace: Contains the namespace to create all resources inside. The namespace must
preexist otherwise the resource creation will fail.
"""
k8s_client = self.get_conn()

Expand Down
2 changes: 2 additions & 0 deletions tests/providers/google/cloud/hooks/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def _get_credentials(self):

@mock.patch("kubernetes.client.AppsV1Api")
def test_check_kueue_deployment_running(self, gke_deployment_hook, caplog):
self.gke_hook.get_credentials = self._get_credentials
gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = [
NOT_READY_DEPLOYMENT,
READY_DEPLOYMENT,
Expand All @@ -479,6 +480,7 @@ def test_check_kueue_deployment_running(self, gke_deployment_hook, caplog):

@mock.patch("kubernetes.client.AppsV1Api")
def test_check_kueue_deployment_raise_exception(self, gke_deployment_hook, caplog):
self.gke_hook.get_credentials = self._get_credentials
gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = ValueError()
with pytest.raises(ValueError):
self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME, namespace=NAMESPACE)
Expand Down
115 changes: 61 additions & 54 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
FILE_NAME = "/tmp/mock_name"
KUB_OP_PATH = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.{}"
GKE_HOOK_MODULE_PATH = "airflow.providers.google.cloud.hooks.kubernetes_engine"
GKE_HOOK_CLIENT_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook.get_cluster_manager_client"
GKE_HOOK_CLUSTER_AUTOSCALING_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook.check_cluster_autoscaling_ability"
GKE_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook"
GKE_POD_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEPodHook"
GKE_DEPLOYMENT_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEDeploymentHook"
Expand Down Expand Up @@ -99,7 +101,8 @@ class TestGoogleCloudPlatformContainerOperator:
],
)
@mock.patch(GKE_HOOK_PATH)
def test_create_execute(self, mock_hook, body):
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_create_execute(self, mock_hook, mock_cluster_hook, body):
operator = GKECreateClusterOperator(
project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID
)
Expand Down Expand Up @@ -256,15 +259,21 @@ def setup_method(self):
name=TASK_NAME,
namespace=NAMESPACE,
)
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT

def test_template_fields(self):
assert set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields)

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute(self, fetch_cluster_info_mock, file_mock, exec_mock):
def test_execute(self, fetch_cluster_info_mock, file_mock, mock_get_conn, exec_mock):
self.gke_op.execute(context=mock.MagicMock())
fetch_cluster_info_mock.assert_called_once()

Expand All @@ -288,9 +297,10 @@ def test_config_file_throws_error(self):
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(GKE_HOOK_PATH)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute_with_impersonation_service_account(
self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
self, fetch_cluster_info_mock, mock_cluster_hook, file_mock, exec_mock, get_con_mock
):
self.gke_op.impersonation_chain = "test_account@example.com"
self.gke_op.execute(context=mock.MagicMock())
Expand All @@ -303,9 +313,10 @@ def test_execute_with_impersonation_service_account(
)
@mock.patch(KUB_OPERATOR_EXEC)
@mock.patch(TEMP_FILE)
@mock.patch(GKE_HOOK_PATH)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_execute_with_impersonation_service_chain_one_element(
self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock
self, fetch_cluster_info_mock, mock_cluster_hook, file_mock, exec_mock, get_con_mock
):
self.gke_op.impersonation_chain = ["test_account@example.com"]
self.gke_op.execute(context=mock.MagicMock())
Expand All @@ -314,25 +325,20 @@ def test_execute_with_impersonation_service_chain_one_element(

@pytest.mark.db_test
@pytest.mark.parametrize("use_internal_ip", [True, False])
@mock.patch(f"{GKE_HOOK_PATH}.get_cluster")
def test_cluster_info(self, get_cluster_mock, use_internal_ip):
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_cluster_info(self, get_cluster_mock, mock_get_credentials, use_internal_ip):
get_cluster_mock.return_value = mock.MagicMock(
**{
"endpoint": "test-host",
"private_cluster_config.private_endpoint": "test-private-host",
"master_auth.cluster_ca_certificate": SSL_CA_CERT,
}
)
GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
use_internal_ip=use_internal_ip,
)
self.gke_op.execute(context=mock.MagicMock())
cluster_info = GKEClusterAuthDetails(
project_id=TEST_GCP_PROJECT_ID,
cluster_name=CLUSTER_NAME,
Expand All @@ -345,35 +351,26 @@ def test_cluster_info(self, get_cluster_mock, use_internal_ip):
assert ssl_ca_cert == SSL_CA_CERT

@pytest.mark.db_test
@mock.patch(f"{GKE_HOOK_PATH}.get_cluster")
def test_default_gcp_conn_id(self, get_cluster_mock):
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_OP_PATH}.get_cluster")
def test_default_gcp_conn_id(self, get_cluster_mock, mock_get_credentials):
get_cluster_mock.return_value = mock.MagicMock(
**{
"endpoint": "test-host",
"private_cluster_config.private_endpoint": "test-private-host",
"master_auth.cluster_ca_certificate": SSL_CA_CERT,
}
)
gke_op = GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
cluster_name=CLUSTER_NAME,
task_id=PROJECT_TASK_ID,
name=TASK_NAME,
namespace=NAMESPACE,
image=IMAGE,
)
gke_op._cluster_url = CLUSTER_URL
gke_op._ssl_ca_cert = SSL_CA_CERT
hook = gke_op.hook

assert hook.gcp_conn_id == "google_cloud_default"
assert self.gke_op.hook.gcp_conn_id == "google_cloud_default"

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
def test_gcp_conn_id(self, get_con_mock):
def test_gcp_conn_id(self, mock_get_credentials):
gke_op = GKEStartPodOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
Expand Down Expand Up @@ -452,7 +449,8 @@ def test_on_finish_action_handler(


class TestGKEStartKueueInsideClusterOperator:
def setup_method(self):
@pytest.fixture(autouse=True)
def setup_test(self):
self.gke_op = GKEStartKueueInsideClusterOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
Expand All @@ -465,34 +463,26 @@ def setup_method(self):
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT

def refresh_token(request):
self.credentials.token = "New"

self.credentials = mock.MagicMock()
self.credentials.token = "Old"
self.credentials.expired = False
self.credentials.refresh = refresh_token

def _get_credentials(self):
return self.credentials

@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
@mock.patch(GKE_DEPLOYMENT_HOOK_PATH)
def test_execute(self, mock_depl_hook, mock_hook, fetch_cluster_info_mock):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.get_credentials.return_value = self.credentials
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT
self.gke_op.execute(context=mock.MagicMock())
mock_hook.return_value.get_cluster.assert_called_once()

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
@mock.patch(GKE_POD_HOOK_PATH)
@mock.patch(GKE_DEPLOYMENT_HOOK_PATH)
def test_execute_autoscaled_cluster(
self, mock_depl_hook, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog
self, mock_depl_hook, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True
Expand All @@ -501,23 +491,31 @@ def test_execute_autoscaled_cluster(
self.gke_op.execute(context=mock.MagicMock())
assert "Kueue installed successfully!" in caplog.text

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
@mock.patch(GKE_POD_HOOK_PATH)
def test_execute_autoscaled_cluster_check_error(
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True
mock_pod_hook.return_value.apply_from_yaml_file.side_effect = FailToCreateError("error")
self.gke_op.execute(context=mock.MagicMock())
assert "Kueue is already enabled for the cluster" in caplog.text

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(GKE_HOOK_PATH)
@mock.patch(GKE_POD_HOOK_PATH)
def test_execute_non_autoscaled_cluster_check_error(
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog
self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, mock_get_credentials, caplog
):
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False
Expand All @@ -527,11 +525,15 @@ def test_execute_non_autoscaled_cluster_check_error(
in caplog.text
)

@mock.patch(
"airflow.hooks.base.BaseHook.get_connections",
return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))],
)
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
@mock.patch(f"{GKE_HOOK_PATH}.get_cluster")
@mock.patch(GKE_HOOK_PATH)
def test_execute_with_impersonation_service_account(
self, mock_hook, mock_get_cluster, fetch_cluster_info_mock
self, mock_hook, mock_get_cluster, fetch_cluster_info_mock, get_con_mock
):
mock_get_cluster.return_value = mock.MagicMock()
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
Expand Down Expand Up @@ -574,11 +576,16 @@ def test_default_gcp_conn_id(self):

assert hook.gcp_conn_id == "google_cloud_default"

@mock.patch.dict(os.environ, {})
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
"airflow.hooks.base.BaseHook.get_connection",
return_value=Connection(conn_id="test_conn"),
)
def test_gcp_conn_id(self, get_con_mock):
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=mock.MagicMock(),
)
def test_gcp_conn_id(self, mock_get_credentials, get_con_mock):
gke_op = GKEStartKueueInsideClusterOperator(
project_id=TEST_GCP_PROJECT_ID,
location=PROJECT_LOCATION,
Expand Down

0 comments on commit 076feab

Please sign in to comment.