Skip to content

Commit

Permalink
Improve idempotency in MLEngineHook.create_model (apache#7811)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Mar 26, 2020
1 parent 876ca9b commit bfd4251
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
40 changes: 34 additions & 6 deletions airflow/providers/google/cloud/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create_job(
hook = self.get_conn()

self._append_label(job)

self.log.info("Creating job.")
request = hook.projects().jobs().create( # pylint: disable=no-member
parent='projects/{}'.format(project_id),
body=job)
Expand Down Expand Up @@ -236,6 +236,8 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):
:type interval: int
:raises: googleapiclient.errors.HttpError
"""
self.log.info("Waiting for job. job_id=%s", job_id)

if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
Expand Down Expand Up @@ -415,16 +417,42 @@ def create_model(
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if not model['name']:
if 'name' not in model or not model['name']:
raise ValueError("Model name must be provided and "
"could not be an empty string")
project = 'projects/{}'.format(project_id)

self._append_label(model)

request = hook.projects().models().create( # pylint: disable=no-member
parent=project, body=model)
return request.execute()
try:
request = hook.projects().models().create( # pylint: disable=no-member
parent=project, body=model)
respone = request.execute()
except HttpError as e:
if e.resp.status != 409:
raise e
str(e) # Fills in the error_details field
if not e.error_details or len(e.error_details) != 1:
raise e

error_detail = e.error_details[0]
if error_detail["@type"] != 'type.googleapis.com/google.rpc.BadRequest':
raise e

if "fieldViolations" not in error_detail or len(error_detail['fieldViolations']) != 1:
raise e

field_violation = error_detail['fieldViolations'][0]
if (
field_violation["field"] != "model.name" or
field_violation["description"] != "A model with the same name already exists."
):
raise e
respone = self.get_model(
model_name=model['name'],
project_id=project_id
)

return respone

@CloudBaseHook.fallback_to_default_project_id
def get_model(
Expand Down
70 changes: 69 additions & 1 deletion tests/providers/google/cloud/hooks/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json
import unittest
from copy import deepcopy
from unittest import mock

import httplib2
from googleapiclient.errors import HttpError
from mock import PropertyMock

Expand Down Expand Up @@ -287,6 +288,73 @@ def test_create_model(self, mock_get_conn):
mock.call().projects().models().create().execute()
])

@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_create_model_idempotency(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
model = {
'name': model_name,
}
model_with_airflow_version = {
'name': model_name,
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}
project_path = 'projects/{}'.format(project_id)

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
create.return_value.
execute.side_effect
) = [
HttpError(
resp=httplib2.Response({"status": 409}),
content=json.dumps(
{
"error": {
"code": 409,
"message": "Field: model.name Error: A model with the same name already exists.",
"status": "ALREADY_EXISTS",
"details": [
{
"@type": "type.googleapis.com/google.rpc.BadRequest",
"fieldViolations": [
{
"field": "model.name",
"description": "A model with the same name already exists."
}
],
}
],
}
}
).encode(),
)
]

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
get.return_value.
execute.return_value
) = deepcopy(model)

create_model_response = self.hook.create_model(
project_id=project_id, model=deepcopy(model)
)

self.assertEqual(create_model_response, model)
mock_get_conn.assert_has_calls([
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute(),
])
mock_get_conn.assert_has_calls([
mock.call().projects().models().get(name='projects/test-project/models/test-model'),
mock.call().projects().models().get().execute()
])

@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_create_model_with_labels(self, mock_get_conn):
project_id = 'test-project'
Expand Down

0 comments on commit bfd4251

Please sign in to comment.