Skip to content

Commit

Permalink
linters
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed Aug 13, 2021
1 parent f5db46a commit c0baf09
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 30 deletions.
13 changes: 5 additions & 8 deletions examples/tune/ax/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from sklearn.model_selection import train_test_split

from spock.addons.tune import (
ChoiceHyperParameter,
AxTunerConfig,
ChoiceHyperParameter,
RangeHyperParameter,
spockTuner,
)
Expand Down Expand Up @@ -41,10 +41,7 @@ def main():

# Ax config -- this will internally spawn the AxClient service API style which will be returned
# by accessing the tuner_status property on the ConfigArgBuilder object
ax_config = AxTunerConfig(
objective_name='accuracy',
minimize=False
)
ax_config = AxTunerConfig(objective_name="accuracy", minimize=False)

# Use the builder to setup
# Call tuner to indicate that we are going to do some HP tuning -- passing in an ax study object
Expand Down Expand Up @@ -84,9 +81,9 @@ def main():
tuner_status = attrs_obj.tuner_status
# Pull the AxClient object and trial index out of the return dictionary and call 'complete_trial' on the
# AxClient object with the correct raw_data that contains the objective name
tuner_status['client'].complete_trial(
trial_index=tuner_status['trial_index'],
raw_data={'accuracy': (val_acc, 0.0)}
tuner_status["client"].complete_trial(
trial_index=tuner_status["trial_index"],
raw_data={"accuracy": (val_acc, 0.0)},
)
# Always save the current best set of hyper-parameters
attrs_obj.save_best(user_specified_path="/tmp/ax")
Expand Down
14 changes: 7 additions & 7 deletions spock/addons/tune/ax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class AxTunerStatus(TypedDict):
trial_index: current trial index
"""

client: AxClient
trial_index: int


class AxInterface(BaseInterface):
"""Specific override to support the Ax backend -- supports the service style API from Ax
"""Specific override to support the Ax backend -- supports the service style API from Ax"""

"""
def __init__(self, tuner_config: AxTunerConfig, tuner_namespace):
"""AxInterface init call that maps variables, creates a map to fnc calls, and constructs the necessary
underlying objects
Expand All @@ -47,7 +47,7 @@ def __init__(self, tuner_config: AxTunerConfig, tuner_namespace):
generation_strategy=self._tuner_config.generation_strategy,
enforce_sequential_optimization=self._tuner_config.enforce_sequential_optimization,
random_seed=self._tuner_config.random_seed,
verbose_logging=self._tuner_config.verbose_logging
verbose_logging=self._tuner_config.verbose_logging,
)
# Some variables to use later
self._trial_index = None
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, tuner_config: AxTunerConfig, tuner_namespace):
overwrite_existing_experiment=self._tuner_config.overwrite_existing_experiment,
tracking_metric_names=self._tuner_config.tracking_metric_names,
immutable_search_space_and_opt_config=self._tuner_config.immutable_search_space_and_opt_config,
is_test=self._tuner_config.is_test
is_test=self._tuner_config.is_test,
)

@property
Expand All @@ -91,7 +91,7 @@ def best(self):
rollup_dict, _ = self._sample_rollup(best_obj[0])
return (
self._gen_spockspace(rollup_dict),
best_obj[1][0][self._tuner_obj.objective_name]
best_obj[1][0][self._tuner_obj.objective_name],
)

@property
Expand Down Expand Up @@ -134,7 +134,7 @@ def _ax_range(self, name, val):
"type": "range",
"bounds": [low, high],
"value_type": val.type,
"log_scale": val.log_scale
"log_scale": val.log_scale,
}

def _ax_choice(self, name, val):
Expand All @@ -155,5 +155,5 @@ def _ax_choice(self, name, val):
"name": name,
"type": "choice",
"values": val.choices,
"value_type": val.type
"value_type": val.type,
}
2 changes: 1 addition & 1 deletion spock/addons/tune/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
"""Creates the spock config interface that wraps attr -- tune version for hyper-parameters"""
import sys
from typing import List, Optional, Sequence, Tuple, Union
from uuid import uuid4

import attr
import optuna
from ax.modelbridge.generation_strategy import GenerationStrategy

from spock.backend.config import _base_attr
from uuid import uuid4


@attr.s(auto_attribs=True)
Expand Down
17 changes: 6 additions & 11 deletions spock/addons/tune/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
# SPDX-License-Identifier: Apache-2.0

"""Handles the base interface"""
import hashlib
import json
from abc import ABC, abstractmethod
from typing import Dict
from typing import Dict, Union

import attr
import hashlib
import json

from spock.addons.tune.config import AxTunerConfig, OptunaTunerConfig
from spock.backend.wrappers import Spockspace

from spock.addons.tune.config import AxTunerConfig
from spock.addons.tune.config import OptunaTunerConfig

from typing import Union


class BaseInterface(ABC):
"""Base interface for the various hyper-parameter tuner backends
Expand All @@ -28,6 +24,7 @@ class BaseInterface(ABC):
_tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types
"""

def __init__(self, tuner_config, tuner_namespace: Spockspace):
"""Base init call that maps a few variables
Expand Down Expand Up @@ -141,9 +138,7 @@ def _config_to_dict(tuner_config: Union[OptunaTunerConfig, AxTunerConfig]):
dictionary of the attrs config object
"""
return {
k: v for k, v in attr.asdict(tuner_config).items() if v is not None
}
return {k: v for k, v in attr.asdict(tuner_config).items() if v is not None}

@staticmethod
def _to_spockspace(tune_dict: Dict):
Expand Down
5 changes: 4 additions & 1 deletion spock/addons/tune/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OptunaTunerStatus(TypedDict):
study: current optuna study object
"""

trial: optuna.Trial
study: optuna.Study

Expand Down Expand Up @@ -55,7 +56,9 @@ def __init__(self, tuner_config: OptunaTunerConfig, tuner_namespace):
"""
super(OptunaInterface, self).__init__(tuner_config, tuner_namespace)
self._tuner_obj = optuna.create_study(**self._config_to_dict(self._tuner_config))
self._tuner_obj = optuna.create_study(
**self._config_to_dict(self._tuner_config)
)
# Some variables to use later
self._trial = None
self._sample_hash = None
Expand Down
3 changes: 1 addition & 2 deletions spock/addons/tune/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from typing import Union

from spock.addons.tune.config import AxTunerConfig
from spock.addons.tune.ax import AxInterface
from spock.addons.tune.config import OptunaTunerConfig
from spock.addons.tune.config import AxTunerConfig, OptunaTunerConfig
from spock.addons.tune.optuna import OptunaInterface
from spock.backend.wrappers import Spockspace

Expand Down

0 comments on commit c0baf09

Please sign in to comment.