Skip to content

Commit

Permalink
add container image uris support in train API (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongye-sun authored Feb 21, 2019
1 parent 2ec15c2 commit e282bd3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
20 changes: 16 additions & 4 deletions component_sdk/python/kfp_component/google/ml_engine/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from ._create_job import create_job

@decorators.SetParseFns(python_version=str, runtime_version=str)
def train(project_id, python_module, package_uris,
region, args=None, job_dir=None, python_version=None,
runtime_version=None, training_input=None, job_id_prefix=None,
wait_interval=30):
def train(project_id, python_module=None, package_uris=None,
region=None, args=None, job_dir=None, python_version=None,
runtime_version=None, master_image_uri=None, worker_image_uri=None,
training_input=None, job_id_prefix=None, wait_interval=30):
"""Creates a MLEngine training job.
Args:
Expand All @@ -44,6 +44,10 @@ def train(project_id, python_module, package_uris,
runtime_version (str): Optional. The Cloud ML Engine runtime version
to use for training. If not set, Cloud ML Engine uses the
default stable version, 1.0.
master_image_uri (str): The Docker image to run on the master replica.
This image must be in Container Registry.
worker_image_uri (str): The Docker image to run on the worker replica.
This image must be in Container Registry.
training_input (dict): Input parameters to create a training job.
job_id_prefix (str): the prefix of the generated job id.
wait_interval (int): optional wait interval between calls
Expand All @@ -65,6 +69,14 @@ def train(project_id, python_module, package_uris,
training_input['pythonVersion'] = python_version
if runtime_version:
training_input['runtimeVersion'] = runtime_version
if master_image_uri:
if 'masterConfig' not in training_input:
training_input['masterConfig'] = {}
training_input['masterConfig']['imageUri'] = master_image_uri
if worker_image_uri:
if 'workerConfig' not in training_input:
training_input['workerConfig'] = {}
training_input['workerConfig']['imageUri'] = worker_image_uri
job = {
'trainingInput': training_input
}
Expand Down
11 changes: 9 additions & 2 deletions component_sdk/python/tests/google/ml_engine/test__train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_train_succeed(self, mock_create_job):
training_input={
'runtimeVersion': '1.10',
'pythonVersion': '2.7'
}, job_id_prefix='job-')
}, job_id_prefix='job-', master_image_uri='tensorflow:latest',
worker_image_uri='debian:latest')

mock_create_job.assert_called_with('proj-1', {
'trainingInput': {
Expand All @@ -39,6 +40,12 @@ def test_train_succeed(self, mock_create_job):
'args': ['arg-1', 'arg-2'],
'jobDir': 'gs://test/job/dir',
'runtimeVersion': '1.10',
'pythonVersion': '2.7'
'pythonVersion': '2.7',
'masterConfig': {
'imageUri': 'tensorflow:latest'
},
'workerConfig': {
'imageUri': 'debian:latest'
}
}
}, 'job-', 30)

0 comments on commit e282bd3

Please sign in to comment.