Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MLEngine code and add deploy and set_default commands #864

Merged
merged 4 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion component_sdk/python/kfp_component/google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import ml_engine, dataflow
from . import ml_engine, dataflow
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@
from ._create_version import create_version
from ._delete_version import delete_version
from ._train import train
from ._batch_predict import batch_predict
from ._batch_predict import batch_predict
from ._deploy import deploy
from ._set_default_version import set_default_version
30 changes: 13 additions & 17 deletions component_sdk/python/kfp_component/google/ml_engine/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,79 +81,75 @@ def create_model(self, project_id, model):
body = model
).execute()

def get_model(self, project_id, model_name):
def get_model(self, model_name):
"""Gets a model.

Args:
project_id: the ID of the parent project.
model_name: the name of the model.
Returns:
The retrieved model.
"""
return self._ml_client.projects().models().get(
name = 'projects/{}/models/{}'.format(
project_id, model_name)
name = model_name
).execute()

def create_version(self, project_id, model_name, version):
def create_version(self, model_name, version):
"""Creates a new version.

Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version: the payload of the version.

Returns:
The created version.
"""
return self._ml_client.projects().models().versions().create(
parent = 'projects/{}/models/{}'.format(project_id, model_name),
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
parent = model_name,
body = version
).execute()

def get_version(self, project_id, model_name, version_name):
def get_version(self, version_name):
"""Gets a version.

Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version_name: the name of the version.

Returns:
The retrieved version. None if the version is not found.
"""
try:
return self._ml_client.projects().models().versions().get(
name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
name = version_name
).execute()
except errors.HttpError as e:
if e.resp.status == 404:
return None
raise

def delete_version(self, project_id, model_name, version_name):
def delete_version(self, version_name):
"""Deletes a version.

Args:
project_id: the ID of the parent project.
model_name: the name of the parent model.
version_name: the name of the version.

Returns:
The delete operation. None if the version is not found.
"""
try:
return self._ml_client.projects().models().versions().delete(
name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
name = version_name
).execute()
except errors.HttpError as e:
if e.resp.status == 404:
logging.info('The version has already been deleted.')
return None
raise

def set_default_version(self, version_name):
return self._ml_client.projects().models().versions().setDefault(
name = version_name
).execute()

def get_operation(self, operation_name):
"""Gets an operation.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

from googleapiclient import errors

def wait_existing_version(ml_client, project_id, model_name,
version_name, wait_interval):
def wait_existing_version(ml_client, version_name, wait_interval):
while True:
existing_version = ml_client.get_version(
project_id, model_name, version_name)
existing_version = ml_client.get_version(version_name)
if not existing_version:
return None
state = existing_version.get('state', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ def _dump_metadata(self):

def _dump_job(self, job):
logging.info('Dumping job: {}'.format(job))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(job))
gcp_common.dump_file('/tmp/outputs/job_id.txt', job['jobId'])
gcp_common.dump_file('/tmp/kfp/output/ml_engine/job.json', json.dumps(job))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/job_id.txt', job['jobId'])
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@
from ._client import MLEngineClient
from .. import common as gcp_common

def create_model(project_id, name=None, model=None):
def create_model(project_id, model_id=None, model=None):
"""Creates a MLEngine model.

Args:
project_id (str): the ID of the parent project of the model.
name (str): optional, the name of the model. If absent, a new name will
model_id (str): optional, the name of the model. If absent, a new name will
be generated.
model (dict): the payload of the model.
"""
return CreateModelOp(project_id, name, model).execute()
return CreateModelOp(project_id, model_id, model).execute()

class CreateModelOp:
def __init__(self, project_id, name, model):
def __init__(self, project_id, model_id, model):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = name
self._model_id = model_id
self._model_name = None
if model:
self._model = model
else:
Expand All @@ -53,8 +54,7 @@ def execute(self):
model = self._model)
except errors.HttpError as e:
if e.resp.status == 409:
existing_model = self._ml.get_model(
self._project_id, self._model_name)
existing_model = self._ml.get_model(self._model_name)
if not self._is_dup_model(existing_model):
raise
logging.info('The same model {} has been submitted'
Expand All @@ -67,9 +67,11 @@ def execute(self):
return created_model

def _set_model_name(self, context_id):
if not self._model_name:
self._model_name = 'model_' + context_id
self._model['name'] = gcp_common.normalize_name(self._model_name)
if not self._model_id:
self._model_id = 'model_' + context_id
self._model['name'] = gcp_common.normalize_name(self._model_id)
self._model_name = 'projects/{}/models/{}'.format(
self._project_id, self._model_id)


def _is_dup_model(self, existing_model):
Expand All @@ -82,11 +84,11 @@ def _is_dup_model(self, existing_model):
def _dump_metadata(self):
display.display(display.Link(
'https://console.cloud.google.com/mlengine/models/{}?project={}'.format(
self._model_name, self._project_id),
self._model_id, self._project_id),
'Model Details'
))

def _dump_model(self, model):
logging.info('Dumping model: {}'.format(model))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(model))
gcp_common.dump_file('/tmp/outputs/model_name.txt', self._model_name)
gcp_common.dump_file('/tmp/kfp/output/ml_engine/model.json', json.dumps(model))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_name.txt', self._model_name)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import time
import re

from googleapiclient import errors
from fire import decorators
Expand All @@ -25,18 +26,17 @@
from ._common_ops import wait_existing_version, wait_for_operation_done

@decorators.SetParseFns(python_version=str, runtime_version=str)
def create_version(project_id, model_name, deployemnt_uri=None, version_name=None,
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
def create_version(model_name, deployemnt_uri=None, version_id=None,
runtime_version=None, python_version=None, version=None,
replace_existing=False, wait_interval=30):
"""Creates a MLEngine version and wait for the operation to be done.

Args:
project_id (str): required, the ID of the parent project.
model_name (str): required, the name of the parent model.
deployment_uri (str): optional, the Google Cloud Storage location of
the trained model used to create the version.
version_name (str): optional, the name of the version. If it is not
provided, the operation uses a random name.
version_id (str): optional, the user provided short name of
the version. If it is not provided, the operation uses a random name.
runtime_version (str): optinal, the Cloud ML Engine runtime version
to use for this deployment. If not set, Cloud ML Engine uses
the default stable version, 1.0.
Expand All @@ -53,23 +53,28 @@ def create_version(project_id, model_name, deployemnt_uri=None, version_name=Non
version = {}
if deployemnt_uri:
version['deploymentUri'] = deployemnt_uri
if version_name:
version['name'] = version_name
if version_id:
version['name'] = version_id
if runtime_version:
version['runtimeVersion'] = runtime_version
if python_version:
version['pythonVersion'] = python_version

return CreateVersionOp(project_id, model_name, version,
return CreateVersionOp(model_name, version,
replace_existing, wait_interval).execute_and_wait()

class CreateVersionOp:
def __init__(self, project_id, model_name, version,
def __init__(self, model_name, version,
replace_existing, wait_interval):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = gcp_common.normalize_name(model_name)
self._model_name = model_name
self._project_id, self._model_short_name = self._parse_model_name(model_name)
# The name of the version resource, which is in the format
# of projects/*/models/*/versions/*
self._version_name = None
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
# The user provide short name of the version.
self._version_id = None
# The full payload of the version resource.
self._version = version
self._replace_existing = replace_existing
self._wait_interval = wait_interval
Expand All @@ -81,7 +86,7 @@ def execute_and_wait(self):
self._set_version_name(ctx.context_id())
self._dump_metadata()
existing_version = wait_existing_version(self._ml,
self._project_id, self._model_name, self._version_name,
self._version_name,
self._wait_interval)
if existing_version and self._is_dup_version(existing_version):
return self._handle_completed_version(existing_version)
Expand All @@ -95,15 +100,21 @@ def execute_and_wait(self):

created_version = self._create_version_and_wait()
return self._handle_completed_version(created_version)

def _parse_model_name(self, model_name):
match = re.search(r'^projects/([^/]+)/models/([^/]+)$', model_name)
if not match:
raise ValueError('model name "{}" is not in desired format.'.format(model_name))
return (match.group(1), match.group(2))

def _set_version_name(self, context_id):
version_name = self._version.get('name', None)
if not version_name:
version_name = 'ver_' + context_id
version_name = gcp_common.normalize_name(version_name)
self._version_name = version_name
self._version['name'] = version_name

name = self._version.get('name', None)
if not name:
name = 'ver_' + context_id
name = gcp_common.normalize_name(name)
self._version_id = name
self._version['name'] = name
self._version_name = '{}/versions/{}'.format(self._model_name, name)

def _cancel(self):
if self._delete_operation_name:
Expand All @@ -113,8 +124,7 @@ def _cancel(self):
self._ml.cancel_operation(self._create_operation_name)

def _create_version_and_wait(self):
operation = self._ml.create_version(self._project_id,
self._model_name, self._version)
operation = self._ml.create_version(self._model_name, self._version)
# Cache operation name for cancellation.
self._create_operation_name = operation.get('name')
try:
Expand All @@ -128,8 +138,7 @@ def _create_version_and_wait(self):
return operation.get('response', None)

def _delete_version_and_wait(self):
operation = self._ml.delete_version(
self._project_id, self._model_name, self._version_name)
operation = self._ml.delete_version(self._version_name)
# Cache operation name for cancellation.
self._delete_operation_name = operation.get('name')
try:
Expand All @@ -147,20 +156,22 @@ def _handle_completed_version(self, version):
error_message = version.get('errorMessage', 'Unknown failure')
raise RuntimeError('Version is in failed state: {}'.format(
error_message))
# Workaround issue that CMLE doesn't return the full version name.
version['name'] = self._version_name
self._dump_version(version)
return version

def _dump_metadata(self):
display.display(display.Link(
'https://console.cloud.google.com/mlengine/models/{}/versions/{}?project={}'.format(
self._model_name, self._version_name, self._project_id),
self._model_short_name, self._version_id, self._project_id),
'Version Details'
))

def _dump_version(self, version):
logging.info('Dumping version: {}'.format(version))
gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(version))
gcp_common.dump_file('/tmp/outputs/version_name.txt', version['name'])
gcp_common.dump_file('/tmp/kfp/output/ml_engine/version.json', json.dumps(version))
gcp_common.dump_file('/tmp/kfp/output/ml_engine/version_name.txt', version['name'])

def _is_dup_version(self, existing_version):
return not gcp_common.check_resource_changed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,33 @@
from .. import common as gcp_common
from ._common_ops import wait_existing_version, wait_for_operation_done

def delete_version(project_id, model_name, version_name, wait_interval=30):
def delete_version(version_name, wait_interval=30):
"""Deletes a MLEngine version and wait.

Args:
project_id (str): required, the ID of the parent project.
model_name (str): required, the name of the parent model.
version_name (str): required, the name of the version.
wait_interval (int): the interval to wait for a long running operation.
"""
DeleteVersionOp(project_id, model_name, version_name,
wait_interval).execute_and_wait()
DeleteVersionOp(version_name, wait_interval).execute_and_wait()

class DeleteVersionOp:
def __init__(self, project_id, model_name, version_name, wait_interval):
gaoning777 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, version_name, wait_interval):
self._ml = MLEngineClient()
self._project_id = project_id
self._model_name = gcp_common.normalize_name(model_name)
self._version_name = gcp_common.normalize_name(version_name)
self._version_name = version_name
self._wait_interval = wait_interval
self._delete_operation_name = None

def execute_and_wait(self):
with KfpExecutionContext(on_cancel=self._cancel):
existing_version = wait_existing_version(self._ml,
self._project_id, self._model_name, self._version_name,
self._version_name,
self._wait_interval)
if not existing_version:
logging.info('The version has already been deleted.')
return None

logging.info('Deleting existing version...')
operation = self._ml.delete_version(
self._project_id, self._model_name, self._version_name)
operation = self._ml.delete_version(self._version_name)
# Cache operation name for cancellation.
self._delete_operation_name = operation.get('name')
try:
Expand Down
Loading