Skip to content

Commit

Permalink
Recreate k8s client on auth failure (kubeflow#215)
Browse files Browse the repository at this point in the history
* Recreate k8s client on auth failure

* Recreate client on each call

* Fix syntax

* Remove debug msg

* Try to fix unit tests
  • Loading branch information
richardsliu authored and k8s-ci-robot committed Sep 22, 2018
1 parent 05e2145 commit 029a9d0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
28 changes: 17 additions & 11 deletions py/kubeflow/testing/argo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def log_status(workflow):


def handle_retriable_exception(exception):
if isinstance(exception, rest.ApiException) and exception.status == 401:
# See https://github.com/kubeflow/testing/issues/207.
# If we get an unauthorized response, just reload the kubeconfig and retry.
if (isinstance(exception, rest.ApiException) and
(exception.status == 401 or exception.status == 403)):
# Due to https://github.com/kubernetes-client/python-base/issues/59,
# we need to reload the kube config (which refreshes the GCP token).
# TODO(richardsliu): Remove this workaround when the k8s client issue
# is resolved.
util.load_kube_config()
return True
return not isinstance(exception, util.TimeoutError)
Expand All @@ -49,27 +52,31 @@ def handle_retriable_exception(exception):
@retry(wait_exponential_multiplier=1000, wait_exponential_max=10000,
stop_max_delay=5*60*1000,
retry_on_exception=handle_retriable_exception)
def get_namespaced_custom_object_with_retries(client, namespace, name):
def get_namespaced_custom_object_with_retries(namespace, name):
"""Call get_namespaced_customer_object API with retries.
Args:
client: K8s api client.
namespace: namespace for the workflow.
name: name of the workflow.
"""
# Due to https://github.com/kubernetes-client/python-base/issues/59,
# we need to recreate the API client since it may contain stale auth
# tokens.
# TODO(richardsliu): Remove this workaround when the k8s client issue
# is resolved.
client = k8s_client.ApiClient()
crd_api = k8s_client.CustomObjectsApi(client)
return crd_api.get_namespaced_custom_object(
GROUP, VERSION, namespace, PLURAL, name)


def wait_for_workflows(client, namespace, names,
def wait_for_workflows(namespace, names,
timeout=datetime.timedelta(minutes=30),
polling_interval=datetime.timedelta(seconds=30),
status_callback=None):
"""Wait for multiple workflows to finish.
Args:
client: K8s api client.
namespace: namespace for the workflow.
names: Names of the workflows to wait for.
timeout: How long to wait for the workflow.
Expand All @@ -88,7 +95,7 @@ def wait_for_workflows(client, namespace, names,
all_results = []

for n in names:
results = get_namespaced_custom_object_with_retries(client, namespace, n)
results = get_namespaced_custom_object_with_retries(namespace, n)
all_results.append(results)
if status_callback:
status_callback(results)
Expand All @@ -111,14 +118,13 @@ def wait_for_workflows(client, namespace, names,

return []

def wait_for_workflow(client, namespace, name,
def wait_for_workflow(namespace, name,
timeout=datetime.timedelta(minutes=30),
polling_interval=datetime.timedelta(seconds=30),
status_callback=None):
"""Wait for the specified workflow to finish.
Args:
client: K8s api client.
namespace: namespace for the workflow.
name: Name of the workflow
timeout: How long to wait for the workflow.
Expand All @@ -130,6 +136,6 @@ def wait_for_workflow(client, namespace, name,
Raises:
TimeoutError: If timeout waiting for the job to finish.
"""
results = wait_for_workflows(client, namespace, [name],
results = wait_for_workflows(namespace, [name],
timeout, polling_interval, status_callback)
return results[0]
4 changes: 1 addition & 3 deletions py/kubeflow/testing/run_e2e_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import datetime
import fnmatch
import logging
from kubernetes import client as k8s_client
import os
import tempfile
from kubeflow.testing import argo_client
Expand Down Expand Up @@ -132,7 +131,6 @@ def run(args, file_handler): # pylint: disable=too-many-statements,too-many-bran
util.configure_kubectl(args.project, args.zone, args.cluster)
util.load_kube_config()

api_client = k8s_client.ApiClient()
workflow_names = []
ui_urls = {}

Expand Down Expand Up @@ -237,7 +235,7 @@ def run(args, file_handler): # pylint: disable=too-many-statements,too-many-bran
success = True
workflow_phase = {}
try:
results = argo_client.wait_for_workflows(api_client, get_namespace(args),
results = argo_client.wait_for_workflows(get_namespace(args),
workflow_names,
timeout=datetime.timedelta(minutes=60),
status_callback=argo_client.log_status)
Expand Down
2 changes: 2 additions & 0 deletions py/kubeflow/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ def _save_kube_config(config_map):
kubernetes_configuration.Configuration.set_default(config)
else:
loader.load_and_set(client_configuration) # pylint: disable=too-many-function-args
# Dump the loaded config.
run(["kubectl", "config", "view"])


def maybe_activate_service_account():
Expand Down
15 changes: 7 additions & 8 deletions py/kubeflow/tests/argo_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest

from kubeflow.testing import argo_client
from kubernetes import client as k8s_client
import mock
import os
import yaml
Expand All @@ -13,14 +12,14 @@ def setUp(self):
self.test_dir = os.path.join(os.path.dirname(__file__), "test-data")

def test_wait_for_workflow(self):
api_client = mock.MagicMock(spec=k8s_client.ApiClient)
with mock.patch("kubeflow.testing.argo_client.k8s_client.ApiClient") as mock_client:
with open(os.path.join(self.test_dir, "successful_workflow.yaml")) as hf:
response = yaml.load(hf)

with open(os.path.join(self.test_dir, "successful_workflow.yaml")) as hf:
response = yaml.load(hf)

api_client.call_api.return_value = response
result = argo_client.wait_for_workflow(api_client, "some-namespace", "some-set")
self.assertIsNotNone(result)
client = mock_client.return_value
client.call_api.return_value = response
result = argo_client.wait_for_workflow("some-namespace", "some-set")
self.assertIsNotNone(result)

if __name__ == "__main__":
unittest.main()

0 comments on commit 029a9d0

Please sign in to comment.