-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementing v1alpha2 grid search suggestion algorithm (#622)
* Implementing v1alpha2 grid search algorithm * Fix indendation * Build grid image
- Loading branch information
1 parent
0f6fdeb
commit cb25807
Showing
12 changed files
with
261 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
FROM python:3 | ||
|
||
ADD . /usr/src/app/github.com/kubeflow/katib | ||
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/grid/v1alpha2 | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/api/v1alpha2/python | ||
|
||
ENTRYPOINT ["python", "main.py"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import grpc | ||
import time | ||
from pkg.api.v1alpha2.python import api_pb2_grpc | ||
from pkg.suggestion.v1alpha2.grid_service import GridService | ||
from concurrent import futures | ||
|
||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 | ||
DEFAULT_PORT = "0.0.0.0:6789" | ||
|
||
def serve(): | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
api_pb2_grpc.add_SuggestionServicer_to_server(GridService(), server) | ||
server.add_insecure_port(DEFAULT_PORT) | ||
print("Listening...") | ||
server.start() | ||
try: | ||
while True: | ||
time.sleep(_ONE_DAY_IN_SECONDS) | ||
except KeyboardInterrupt: | ||
server.stop(0) | ||
|
||
if __name__ == "__main__": | ||
serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
grpcio | ||
duecredit | ||
cloudpickle==0.5.6 | ||
numpy>=1.13.3 | ||
scikit-learn>=0.19.0 | ||
scipy>=0.19.1 | ||
forestci | ||
protobuf | ||
googleapis-common-protos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
apiVersion: "kubeflow.org/v1alpha2" | ||
kind: Experiment | ||
metadata: | ||
namespace: kubeflow | ||
name: grid-experiment | ||
spec: | ||
parallelTrialCount: 3 | ||
maxTrialCount: 12 | ||
maxFailedTrialCount: 3 | ||
objective: | ||
type: maximize | ||
goal: 0.99 | ||
objectiveMetricName: Validation-accuracy | ||
additionalMetricNames: | ||
- accuracy | ||
algorithm: | ||
algorithmName: grid | ||
algorithmSettings: | ||
- name: --num-layers | ||
value: "5" | ||
- name: --optimizer | ||
value: "3" | ||
trialTemplate: | ||
goTemplate: | ||
rawTemplate: |- | ||
apiVersion: batch/v1 | ||
kind: Job | ||
metadata: | ||
name: {{.Trial}} | ||
namespace: {{.NameSpace}} | ||
spec: | ||
template: | ||
spec: | ||
containers: | ||
- name: {{.Trial}} | ||
image: katib/mxnet-mnist-example | ||
command: | ||
- "python" | ||
- "/mxnet/example/image-classification/train_mnist.py" | ||
- "--batch-size=64" | ||
{{- with .HyperParameters}} | ||
{{- range .}} | ||
- "{{.Name}}={{.Value}}" | ||
{{- end}} | ||
{{- end}} | ||
restartPolicy: Never | ||
parameters: | ||
- name: --lr | ||
parameterType: double | ||
feasibleSpace: | ||
min: "0.01" | ||
max: "0.03" | ||
- name: --num-layers | ||
parameterType: int | ||
feasibleSpace: | ||
min: "1" | ||
max: "15" | ||
- name: --optimizer | ||
parameterType: categorical | ||
feasibleSpace: | ||
list: | ||
- sgd | ||
- adam | ||
- ftrl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
apiVersion: extensions/v1beta1 | ||
kind: Deployment | ||
metadata: | ||
name: katib-suggestion-grid | ||
namespace: kubeflow | ||
labels: | ||
app: katib | ||
component: suggestion-grid | ||
spec: | ||
replicas: 1 | ||
template: | ||
metadata: | ||
name: katib-suggestion-grid | ||
labels: | ||
app: katib | ||
component: suggestion-grid | ||
spec: | ||
containers: | ||
- name: katib-suggestion-grid | ||
image: katib/v1alpha2/suggestion-grid | ||
imagePullPolicy: IfNotPresent | ||
ports: | ||
- name: api | ||
containerPort: 6789 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
apiVersion: v1 | ||
kind: Service | ||
metadata: | ||
name: katib-suggestion-grid | ||
namespace: kubeflow | ||
labels: | ||
app: katib | ||
component: suggestion-grid | ||
spec: | ||
type: ClusterIP | ||
ports: | ||
- port: 6789 | ||
protocol: TCP | ||
name: api | ||
selector: | ||
app: katib | ||
component: suggestion-grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from logging import getLogger, StreamHandler, INFO, DEBUG | ||
import itertools | ||
import grpc | ||
import numpy as np | ||
from pkg.api.v1alpha2.python import api_pb2 | ||
from pkg.api.v1alpha2.python import api_pb2_grpc | ||
from . import parsing_util | ||
|
||
class GridService(api_pb2_grpc.SuggestionServicer): | ||
def __init__(self): | ||
self.manager_addr = "katib-manager" | ||
self.manager_port = 6789 | ||
self.default_grid = 10 | ||
|
||
def _get_experiment(self, name): | ||
channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) | ||
with api_pb2.beta_create_Manager_stub(channel) as client: | ||
exp = client.GetExperiment(api_pb2.GetExperimentRequest(experiment_name=name), 10) | ||
return exp.experiment | ||
|
||
def _get_algorithm_settings(self, experiment_name): | ||
channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) | ||
with api_pb2.beta_create_Manager_stub(channel) as client: | ||
alg = client.GetAlgorithmExtraSettings(api_pb2.GetAlgorithmExtraSettingsRequest( | ||
experiment_name=experiment_name), 10) | ||
params = alg.extra_algorithm_settings | ||
alg_settings = {} | ||
for param in params: | ||
alg_settings[param.name] = param.value | ||
return alg_settings | ||
|
||
def _get_trials(self, experiment_name): | ||
channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) | ||
with api_pb2.beta_create_Manager_stub(channel) as client: | ||
trials = client.GetTrialList(api_pb2.GetTrialListRequest( | ||
experiment_name=experiment_name), 10) | ||
return trials.trials | ||
|
||
def _create_all_combinations(self, parameters, alg_settings): | ||
param_ranges = [] | ||
cur_index = 0 | ||
parameter_config = parsing_util.parse_parameter_configs(parameters) | ||
default_grid_size = alg_settings.get("DefaultGrid", self.default_grid) | ||
for idx, param_type in enumerate(parameter_config.parameter_types): | ||
param_name = parameter_config.names[idx] | ||
if param_type in [api_pb2.DOUBLE, api_pb2.INT]: | ||
num = alg_settings.get(param_name, default_grid_size) | ||
param_values = \ | ||
np.linspace(parameter_config.lower_bounds[0, cur_index], | ||
parameter_config.upper_bounds[0, cur_index], | ||
num=num) | ||
cur_index += 1 | ||
if param_type == api_pb2.INT: | ||
param_values = param_values.astype(np.int64) | ||
elif param_type == api_pb2.DISCRETE: | ||
for discrete_param in parameter_config.discrete_info: | ||
if param_name == discrete_param["name"]: | ||
param_values = discrete_param["values"] | ||
break | ||
cur_index += 1 | ||
elif param_type == api_pb2.CATEGORICAL: | ||
for categ_param in parameter_config.categorical_info: | ||
if param_name == categ_param["name"]: | ||
param_values = categ_param["values"] | ||
break | ||
cur_index += categ_param["number"] | ||
param_ranges.append(param_values) | ||
all_combinations = [comb for comb in itertools.product(*param_ranges)] | ||
return all_combinations, parameter_config | ||
|
||
def GetSuggestions(self, request, context): | ||
""" | ||
Main function to provide suggestion. | ||
""" | ||
experiment_name = request.experiment_name | ||
request_number = request.request_number | ||
experiment = self._get_experiment(experiment_name) | ||
parameters = experiment.spec.parameter_specs.parameters | ||
alg_settings = self._get_algorithm_settings(experiment_name) | ||
combinations, parameter_config = self._create_all_combinations(parameters, alg_settings) | ||
total_combinations = len(combinations) | ||
|
||
allocated_trials = self._get_trials(experiment_name) | ||
total_allocated_trials = len(allocated_trials) | ||
return_start_index = total_allocated_trials | ||
return_end_index = return_start_index + request_number | ||
|
||
if return_start_index > total_combinations: | ||
return_start_index = 0 | ||
return_end_index = return_start_index + request_number | ||
elif return_start_index + request_number > total_combinations: | ||
return_start_index = total_combinations - request_number | ||
return_end_index = total_combinations | ||
if return_start_index < 0: | ||
return_start_index = 0 | ||
|
||
trial_specs = [] | ||
for elem in combinations[return_start_index:return_end_index]: | ||
suggestion = parsing_util.parse_x_next_tuple(elem, parameter_config.parameter_types, | ||
parameter_config.names) | ||
trial_spec = api_pb2.TrialSpec() | ||
trial_spec.experiment_name = experiment_name | ||
for param in suggestion: | ||
trial_spec.parameter_assignments.assignments.add(name=param['name'], | ||
value=str(param['value'])) | ||
trial_specs.append(trial_spec) | ||
reply = api_pb2.GetSuggestionsReply() | ||
for trial_spec in trial_specs: | ||
reply.trials.add(spec=trial_spec) | ||
return reply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters