diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..61d9445e --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,18 @@ +--- +name: Pull request +about: Create a pull request for merge + +--- + +## What does this PR do? +E.g. Describe the added feature or what issue it fixes #(issue)... + +## Checklist + - [ ] Did you adhere to [PEP-8](https://www.python.org/dev/peps/pep-0008/) standards? + - [ ] Did you run black and isort prior to submitting your PR? + - [ ] Does your PR pass all existing unit tests? + - [ ] Did you add associated unit tests for any additional functionality? + - [ ] Did you provide documentation ([Numpy Docstring format](https://numpydoc.readthedocs.io/en/latest/format.html#style-guide)) whenever possible, even for simple functions or classes? + +## Review +Request will go to reviewers to approve for merge. \ No newline at end of file diff --git a/.github/workflows/python-coverage.yaml b/.github/workflows/python-coverage.yaml index 2e9dcff5..518aad60 100644 --- a/.github/workflows/python-coverage.yaml +++ b/.github/workflows/python-coverage.yaml @@ -21,11 +21,19 @@ jobs: with: python-version: '3.8' + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }} + - name: Install dependencies and dev dependencies run: | python -m pip install --upgrade pip - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt + pip install -r ./requirements/TUNE_REQUIREMENTS.txt + pip install -r ./requirements/TEST_EXTRAS_REQUIREMENTS.txt - name: Test with pytest run: | diff --git a/.github/workflows/python-docs.yaml b/.github/workflows/python-docs.yaml index 88c90275..5df2ab7b 100644 --- a/.github/workflows/python-docs.yaml +++ b/.github/workflows/python-docs.yaml @@ -18,12 +18,18 @@ jobs: uses: actions/setup-python@v2 with: python-version: '3.8' + + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies and dev dependencies run: | python -m pip install --upgrade pip pip install -e .[s3] - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt - name: Build docs with Portray env: diff --git a/.github/workflows/python-lint.yaml b/.github/workflows/python-lint.yaml new file mode 100644 index 00000000..35d77eee --- /dev/null +++ b/.github/workflows/python-lint.yaml @@ -0,0 +1,42 @@ +# This workflow will run isort and black linters on PRs + +name: lint + +# on: workflow_dispatch +on: + pull_request: + branches: [master] + push: + branches: [master] + +jobs: + run_lint: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + + - name: Install dependencies and dev dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt + + - name: Run isort linter + run: | + isort --check . --skip="debug" --skip="versioneer.py" --skip="tests" --skip="_version.py" + + - name: Run black linter + run: | + black --check . --exclude="versioneer.py|_version.py|debug|tests" diff --git a/.github/workflows/python-pytest-s3.yaml b/.github/workflows/python-pytest-s3.yaml index e359199e..3cc65eaa 100644 --- a/.github/workflows/python-pytest-s3.yaml +++ b/.github/workflows/python-pytest-s3.yaml @@ -23,16 +23,21 @@ jobs: with: python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e . - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt - name: Test with pytest run: | - pytest tests/s3 --cov=spock --cov-config=.coveragerc --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov-report=xml --cov-report=html + pytest tests/s3 --cov=spock --cov-config=.coveragerc --junitxml=junit/test-results-s3-${{ matrix.python-version }}.xml --cov-report=xml --cov-report=html - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/.github/workflows/python-pytest-tune.yaml b/.github/workflows/python-pytest-tune.yaml new file mode 100644 index 00000000..ca13e070 --- /dev/null +++ b/.github/workflows/python-pytest-tune.yaml @@ -0,0 +1,50 @@ +# This workflow will install Python dependencies, run S3 tests with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: pytest-tune + +on: + pull_request: + branches: [master] + push: + branches: [master] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt + pip install -r ./requirements/TUNE_REQUIREMENTS.txt + pip install -r ./requirements/TEST_EXTRAS_REQUIREMENTS.txt + + - name: Test with pytest + run: | + pytest tests/tune --cov=spock --cov-config=.coveragerc --junitxml=junit/test-results-tune-${{ matrix.python-version }}.xml --cov-report=xml --cov-report=html + + - name: Upload pytest test results + uses: actions/upload-artifact@v2 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + # Use always() to always run this step to publish test results when there are test failures + if: ${{ always() }} diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 4f246e96..75cd5333 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -23,11 +23,16 @@ jobs: with: python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e . - pip install -r DEV_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt - name: Test with pytest run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dd05fe30..d3036f40 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,6 +4,7 @@ Requests in the public repository. ## Contribution Guidelines 1. Adhere to [PEP-8](https://www.python.org/dev/peps/pep-0008/) standards. -2. Any changes to core functionality must pass all existing unit tests. -3. Additional functionality should have associated unit tests. -4. Provide documentation (Google Docstring format) whenever possible, even for simple functions or classes. \ No newline at end of file +2. Run black and isort linters before creating a PR. +3. Any changes to core functionality must pass all existing unit tests. +4. Additional functionality should have associated unit tests. +5. Provide documentation ([Numpy Docstring format](https://numpydoc.readthedocs.io/en/latest/format.html#style-guide)) whenever possible, even for simple functions or classes. \ No newline at end of file diff --git a/NOTICE.txt b/NOTICE.txt index 60910e68..f01ab0e6 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -4,7 +4,7 @@ // ------------------------------------------------------------------ Spock -Copyright [2109-2020] FMR LLC +Copyright [2019-2021] FMR LLC This product includes software developed at FMR LLC (https://www.fidelity.com/). @@ -12,5 +12,13 @@ FMR LLC (https://www.fidelity.com/). This product relies on the following works (and the dependencies thereof), installed separately: - attrs | https://github.com/python-attrs/attrs | MIT License - GitPython | https://github.com/gitpython-developers/GitPython | BSD 3-Clause License +- pytomlpp | https://github.com/bobfang1992/pytomlpp | MIT License - PyYAML | https://github.com/yaml/pyyaml | MIT License -- toml | https://github.com/toml-lang/toml | MIT License \ No newline at end of file + + +Optional extensions rely on the following works (and the dependencies thereof), installed separately: +- boto3 | https://github.com/boto/boto3 | Apache License 2.0 +- botocore | https://github.com/boto/botocore | Apache License 2.0 +- hurry.filesize | https://pypi.org/project/hurry.filesize/ | ZPL 2.1 +- optuna | https://optuna.org/ | MIT License +- s3transfer | https://github.com/boto/s3transfer | Apache License 2.0 \ No newline at end of file diff --git a/README.md b/README.md index 5afe2e18..7072d77f 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![License](https://img.shields.io/badge/License-Apache%202.0-9cf)](https://opensource.org/licenses/Apache-2.0) [![Python](https://img.shields.io/badge/python-3.6+-informational.svg)]() +[![Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![PyPI version](https://badge.fury.io/py/spock-config.svg)](https://badge.fury.io/py/spock-config) [![Coverage Status](https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master)](https://coveralls.io/github/fidelity/spock?branch=master) ![Tests](https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master) @@ -45,6 +46,10 @@ recent features, bugfixes, and hotfixes. See [Releases](https://github.com/fidelity/spock/releases) for more information. +#### July 21, 2021 +* Added hyper-parameter tuning support with `pip install spock-config[tune]` +* Hyper-parameter tuning backend support for Optuna define-and-run API (WIP for Ax) + #### May 6th, 2021 * Added S3 support with `pip install spock-config[s3]` * S3 addon supports automatically handling loading/saving from paths defined with `s3://` URI(s) by passing in an @@ -81,6 +86,8 @@ Example `spock` usage is located [here](https://github.com/fidelity/spock/blob/m parameter configuration to YAML, TOML, or JSON with a single chained command (with extra runtime info such as Git info, Python version, machine FQDN, etc). The saved markdown file can be used as the configuration input to reproduce prior runtime configurations. +* [Hyper-Parameter Tuner Addon](https://fidelity.github.io/spock/docs/addons/tuner/About.html): Provides a unified + interface for hyper-parameter tuning that supports various backends (Optuna, WIP: Ax) * [S3 Addon](https://fidelity.github.io/spock/docs/addons/S3/): Automatically detects `s3://` URI(s) and handles loading and saving `spock` configuration files when an active `boto3.Session` is passed in (plus any additional `S3Transfer` configurations) diff --git a/REQUIREMENTS.txt b/REQUIREMENTS.txt index 7db37490..b33e46e0 100644 --- a/REQUIREMENTS.txt +++ b/REQUIREMENTS.txt @@ -1,4 +1,4 @@ attrs GitPython -pyYAML -toml \ No newline at end of file +pytomlpp +pyYAML \ No newline at end of file diff --git a/docs/Installation.md b/docs/Installation.md index 792f89ce..0dd8087c 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -3,7 +3,7 @@ ### Requirements * Python: 3.6+ -* Dependencies: attrs, GitPython, PyYAML, toml +* Base Dependencies: attrs, GitPython, PyYAML, toml * Tested OS: Unix (Ubuntu 16.04, Ubuntu 18.04), OSX (10.14.6) ### Install/Upgrade @@ -14,10 +14,21 @@ pip install spock-config ``` #### w/ S3 Extension + +Extra Dependencies: boto3, botocore, hurry.filesize, s3transfer + ```bash pip install spock-config[s3] ``` +#### w/ Hyper-Parameter Tuner Extension + +Extra Dependencies: optuna + +```bash +pip install spock-config[tune] +``` + #### Pip From Source ```bash pip install git+https://github.com/fidelity/spock diff --git a/docs/Motivation.md b/docs/Motivation.md index bdc6a76a..7463a480 100644 --- a/docs/Motivation.md +++ b/docs/Motivation.md @@ -87,6 +87,8 @@ set of parameters. * Tractability and Reproducibility: Save runtime parameter configuration to YAML, TOML, or JSON with a single chained command (with extra runtime info such as Git info, Python version, machine FQDN, etc). The saved markdown file can be used as the configuration input to reproduce prior runtime configurations. +* Hyper-Parameter Tuner Addon: Provides a unified interface for hyper-parameter tuning that supports various + backends (Optuna, WIP: Ax) * S3 Addon: Automatically detects `s3://` URI(s) and handles loading and saving `spock` configuration files when an active `boto3.Session` is passed in (plus any additional `S3Transfer` configurations) diff --git a/docs/Quick-Start.md b/docs/Quick-Start.md index 2f64ac80..ba107af9 100644 --- a/docs/Quick-Start.md +++ b/docs/Quick-Start.md @@ -107,7 +107,7 @@ fancier_parameter: 64.64 most_fancy_parameter: [768, 768, 512, 128] ``` -Finally, we would run our script and pass the path to the configuration file to the command line (-c or --config): +Finally, we would run our script and pass the path to the configuration file to the command line (`-c` or `--config`): ```bash $ python simple.py -c simple.yaml @@ -131,4 +131,16 @@ configuration(s): fancy_parameter float parameter that multiplies a value fancier_parameter float parameter that gets added to product of val and fancy_parameter most_fancy_parameter List[int] values to apply basic algebra to +``` + +### Spock As a Drop In For Argparser + +`spock` can easily be used as a drop in for argparser. This means that all parameter definitions as required to come in +from the command line or from setting defaults within the `@spock` decorated classes. Simply do not pass a `-c` or +`--config` argument at the command line and instead pass in all of the automatically generated cmd-line arguments. + + +```bash +$ python simple.py --BasicConfig.parameter --BasicConfig.fancy_parameter 8.8 --BasicConfig.fancier_parameter 64.64 \ + --BasicConfig.most_fancy_parameter [768, 768, 512, 128] ``` \ No newline at end of file diff --git a/docs/addons/S3.md b/docs/addons/S3.md index d2ac5d32..36dfb89c 100644 --- a/docs/addons/S3.md +++ b/docs/addons/S3.md @@ -46,11 +46,11 @@ session = boto3.Session( ### Using the S3Config Object -As an example let's create a basic `@spock` decorated class, instantiate a `S3Config` object from `spock.addons` with +As an example let's create a basic `@spock` decorated class, instantiate a `S3Config` object from `spock.addons.s3` with the `boto3.session.Session` we created above, and pass it to the `ConfigArgBuilder`. ```python -from spock.addons import S3Config +from spock.addons.s3 import S3Config from spock.builder import ConfigArgBuilder from spock.config import spock from typing import List @@ -123,8 +123,8 @@ With a `S3Config` object passed into the `ConfigArgBuilder` the S3 URI will auto If you require any other settings for uploading or downloading files from S3 the `S3Config` class has two extra attributes: -`download_config` which takes a `S3DownloadConfig` object from `spock.addons` which supports all ExtraArgs from +`download_config` which takes a `S3DownloadConfig` object from `spock.addons.s3` which supports all ExtraArgs from [S3Transfer.ALLOWED_DOWNLOAD_ARGS](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS) -`upload_config` which takes a `S3UploadConfig` object from `spock.addons` which supports all ExtraArgs from +`upload_config` which takes a `S3UploadConfig` object from `spock.addons.s3` which supports all ExtraArgs from [S3Transfer.ALLOWED_UPLOAD_ARGS](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS) diff --git a/docs/addons/tuner/About.md b/docs/addons/tuner/About.md new file mode 100644 index 00000000..1187be8f --- /dev/null +++ b/docs/addons/tuner/About.md @@ -0,0 +1,22 @@ +# Hyper-Parameter Tuning Support + +This series of docs will describe the basics of hyper-parameter support within `spock`. `spock` tries to be as hands-off +as possible with the underlying backends that support hyper-parameter tuning and only provide a common and simplified +interface to define hyper-parameter tuning runs. The rest is left up to the user to define and handle, thus to not +handcuff the user into too simplified functionality. + +All examples can be found [here](https://github.com/fidelity/spock/blob/master/examples). + +### Installing + +Install `spock` with the extra hyper-parameter tuning related dependencies. + +```bash +pip install spock-config[tune] +``` + +### Supported Backends +* [Optuna](https://optuna.readthedocs.io/en/stable/index.html) + +### WIP/Planned Backends +* [Ax](https://ax.dev/) \ No newline at end of file diff --git a/docs/addons/tuner/Basics.md b/docs/addons/tuner/Basics.md new file mode 100644 index 00000000..2401f874 --- /dev/null +++ b/docs/addons/tuner/Basics.md @@ -0,0 +1,95 @@ +# Tune Basics + +Just like the basic `spock` functionality, hyper-parameters are defined via a class based solution. All parameters +must be defined in a class or multiple classes by decorating with the `@spockTuner` decorator. Parameters are defined +as one of the two basic types, `RangeHyperParameter` or `ChoiceHyperParameter`. + +Once built (with a specific backend), all parameters can be found within an automatically generated namespace +object that contains both the fixed and sampled parameters that can be accessed with the given `@spock` or +`@spockTuner` class names. + +### Supported Hyper-Parameter Types +`spock` supports the two following types for hyper-parameters, `RangeHyperParameter` or `ChoiceHyperParameter`. + +The `RangeHyperParameter` type is used for hyper-parameters that are to be drawn from a sampled range of `int` or +`float` while the `ChoiceHyperParameter` type is used for hyper-parameters that are to be sampled from a discrete set +of values that can be of base type `int`, `float`, `bool`, or `str`. + +`RangeHyperParameter` requires the following inputs: + +- type: string of either int or float depending on the needed type +- bounds: a tuple of two values that define the lower and upper bound of the range (int or float) +- log_scale: boolean to activate log scaling of the range + +`ChoiceHyperParameter` requires the following inputs: + +- type: string of either int, float, bool, str depending on the needed type +- choices: a list of any length that contains the discrete values to sample from + +### Defining a spockTuner Class + +Let's start building out a very simple example (logistic regression of iris w/ sklearn) that we will continue to use +within the tutorial: `tune.py` + +Tune functions exactly the same as base `spock` functionality. We import the basic units of functionality +from `spock.addons.tune`, define our class using the `@spockTuner` decorator, and define our parameters with +supported argument types. We also pull in the sample iris data from sklearn. + +```python +from spock.addons.tune import ChoiceHyperParameter +from spock.addons.tune import RangeHyperParameter +from spock.addons.tune import spockTuner +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split + + +@spockTuner +class LogisticRegressionHP: + c: RangeHyperParameter + solver: ChoiceHyperParameter + +# Load the iris data +X, y = load_iris(return_X_y=True) + +# Split the Iris data +X_train, X_valid, y_train, y_valid = train_test_split(X, y) + +``` + +The `@spockTuner` decorated classes are passed to the `ConfigArgBuilder` in the exact same way as basic `@spock` +decorated classes. This returns a `spock` builder object which can be used to call different methods. + +```python +attrs_obj = ConfigArgBuilder( + LogisticRegressionHP, + desc="Example Logistic Regression Hyper-Parameter Tuning", +) +``` + + +### Creating a Configuration File + +Just like basic spock functionality, values in `spock` are set primarily using external configuration files. For our +hyper-parameters we just defined above our `tune.yaml` file might look something like this (remember each class requires +specific inputs): + +```yaml +# Hyper-parameter config +LogisticRegressionHP: + c: + bounds: + - 0.01 + - 10.0 + log_scale: true + type: float + solver: + choices: + - lbfgs + - saga + type: str +``` + + +### Continuing + +The rest of the docs are backend specific so refer to the correct backend specific documentation. \ No newline at end of file diff --git a/docs/addons/tuner/Optuna.md b/docs/addons/tuner/Optuna.md new file mode 100644 index 00000000..97b4513e --- /dev/null +++ b/docs/addons/tuner/Optuna.md @@ -0,0 +1,96 @@ +# Optuna Support + +`spock` integrates with the Optuna hyper-parameter optimization framework through the provided +ask-and-run interface and the define-and-run API. See [docs](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run). + +All examples can be found [here](https://github.com/fidelity/spock/blob/master/examples). + +### Defining the Backend + +So let's continue in our Optuna specific version of `tune.py`: + +It's important to note that you can still use the `@spock` decorator to define any non hyper-parameters! For +posterity let's add some fixed parameters (those that are not part of hyper-parameter tuning) that we will use +elsewhere in our code. + +```python +from spock.config import spock + +@spock +class BasicParams: + n_trials: int + max_iter: int +``` + +Now we need to tell `spock` that we intend on doing hyper-parameter tuning and which backend we would like to use. We +do this by calling the `tuner` method on the `ConfigArgBuilder` object passing in a configuration object for the +backend of choice (just like in basic functionality this is a chained command, thus the builder object will still be +returned). For Optuna one uses `OptunaTunerConfig`. This config mirrors all options that would be passed into +the `optuna.study.create_study` function call so that `spock` can setup the define-and-run API. (Note: The `@spockTuner` +decorated classes are passed to the `ConfigArgBuilder` in the exact same way as basic `@spock` +decorated classes.) + +```python +from spock.addons.tune import OptunaTunerConfig + +# Optuna config -- this will internally configure the study object for the define-and-run style which will be returned +# by accessing the tuner_status property on the ConfigArgBuilder object +optuna_config = OptunaTunerConfig( + study_name="Iris Logistic Regression", direction="maximize" +) + +# Use the builder to setup +# Call tuner to indicate that we are going to do some HP tuning -- passing in an optuna study object +attrs_obj = ConfigArgBuilder( + LogisticRegressionHP, + BasicParams, + desc="Example Logistic Regression Hyper-Parameter Tuning", +).tuner(tuner_config=optuna_config) + +``` + +### Generate Functionality Still Exists + +To get the set of fixed parameters (those that are not hyper-parameters) one simply calls the `generate()` function +just like they would for normal `spock` usage to get the fixed parameter `spockspace`. + +Continuing in `tune.py`: + +```python + +# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params +# prior to starting the sampling process +fixed_params = attrs_obj.generate() +``` + +### Sample as an Alternative to Generate + +The `sample()` call is the crux of `spock` hyper-parameter tuning support. It draws a hyper-parameter sample from the +underlying backend sampler and combines it with fixed parameters and returns a single `Spockspace` with all +useable parameters (defined with dot notation). For Optuna -- Under the hood `spock` uses the define-and-run Optuna +interface -- thus it handles the underlying 'ask' call. The `spock` builder object has a `@property` called +`tuner_status` that returns any necessary backend objects in a dictionary that the user needs to interface with. In the +case of Optuna, this contains both the Optuna `study` and `trial` (as dictionary keys). We use the return of +`tuner_status` to handle the 'tell' call based on the metric of interested (here just simple validation accuracy) + +Continuing in `tune.py`: + +```python +# Iterate through a bunch of optuna trials +for _ in range(fixed_params.BasicParams.n_trials): + # Call sample on the spock object + hp_attrs = attrs_obj.sample() + # Use the currently sampled parameters in a simple LogisticRegression from sklearn + clf = LogisticRegression( + C=hp_attrs.LogisticRegressionHP.c, + solver=hp_attrs.LogisticRegressionHP.solver, + max_iter=hp_attrs.BasicParams.max_iter + ) + clf.fit(X_train, y_train) + val_acc = clf.score(X_valid, y_valid) + # Get the status of the tuner -- this dict will contain all the objects needed to update + tuner_status = attrs_obj.tuner_status + # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study + # object + tuner_status["study"].tell(tuner_status["trial"], val_acc) +``` \ No newline at end of file diff --git a/docs/addons/tuner/Saving.md b/docs/addons/tuner/Saving.md new file mode 100644 index 00000000..9a16c43d --- /dev/null +++ b/docs/addons/tuner/Saving.md @@ -0,0 +1,106 @@ +# Saving Hyper-Parameter Configs -- Base, Samples, and Best + +`spock` provides the capability to save the configuration for each stage of hyper-parameter tuning. + +### Saving Base Hyper-Parameter Definitions + +First, if we wanted to save the configuration state of the defined hyper-parameter ranges (i.e. the definitions of the +parameters that are not sampled) we simply chain the `save()` call post `tuner()` call just like we did with basic +`spock` usage. If there are defined hyper-parameters from `@spockTuner` these will automatically get written into the +markdown file along with the fixed parameters. + +For instance in `tune.py`: + +```python + +# Chain the .save call which will dump the hyper-parameter definitions to the configuration file +attrs_obj = ConfigArgBuilder( + LogisticRegressionHP, + BasicParams, + desc="Example Logistic Regression Hyper-Parameter Tuning", +).tuner(tuner_config=optuna_config).save(user_specified_path='/tmp') +``` + +Would result in the following YAML file: + +```yaml + BasicParams: + max_iter: 150 + n_trials: 10 + LogisticRegressionHP: + c: + bounds: + - 0.01 + - 10.0 + log_scale: true + type: float + solver: + choices: + - lbfgs + - saga + type: str +``` + +### Saving Individual Hyper-Parameter Samples + +If we want to save each individual hyper-parameter sample we again use the `save()` call with the addition of the +`add_tuner_sample=True` keyword arg and chain it before the`sample()` call. The order might be slightly confusing +but this is to allow all methods to return the builder object except for hte `sample()` and `generate()` calls +which returns a `Spockspace`. The saver will append `hp.sample.[0-9+]` to the filename to identify each sample +configuration. + +For instance in `tune.py`: + +```python + +# Now we iterate through a bunch of optuna trials +for _ in range(fixed_params.BasicParams.n_trials): + hp_attrs = attrs_obj.save( + add_tuner_sample=True, user_specified_path="/tmp" + ).sample() +``` + +Would result in `n_trials` files named `hp.sample.[0-9]+.{uuid}.spock.cfg`. For instance opening a file named +`hp.sample.1.d1cc7a30-10f0-4d2c-b076-513fe3494566.spock.cfg.yaml` we would see the first sample set of the +hyper-parameters: + +```yaml + BasicParams: + max_iter: 150 + n_trials: 10 + LogisticRegressionHP: + c: 0.21495978453310358 + solver: lbfgs +``` + +### Saving the Best Hyper-Parameter Samples + +If we want to keep track of the current/final best hyper-parameter set based on the optimization metric we use the +`save_best()` call on the builder object. This function takes all the same arguments as the `save()` method but +maintains only a single configuration file that is the current/final best hyper-parameter configuration. The saver will +append `hp.best.` to the filename to identify the best configuration. Note: Make sure this function is only called post +all backend handling (in the case of Optuna -- the 'tell' call) for the sample or else an exception will be raised as +the best configuration will not yet be registered. + +For instance in `tune.py`: + +```python +# Now we iterate through a bunch of optuna trials +for _ in range(fixed_params.BasicParams.n_trials): + hp_attrs = attrs_obj.sample() + # Use the currently sampled parameters in a simple LogisticRegression from sklearn + clf = LogisticRegression( + C=hp_attrs.LogisticRegressionHP.c, + solver=hp_attrs.LogisticRegressionHP.solver, + max_iter=hp_attrs.BasicParams.max_iter + ) + clf.fit(X_train, y_train) + val_acc = clf.score(X_valid, y_valid) + # Get the status of the tuner -- this dict will contain all the objects needed to update + tuner_status = attrs_obj.tuner_status + # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study + # object + tuner_status["study"].tell(tuner_status["trial"], val_acc) + # Always save the current best set of hyper-parameters + attrs_obj.save_best(user_specified_path='/tmp') +``` \ No newline at end of file diff --git a/docs/advanced_features/Command-Line-Overrides.md b/docs/advanced_features/Command-Line-Overrides.md index 46efc070..128cbe94 100644 --- a/docs/advanced_features/Command-Line-Overrides.md +++ b/docs/advanced_features/Command-Line-Overrides.md @@ -117,5 +117,17 @@ We could override the parameters like so (note that the len must match the defin ```bash $ python tutorial.py --config tutorial.yaml --TypeConfig.nested_list.NestedListStuff.one [1,2] \ ---TypeConfig.nested_list.NestedListStuff.two [ciao,ciao] +--TypeConfig.nested_list.NestedListStuff.two ['ciao','ciao'] +``` + +### Spock As a Drop In For Argparser + +`spock` can easily be used as a drop in for argparser. This means that all parameter definitions as required to come in +from the command line or from setting defaults within the `@spock` decorated classes. Simply do not pass a `-c` or +`--config` argument at the command line and instead pass in all of the automatically generated cmd-line arguments. + + +```bash +$ python tutorial.py --TypeConfig.nested_list.NestedListStuff.one [1,2] \ + --TypeConfig.nested_list.NestedListStuff.two [ciao,ciao] ... ``` \ No newline at end of file diff --git a/docs/advanced_features/Composition.md b/docs/advanced_features/Composition.md index e0450a80..1c2b4de4 100644 --- a/docs/advanced_features/Composition.md +++ b/docs/advanced_features/Composition.md @@ -60,6 +60,6 @@ nesterov: true ### Warning -You can add as many configuration files as you want to a `config` tag however be aware of circular dependencies (we -do not check for these yet) and that the lower a configuration file is in the order (i.e. later in the list) that -it will take precedence over the others. \ No newline at end of file +You can add as many configuration files as you want to a `config` tag however be aware of circular dependencies (this +should get caught and raise an exception) and that the lower a configuration file is in the order (i.e. later in the +list) that it will take precedence over the others. \ No newline at end of file diff --git a/examples/legacy/quick-start/simple.py b/examples/legacy/quick-start/simple.py index 290e364c..d979d111 100644 --- a/examples/legacy/quick-start/simple.py +++ b/examples/legacy/quick-start/simple.py @@ -13,7 +13,12 @@ class BasicConfig: def add_namespace(config): # Lets just do some basic algebra here - val_sum = sum([(config.fancy_parameter * val) + config.fancier_parameter for val in config.most_fancy_parameter]) + val_sum = sum( + [ + (config.fancy_parameter * val) + config.fancier_parameter + for val in config.most_fancy_parameter + ] + ) # If the boolean is true let's round if config.parameter: val_sum = round(val_sum) @@ -38,10 +43,14 @@ def main(): val_sum_namespace = add_namespace(config.BasicConfig) print(val_sum_namespace) # Or pass by parameter - val_sum_parameter = add_by_parameter(config.BasicConfig.fancy_parameter, config.BasicConfig.most_fancy_parameter, - config.BasicConfig.fancier_parameter, config.BasicConfig.parameter) + val_sum_parameter = add_by_parameter( + config.BasicConfig.fancy_parameter, + config.BasicConfig.most_fancy_parameter, + config.BasicConfig.fancier_parameter, + config.BasicConfig.parameter, + ) print(val_sum_parameter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/legacy/tutorial/advanced/basic_nn.py b/examples/legacy/tutorial/advanced/basic_nn.py index 86015cb9..cf81e88c 100644 --- a/examples/legacy/tutorial/advanced/basic_nn.py +++ b/examples/legacy/tutorial/advanced/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout = [] if model_config.dropout is not None: diff --git a/examples/legacy/tutorial/advanced/tutorial.py b/examples/legacy/tutorial/advanced/tutorial.py index 7bf7e7be..c2aad1c7 100644 --- a/examples/legacy/tutorial/advanced/tutorial.py +++ b/examples/legacy/tutorial/advanced/tutorial.py @@ -1,8 +1,9 @@ +import torch from basic_nn import BasicNet + from spock.args import * from spock.builder import ConfigArgBuilder from spock.config import spock_config -import torch @spock_config @@ -11,8 +12,8 @@ class ModelConfig: n_features: IntArg dropout: ListOptArg[float] hidden_sizes: TupleArg[int] = TupleArg.defaults((32, 32, 32)) - activation: ChoiceArg(choice_set=['relu', 'gelu', 'tanh'], default='relu') - optimizer: ChoiceArg(choice_set=['SGD', 'Adam']) + activation: ChoiceArg(choice_set=["relu", "gelu", "tanh"], default="relu") + optimizer: ChoiceArg(choice_set=["SGD", "Adam"]) cache_path: StrOptArg @@ -38,42 +39,61 @@ class SGDConfig(OptimizerConfig): def train(x_data, y_data, model, model_config, data_config, optimizer_config): - if model_config.optimizer == 'SGD': - optimizer = torch.optim.SGD(model.parameters(), lr=optimizer_config.lr, momentum=optimizer_config.momentum, - nesterov=optimizer_config.nesterov) - elif model_config.optimizer == 'Adam': + if model_config.optimizer == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), + lr=optimizer_config.lr, + momentum=optimizer_config.momentum, + nesterov=optimizer_config.nesterov, + ) + elif model_config.optimizer == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_config.lr) else: - raise ValueError(f'Optimizer choice {optimizer_config.optimizer} not available') + raise ValueError(f"Optimizer choice {optimizer_config.optimizer} not available") n_steps_per_epoch = data_config.n_samples % data_config.batch_size for epoch in range(optimizer_config.n_epochs): for i in range(n_steps_per_epoch): # Ugly data slicing for simplicity - x_batch = x_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] - y_batch = y_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] + x_batch = x_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] + y_batch = y_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] optimizer.zero_grad() output = model(x_batch) loss = torch.nn.CrossEntropyLoss(output, y_batch) loss.backward() if optimizer_config.grad_clip: - torch.nn.utils.clip_grad_value(model.parameters(), optimizer_config.grad_clip) + torch.nn.utils.clip_grad_value( + model.parameters(), optimizer_config.grad_clip + ) optimizer.step() - print(f'Finished Epoch {epoch+1}') + print(f"Finished Epoch {epoch+1}") def main(): # A simple description - description = 'spock Advanced Tutorial' + description = "spock Advanced Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder(ModelConfig, DataConfig, SGDConfig, desc=description).generate() + config = ConfigArgBuilder( + ModelConfig, DataConfig, SGDConfig, desc=description + ).generate() # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in x_data = torch.rand(config.DataConfig.n_samples, config.ModelConfig.n_features) y_data = torch.randint(0, 3, (config.DataConfig.n_samples,)) # Run some training - train(x_data, y_data, basic_nn, config.ModelConfig, config.DataConfig, config.SGDConfig) + train( + x_data, + y_data, + basic_nn, + config.ModelConfig, + config.DataConfig, + config.SGDConfig, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/legacy/tutorial/basic/basic_nn.py b/examples/legacy/tutorial/basic/basic_nn.py index f4dc96b3..fdbe4a1c 100644 --- a/examples/legacy/tutorial/basic/basic_nn.py +++ b/examples/legacy/tutorial/basic/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout_1 = nn.Dropout(model_config.dropout[0]) self.dropout_2 = nn.Dropout(model_config.dropout[1]) diff --git a/examples/legacy/tutorial/basic/tutorial.py b/examples/legacy/tutorial/basic/tutorial.py index ae7ae324..a75a8418 100644 --- a/examples/legacy/tutorial/basic/tutorial.py +++ b/examples/legacy/tutorial/basic/tutorial.py @@ -1,8 +1,9 @@ +import torch from basic_nn import BasicNet + from spock.args import * from spock.builder import ConfigArgBuilder from spock.config import spock_config -import torch @spock_config @@ -11,15 +12,18 @@ class ModelConfig: n_features: IntArg dropout: ListArg[float] hidden_sizes: TupleArg[int] - activation: ChoiceArg(choice_set=['relu', 'gelu', 'tanh']) + activation: ChoiceArg(choice_set=["relu", "gelu", "tanh"]) def main(): # A simple description - description = 'spock Tutorial' + description = "spock Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder( - ModelConfig, desc=description, create_save_path=True).save(file_extension='.toml').generate() + config = ( + ConfigArgBuilder(ModelConfig, desc=description, create_save_path=True) + .save(file_extension=".toml") + .generate() + ) # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in @@ -28,5 +32,5 @@ def main(): print(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/quick-start/simple.py b/examples/quick-start/simple.py index 0f00a00f..76cd15f9 100644 --- a/examples/quick-start/simple.py +++ b/examples/quick-start/simple.py @@ -1,6 +1,7 @@ +from typing import List + from spock.builder import ConfigArgBuilder from spock.config import spock -from typing import List @spock @@ -14,6 +15,7 @@ class BasicConfig: most_fancy_parameter: values to apply basic algebra to """ + parameter: bool fancy_parameter: float fancier_parameter: float @@ -22,7 +24,12 @@ class BasicConfig: def add_namespace(config): # Lets just do some basic algebra here - val_sum = sum([(config.fancy_parameter * val) + config.fancier_parameter for val in config.most_fancy_parameter]) + val_sum = sum( + [ + (config.fancy_parameter * val) + config.fancier_parameter + for val in config.most_fancy_parameter + ] + ) # If the boolean is true let's round if config.parameter: val_sum = round(val_sum) @@ -40,17 +47,21 @@ def add_by_parameter(multiply_param, list_vals, add_param, tf_round): def main(): # Chain the generate function to the class call - config = ConfigArgBuilder(BasicConfig, desc='Quick start example').generate() + config = ConfigArgBuilder(BasicConfig, desc="Quick start example").generate() # One can now access the Spock config object by class name with the returned namespace print(config.BasicConfig.parameter) # And pass the namespace to our first function val_sum_namespace = add_namespace(config.BasicConfig) print(val_sum_namespace) # Or pass by parameter - val_sum_parameter = add_by_parameter(config.BasicConfig.fancy_parameter, config.BasicConfig.most_fancy_parameter, - config.BasicConfig.fancier_parameter, config.BasicConfig.parameter) + val_sum_parameter = add_by_parameter( + config.BasicConfig.fancy_parameter, + config.BasicConfig.most_fancy_parameter, + config.BasicConfig.fancier_parameter, + config.BasicConfig.parameter, + ) print(val_sum_parameter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tune/optuna/__init__.py b/examples/tune/optuna/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/examples/tune/optuna/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/examples/tune/optuna/tune.py b/examples/tune/optuna/tune.py new file mode 100644 index 00000000..de645497 --- /dev/null +++ b/examples/tune/optuna/tune.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +"""A simple example using sklearn and Optuna support""" + +# Spock ONLY supports the define-and-run style interface from Optuna +# https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run + + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +from spock.addons.tune import ( + ChoiceHyperParameter, + OptunaTunerConfig, + RangeHyperParameter, + spockTuner, +) +from spock.builder import ConfigArgBuilder +from spock.config import spock + + +@spock +class BasicParams: + n_trials: int + max_iter: int + + +@spockTuner +class LogisticRegressionHP: + c: RangeHyperParameter + solver: ChoiceHyperParameter + + +def main(): + # Load the iris data + X, y = load_iris(return_X_y=True) + + # Split the Iris data + X_train, X_valid, y_train, y_valid = train_test_split(X, y) + + # Optuna config -- this will internally spawn the study object for the define-and-run style which will be returned + # by accessing the tuner_status property on the ConfigArgBuilder object + optuna_config = OptunaTunerConfig( + study_name="Iris Logistic Regression", direction="maximize" + ) + + # Use the builder to setup + # Call tuner to indicate that we are going to do some HP tuning -- passing in an optuna study object + attrs_obj = ( + ConfigArgBuilder( + LogisticRegressionHP, + BasicParams, + desc="Example Logistic Regression Hyper-Parameter Tuning", + ) + .tuner(tuner_config=optuna_config) + .save(user_specified_path="/tmp") + ) + + # Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params + # prior to starting the sampling process + fixed_params = attrs_obj.generate() + + # Now we iterate through a bunch of optuna trials + for _ in range(fixed_params.BasicParams.n_trials): + # The crux of spock support -- call save w/ the add_tuner_sample flag to write the current draw to file and + # then call sample to return the composed Spockspace of the fixed parameters and the sampled parameters + # Under the hood spock uses the define-and-run Optuna interface -- thus it handled the underlying 'ask' call + # and returns the necessary trial object in the return dictionary to call 'tell' with the study object + hp_attrs = attrs_obj.save( + add_tuner_sample=True, user_specified_path="/tmp" + ).sample() + # Use the currently sampled parameters in a simple LogisticRegression from sklearn + clf = LogisticRegression( + C=hp_attrs.LogisticRegressionHP.c, + solver=hp_attrs.LogisticRegressionHP.solver, + max_iter=hp_attrs.BasicParams.max_iter, + ) + clf.fit(X_train, y_train) + val_acc = clf.score(X_valid, y_valid) + # Get the status of the tuner -- this dict will contain all the objects needed to update + tuner_status = attrs_obj.tuner_status + # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study + # object + tuner_status["study"].tell(tuner_status["trial"], val_acc) + # Always save the current best set of hyper-parameters + attrs_obj.save_best(user_specified_path="/tmp") + + # Grab the best config and metric + best_config, best_metric = attrs_obj.best + print(f"Best HP Config:\n{best_config}") + print(f"Best Metric: {best_metric}") + + +if __name__ == "__main__": + main() diff --git a/examples/tune/optuna/tune.yaml b/examples/tune/optuna/tune.yaml new file mode 100644 index 00000000..5a56e06d --- /dev/null +++ b/examples/tune/optuna/tune.yaml @@ -0,0 +1,15 @@ +################ +# tune.yaml +################ +BasicParams: + n_trials: 10 + max_iter: 150 + +LogisticRegressionHP: + c: + type: float + bounds: [1E-07, 10.0] + log_scale: true + solver: + type: str + choices: ["lbfgs", "saga"] \ No newline at end of file diff --git a/examples/tutorial/advanced/basic_nn.py b/examples/tutorial/advanced/basic_nn.py index 86015cb9..cf81e88c 100644 --- a/examples/tutorial/advanced/basic_nn.py +++ b/examples/tutorial/advanced/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout = [] if model_config.dropout is not None: diff --git a/examples/tutorial/advanced/tutorial.py b/examples/tutorial/advanced/tutorial.py index 13a01a3c..09c54993 100644 --- a/examples/tutorial/advanced/tutorial.py +++ b/examples/tutorial/advanced/tutorial.py @@ -1,23 +1,23 @@ -from basic_nn import BasicNet from enum import Enum +from typing import List, Optional, Tuple + +import torch +from basic_nn import BasicNet + from spock.args import SavePath from spock.builder import ConfigArgBuilder from spock.config import spock -import torch -from typing import List -from typing import Optional -from typing import Tuple class Activation(Enum): - relu = 'relu' - gelu = 'gelu' - tanh = 'tanh' + relu = "relu" + gelu = "gelu" + tanh = "tanh" class Optimizer(Enum): - sgd = 'SGD' - adam = 'Adam' + sgd = "SGD" + adam = "Adam" @spock @@ -26,7 +26,7 @@ class ModelConfig: n_features: int dropout: Optional[List[float]] hidden_sizes: Tuple[int, int, int] = (32, 32, 32) - activation: Activation = 'relu' + activation: Activation = "relu" optimizer: Optimizer cache_path: Optional[str] @@ -53,42 +53,61 @@ class SGDConfig(OptimizerConfig): def train(x_data, y_data, model, model_config, data_config, optimizer_config): - if model_config.optimizer == 'SGD': - optimizer = torch.optim.SGD(model.parameters(), lr=optimizer_config.lr, momentum=optimizer_config.momentum, - nesterov=optimizer_config.nesterov) - elif model_config.optimizer == 'Adam': + if model_config.optimizer == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), + lr=optimizer_config.lr, + momentum=optimizer_config.momentum, + nesterov=optimizer_config.nesterov, + ) + elif model_config.optimizer == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_config.lr) else: - raise ValueError(f'Optimizer choice {optimizer_config.optimizer} not available') + raise ValueError(f"Optimizer choice {optimizer_config.optimizer} not available") n_steps_per_epoch = data_config.n_samples % data_config.batch_size for epoch in range(optimizer_config.n_epochs): for i in range(n_steps_per_epoch): # Ugly data slicing for simplicity - x_batch = x_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] - y_batch = y_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] + x_batch = x_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] + y_batch = y_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] optimizer.zero_grad() output = model(x_batch) loss = torch.nn.CrossEntropyLoss(output, y_batch) loss.backward() if optimizer_config.grad_clip: - torch.nn.utils.clip_grad_value(model.parameters(), optimizer_config.grad_clip) + torch.nn.utils.clip_grad_value( + model.parameters(), optimizer_config.grad_clip + ) optimizer.step() - print(f'Finished Epoch {epoch+1}') + print(f"Finished Epoch {epoch+1}") def main(): # A simple description - description = 'spock Advanced Tutorial' + description = "spock Advanced Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder(ModelConfig, DataConfig, SGDConfig, desc=description).generate() + config = ConfigArgBuilder( + ModelConfig, DataConfig, SGDConfig, desc=description + ).generate() # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in x_data = torch.rand(config.DataConfig.n_samples, config.ModelConfig.n_features) y_data = torch.randint(0, 3, (config.DataConfig.n_samples,)) # Run some training - train(x_data, y_data, basic_nn, config.ModelConfig, config.DataConfig, config.SGDConfig) + train( + x_data, + y_data, + basic_nn, + config.ModelConfig, + config.DataConfig, + config.SGDConfig, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/basic/basic_nn.py b/examples/tutorial/basic/basic_nn.py index f4dc96b3..fdbe4a1c 100644 --- a/examples/tutorial/basic/basic_nn.py +++ b/examples/tutorial/basic/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout_1 = nn.Dropout(model_config.dropout[0]) self.dropout_2 = nn.Dropout(model_config.dropout[1]) diff --git a/examples/tutorial/basic/tutorial.py b/examples/tutorial/basic/tutorial.py index 19c953ae..52f832b3 100644 --- a/examples/tutorial/basic/tutorial.py +++ b/examples/tutorial/basic/tutorial.py @@ -1,11 +1,12 @@ -from basic_nn import BasicNet from enum import Enum +from typing import List, Tuple + +import torch +from basic_nn import BasicNet + from spock.args import SavePath from spock.builder import ConfigArgBuilder from spock.config import spock -import torch -from typing import List -from typing import Tuple class Activation(Enum): @@ -16,9 +17,10 @@ class Activation(Enum): gelu: gelu activation tanh: tanh activation """ - relu = 'relu' - gelu = 'gelu' - tanh = 'tanh' + + relu = "relu" + gelu = "gelu" + tanh = "tanh" @spock @@ -32,6 +34,7 @@ class ModelConfig: hidden_sizes: hidden size for each layer activation: choice from the Activation enum of the activation function to use """ + save_path: SavePath n_features: int dropout: List[float] @@ -41,10 +44,13 @@ class ModelConfig: def main(): # A simple description - description = 'spock Basic Tutorial' + description = "spock Basic Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder( - ModelConfig, desc=description, create_save_path=True).save(file_extension='.toml').generate() + config = ( + ConfigArgBuilder(ModelConfig, desc=description, create_save_path=True) + .save(file_extension=".toml") + .generate() + ) # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in @@ -53,5 +59,5 @@ def main(): print(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 2a1ba4df..198b5585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[tool.isort] +profile = "black" + [tool.portray] extra_dirs = ["resources"] @@ -65,6 +68,15 @@ Motivation = "docs/Motivation.md" [[tool.portray.mkdocs.nav]] [[tool.portray.mkdocs.nav."Addons"]] "S3" = "docs/addons/S3.md" + [[tool.portray.mkdocs.nav."Addons"]] + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "About" = "docs/addons/tuner/About.md" + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Basics" = "docs/addons/tuner/Basics.md" + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Optuna" = "docs/addons/tuner/Optuna.md" + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Saving" = "docs/addons/tuner/Saving.md" [[tool.portray.mkdocs.nav]] Contributing = "CONTRIBUTING.md" diff --git a/DEV_REQUIREMENTS.txt b/requirements/DEV_REQUIREMENTS.txt similarity index 73% rename from DEV_REQUIREMENTS.txt rename to requirements/DEV_REQUIREMENTS.txt index 3934641c..fc406fb0 100644 --- a/DEV_REQUIREMENTS.txt +++ b/requirements/DEV_REQUIREMENTS.txt @@ -1,6 +1,7 @@ --r REQUIREMENTS.txt +black coveralls coverage +isort moto portray pytest diff --git a/S3_REQUIREMENTS.txt b/requirements/S3_REQUIREMENTS.txt similarity index 100% rename from S3_REQUIREMENTS.txt rename to requirements/S3_REQUIREMENTS.txt diff --git a/requirements/TEST_EXTRAS_REQUIREMENTS.txt b/requirements/TEST_EXTRAS_REQUIREMENTS.txt new file mode 100644 index 00000000..ff88936c --- /dev/null +++ b/requirements/TEST_EXTRAS_REQUIREMENTS.txt @@ -0,0 +1 @@ +scikit-learn \ No newline at end of file diff --git a/requirements/TUNE_REQUIREMENTS.txt b/requirements/TUNE_REQUIREMENTS.txt new file mode 100644 index 00000000..04534b27 --- /dev/null +++ b/requirements/TUNE_REQUIREMENTS.txt @@ -0,0 +1,4 @@ +optuna==2.8.0 +#torchvision +#torch +#ax-platform \ No newline at end of file diff --git a/setup.py b/setup.py index cb474a50..74093ab0 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,32 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Spock Setup""" -from pkg_resources import parse_requirements import setuptools +from pkg_resources import parse_requirements + import versioneer -with open('README.md', 'r') as fid: +with open("README.md", "r") as fid: long_description = fid.read() -with open('REQUIREMENTS.txt', 'r') as fid: +with open("REQUIREMENTS.txt", "r") as fid: install_reqs = [str(req) for req in parse_requirements(fid)] -with open('S3_REQUIREMENTS.txt', 'r') as fid: +with open("./requirements/S3_REQUIREMENTS.txt", "r") as fid: s3_reqs = [str(req) for req in parse_requirements(fid)] +with open("./requirements/TUNE_REQUIREMENTS.txt", "r") as fid: + tune_reqs = [str(req) for req in parse_requirements(fid)] + setuptools.setup( - name='spock-config', - description='Spock is a framework designed to help manage complex parameter configurations for Python applications', + name="spock-config", + description="Spock is a framework designed to help manage complex parameter configurations for Python applications", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), author="FMR LLC", @@ -33,24 +37,33 @@ "Intended Audience :: Developers", "Natural Language :: English", "License :: OSI Approved :: Apache Software License", - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Operating System :: OS Independent", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules" + "Topic :: Software Development :: Libraries :: Python Modules", ], project_urls={ "Source": "https://github.com/fidelity/spock", "Documentation": "https://fidelity.github.io/spock/", - "Bug Tracker": "https://fidelity.github.io/spock/issues" + "Bug Tracker": "https://fidelity.github.io/spock/issues", }, - keywords=['configuration', 'argparse', 'parameters', 'machine learning', 'deep learning', 'reproducibility'], - packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - python_requires='>=3.6', + keywords=[ + "configuration", + "argparse", + "parameters", + "machine learning", + "deep learning", + "reproducibility", + ], + packages=setuptools.find_packages( + exclude=["*.tests", "*.tests.*", "tests.*", "tests"] + ), + python_requires=">=3.6", install_requires=install_reqs, - extras_require={'s3': s3_reqs} + extras_require={"s3": s3_reqs, "tune": tune_reqs}, ) diff --git a/spock/__init__.py b/spock/__init__.py index 0ea44bea..46d8e091 100644 --- a/spock/__init__.py +++ b/spock/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -13,5 +13,5 @@ __all__ = ["args", "builder", "config"] -__version__ = get_versions()['version'] -del get_versions \ No newline at end of file +__version__ = get_versions()["version"] +del get_versions diff --git a/spock/addons/__init__.py b/spock/addons/__init__.py index 854473e1..ce9724fd 100644 --- a/spock/addons/__init__.py +++ b/spock/addons/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -8,8 +8,6 @@ Please refer to the documentation provided in the README.md """ -from spock.addons.s3.utils import S3Config -from spock.addons.s3.configs import S3DownloadConfig -from spock.addons.s3.configs import S3UploadConfig -__all__ = ["s3", "S3Config", "S3DownloadConfig", "S3UploadConfig"] + +__all__ = ["s3", "tune"] diff --git a/spock/addons/s3/__init__.py b/spock/addons/s3/__init__.py index 6927bcbe..eb6f546a 100644 --- a/spock/addons/s3/__init__.py +++ b/spock/addons/s3/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -9,4 +9,6 @@ Please refer to the documentation provided in the README.md """ -__all__ = ["utils"] +from spock.addons.s3.configs import S3Config, S3DownloadConfig, S3UploadConfig + +__all__ = ["configs", "utils", "S3Config", "S3DownloadConfig", "S3UploadConfig"] diff --git a/spock/addons/s3/configs.py b/spock/addons/s3/configs.py index 01818c02..1c948676 100644 --- a/spock/addons/s3/configs.py +++ b/spock/addons/s3/configs.py @@ -1,45 +1,54 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles all S3 related configurations""" import attr + try: import boto3 from botocore.client import BaseClient from s3transfer.manager import TransferManager except ImportError: - print('Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- ' - 'pip install spock-config[s3]') -import typing - + print( + "Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- " + "pip install spock-config[s3]" + ) +from typing import Optional # Iterate through the allowed download args for S3 and map into optional attr.ib download_attrs = { val: attr.ib( default=None, type=str, - validator=attr.validators.optional(attr.validators.instance_of(str)) - ) for val in TransferManager.ALLOWED_DOWNLOAD_ARGS} + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + for val in TransferManager.ALLOWED_DOWNLOAD_ARGS +} # Make the class dynamically -S3DownloadConfig = attr.make_class(name="S3DownloadConfig", attrs=download_attrs, kw_only=True, frozen=True) +S3DownloadConfig = attr.make_class( + name="S3DownloadConfig", attrs=download_attrs, kw_only=True, frozen=True +) # Iterate through the allowed upload args for S3 and map into optional attr.ib upload_attrs = { val: attr.ib( default=None, type=str, - validator=attr.validators.optional(attr.validators.instance_of(str)) - ) for val in TransferManager.ALLOWED_UPLOAD_ARGS + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + for val in TransferManager.ALLOWED_UPLOAD_ARGS } # Make the class dynamically -S3UploadConfig = attr.make_class(name="S3UploadConfig", attrs=upload_attrs, kw_only=True, frozen=True) +S3UploadConfig = attr.make_class( + name="S3UploadConfig", attrs=upload_attrs, kw_only=True, frozen=True +) @attr.s(auto_attribs=True) @@ -56,13 +65,14 @@ class S3Config: upload_config: S3UploadConfig for extra upload configs (optional) """ + session: boto3.Session # s3_session: BaseClient = attr.ib(init=False) - s3_session: typing.Optional[BaseClient] = None - temp_folder: typing.Optional[str] = '/tmp/' + s3_session: Optional[BaseClient] = None + temp_folder: Optional[str] = "/tmp/" download_config: S3DownloadConfig = S3DownloadConfig() upload_config: S3UploadConfig = S3UploadConfig() def __attrs_post_init__(self): if self.s3_session is None: - self.s3_session = self.session.client('s3') + self.s3_session = self.session.client("s3") diff --git a/spock/addons/s3/utils.py b/spock/addons/s3/utils.py index b463f57b..78678bf9 100644 --- a/spock/addons/s3/utils.py +++ b/spock/addons/s3/utils.py @@ -1,25 +1,28 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles all S3 related ops -- allows for s3 functionality to be optional to keep req deps light""" import attr + try: import boto3 from botocore.client import BaseClient except ImportError: - print('Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- ' - 'pip install spock-config[s3]') -from hurry.filesize import size + print( + "Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- " + "pip install spock-config[s3]" + ) import os -from urllib.parse import urlparse -from spock.addons.s3.configs import S3Config -from spock.addons.s3.configs import S3DownloadConfig -from spock.addons.s3.configs import S3UploadConfig import sys import typing +from urllib.parse import urlparse + +from hurry.filesize import size + +from spock.addons.s3.configs import S3Config, S3DownloadConfig, S3UploadConfig def handle_s3_load_path(path: str, s3_config: S3Config) -> str: @@ -39,15 +42,20 @@ def handle_s3_load_path(path: str, s3_config: S3Config) -> str: """ if s3_config is None: - raise ValueError('Load from S3 -- Missing S3Config object which is necessary to handle S3 style paths') + raise ValueError( + "Load from S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) bucket, obj, fid = get_s3_bucket_object_name(s3_path=path) # Construct the full temp path - temp_path = f'{s3_config.temp_folder}/{fid}' + temp_path = f"{s3_config.temp_folder}/{fid}" # Strip double slashes if exist - temp_path = temp_path.replace(r'//', r'/') + temp_path = temp_path.replace(r"//", r"/") temp_path = download_s3( - bucket=bucket, obj=obj, temp_path=temp_path, s3_session=s3_config.s3_session, - download_config=s3_config.download_config + bucket=bucket, + obj=obj, + temp_path=temp_path, + s3_session=s3_config.s3_session, + download_config=s3_config.download_config, ) return temp_path @@ -68,13 +76,18 @@ def handle_s3_save_path(temp_path: str, s3_path: str, name: str, s3_config: S3Co """ if s3_config is None: - raise ValueError('Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths') + raise ValueError( + "Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) # Fix posix strip - s3_path = s3_path.replace('s3:/', 's3://') - bucket, obj, fid = get_s3_bucket_object_name(f'{s3_path}/{name}') + s3_path = s3_path.replace("s3:/", "s3://") + bucket, obj, fid = get_s3_bucket_object_name(f"{s3_path}/{name}") upload_s3( - bucket=bucket, obj=obj, temp_path=temp_path, - s3_session=s3_config.s3_session, upload_config=s3_config.upload_config + bucket=bucket, + obj=obj, + temp_path=temp_path, + s3_session=s3_config.s3_session, + upload_config=s3_config.upload_config, ) @@ -93,11 +106,16 @@ def get_s3_bucket_object_name(s3_path: str) -> typing.Tuple[str, str, str]: """ parsed = urlparse(s3_path) - return parsed.netloc, parsed.path.lstrip('/'), os.path.basename(parsed.path) + return parsed.netloc, parsed.path.lstrip("/"), os.path.basename(parsed.path) -def download_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, - download_config: S3DownloadConfig) -> str: +def download_s3( + bucket: str, + obj: str, + temp_path: str, + s3_session: BaseClient, + download_config: S3DownloadConfig, +) -> str: """Attempts to download the file from the S3 uri to a temp location using any extra arguments to the download *Args*: @@ -115,9 +133,13 @@ def download_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, """ try: # Unroll the extra options for those values that are not None - extra_options = {k: v for k, v in attr.asdict(download_config).items() if v is not None} - file_size = s3_session.head_object(Bucket=bucket, Key=obj, **extra_options)['ContentLength'] - print(f'Attempting to download s3://{bucket}/{obj} (size: {size(file_size)})') + extra_options = { + k: v for k, v in attr.asdict(download_config).items() if v is not None + } + file_size = s3_session.head_object(Bucket=bucket, Key=obj, **extra_options)[ + "ContentLength" + ] + print(f"Attempting to download s3://{bucket}/{obj} (size: {size(file_size)})") current_progress = 0 n_ticks = 50 @@ -126,21 +148,34 @@ def _s3_progress_bar(chunk): # Increment progress current_progress += chunk done = int(n_ticks * (current_progress / file_size)) - sys.stdout.write(f"\r[%s%s] " - f"{int(current_progress/file_size) * 100}%%" % ('=' * done, ' ' * (n_ticks - done))) + sys.stdout.write( + f"\r[%s%s] " + f"{int(current_progress/file_size) * 100}%%" + % ("=" * done, " " * (n_ticks - done)) + ) sys.stdout.flush() - sys.stdout.write('\n\n') + sys.stdout.write("\n\n") + # Download with the progress callback - s3_session.download_file(bucket, obj, temp_path, Callback=_s3_progress_bar, ExtraArgs=extra_options) + s3_session.download_file( + bucket, obj, temp_path, Callback=_s3_progress_bar, ExtraArgs=extra_options + ) return temp_path except IOError: - print(f'Failed to download file from S3 ' - f'(bucket: {bucket}, object: {obj}) ' - f'and write to {temp_path}') - - -def upload_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, - upload_config: S3UploadConfig): + print( + f"Failed to download file from S3 " + f"(bucket: {bucket}, object: {obj}) " + f"and write to {temp_path}" + ) + + +def upload_s3( + bucket: str, + obj: str, + temp_path: str, + s3_session: BaseClient, + upload_config: S3UploadConfig, +): """Attempts to upload the local file to the S3 uri using any extra arguments to the upload *Args*: @@ -156,9 +191,11 @@ def upload_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, """ try: # Unroll the extra options for those values that are not None - extra_options = {k: v for k, v in attr.asdict(upload_config).items() if v is not None} + extra_options = { + k: v for k, v in attr.asdict(upload_config).items() if v is not None + } file_size = os.path.getsize(temp_path) - print(f'Attempting to upload s3://{bucket}/{obj} (size: {size(file_size)})') + print(f"Attempting to upload s3://{bucket}/{obj} (size: {size(file_size)})") current_progress = 0 n_ticks = 50 @@ -167,13 +204,21 @@ def _s3_progress_bar(chunk): # Increment progress current_progress += chunk done = int(n_ticks * (current_progress / file_size)) - sys.stdout.write(f"\r[%s%s] " - f"{int(current_progress/file_size) * 100}%%" % ('=' * done, ' ' * (n_ticks - done))) + sys.stdout.write( + f"\r[%s%s] " + f"{int(current_progress/file_size) * 100}%%" + % ("=" * done, " " * (n_ticks - done)) + ) sys.stdout.flush() - sys.stdout.write('\n\n') + sys.stdout.write("\n\n") + # Upload with progress callback - s3_session.upload_file(temp_path, bucket, obj, Callback=_s3_progress_bar, ExtraArgs=extra_options) + s3_session.upload_file( + temp_path, bucket, obj, Callback=_s3_progress_bar, ExtraArgs=extra_options + ) except IOError: - print(f'Failed to upload file to S3 ' - f'(bucket: {bucket}, object: {obj}) ' - f'from {temp_path}') + print( + f"Failed to upload file to S3 " + f"(bucket: {bucket}, object: {obj}) " + f"from {temp_path}" + ) diff --git a/spock/addons/tune/__init__.py b/spock/addons/tune/__init__.py new file mode 100644 index 00000000..6982e392 --- /dev/null +++ b/spock/addons/tune/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +""" +Spock is a framework that helps manage complex parameter configurations for Python applications + +Please refer to the documentation provided in the README.md +""" +from spock.addons.tune.config import ( + ChoiceHyperParameter, + OptunaTunerConfig, + RangeHyperParameter, + spockTuner, +) + +__all__ = [ + "builder", + "config", + "spockTuner", + "RangeHyperParameter", + "ChoiceHyperParameter", + "OptunaTunerConfig", +] diff --git a/spock/addons/tune/builder.py b/spock/addons/tune/builder.py new file mode 100644 index 00000000..dd585549 --- /dev/null +++ b/spock/addons/tune/builder.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner builder backend""" + +from spock.backend.builder import BaseBuilder +from spock.utils import make_argument + + +class TunerBuilder(BaseBuilder): + def __init__(self, *args, **kwargs): + """TunerBuilder init + + Args: + *args: list of input classes that link to a backend + configs: None or List of configs to read from + desc: description for the arg parser + no_cmd_line: flag to force no command line reads + **kwargs: any extra keyword args + """ + super().__init__(*args, module_name="spock.addons.tune.config", **kwargs) + + def _handle_arguments(self, args, class_obj): + """Ovverides base -- Handles all argument mapping + + Creates a dictionary of named parameters that are mapped to the final type of object + + *Args*: + + args: read file arguments + class_obj: instance of a class obj + + *Returns*: + + fields: dictionary of mapped parameters + + """ + attr_name = class_obj.__name__ + fields = { + val.name: val.type(**args[attr_name][val.name]) + for val in class_obj.__attrs_attrs__ + } + return fields + + @staticmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.val.(unrolled config parameters) so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + attr_name = class_obj.__name__ + group_parser = parser.add_argument_group( + title=str(attr_name) + " Specific Overrides" + ) + for val in class_obj.__attrs_attrs__: + val_type = val.metadata["type"] if "type" in val.metadata else val.type + for arg in val_type.__attrs_attrs__: + arg_name = f"--{str(attr_name)}.{val.name}.{arg.name}" + group_parser = make_argument(arg_name, arg.type, group_parser) + return parser + + def _extract_fnc(self, val, module_name): + return self._extract_other_types(val.type, module_name) diff --git a/spock/addons/tune/config.py b/spock/addons/tune/config.py new file mode 100644 index 00000000..5974f789 --- /dev/null +++ b/spock/addons/tune/config.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Creates the spock config interface that wraps attr -- tune version for hyper-parameters""" +import sys +from typing import List, Optional, Sequence, Tuple, Union + +import attr +import optuna + +from spock.backend.config import _base_attr + + +@attr.s(auto_attribs=True) +class OptunaTunerConfig: + storage: Optional[Union[str, optuna.storages.BaseStorage]] = None + sampler: Optional[optuna.samplers.BaseSampler] = None + pruner: Optional[optuna.pruners.BasePruner] = None + study_name: Optional[str] = None + direction: Optional[Union[str, optuna.study.StudyDirection]] = None + load_if_exists: bool = False + directions: Optional[Sequence[Union[str, optuna.study.StudyDirection]]] = None + + +def _spock_tune(cls): + """Ovverides basic spock_attr decorator with another name + + Using a different name allows spock to easily determine which parameters are normal and which are + meant to be used in a hyper-parameter tuning backend + + *Args*: + + cls: basic class def + + *Returns*: + + cls: slotted attrs class that is frozen and kw only + """ + bases, attrs_dict = _base_attr(cls) + # Dynamically make an attr class + obj = attr.make_class( + name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True + ) + # For each class we dynamically create we need to register it within the system modules for pickle to work + setattr(sys.modules["spock"].addons.tune.config, obj.__name__, obj) + # Swap the __doc__ string from cls to obj + obj.__doc__ = cls.__doc__ + return obj + + +# Make the alias for the decorator +spockTuner = _spock_tune + + +@attr.s +class RangeHyperParameter: + """Range based hyper-parameter that is sampled uniformly + + Attributes: + type: type of the hyper-parameter (note: spock will attempt to autocast into this type) + bounds: min and max of the hyper-parameter range + log_scale: log scale the values before sampling + + """ + + type = attr.ib( + type=str, + validator=[ + attr.validators.instance_of(str), + attr.validators.in_(["float", "int"]), + ], + ) + bounds = attr.ib( + type=Union[Tuple[float, float], Tuple[int, int]], + validator=attr.validators.deep_iterable( + member_validator=attr.validators.instance_of((float, int)), + iterable_validator=attr.validators.instance_of(tuple), + ), + ) + log_scale = attr.ib(type=bool, validator=attr.validators.instance_of(bool)) + + +@attr.s +class ChoiceHyperParameter: + """Choice based hyper-parameter that is sampled uniformly + + Attributes: + type: type of the hyper-parameter -- (note: spock will attempt to autocast into this type) + choices: list of variable length that contains all the possible choices to select from + + """ + + type = attr.ib( + type=str, + validator=[ + attr.validators.instance_of(str), + attr.validators.in_(["float", "int", "str", "bool"]), + ], + ) + choices = attr.ib( + type=Union[List[str], List[int], List[float], List[bool]], + validator=attr.validators.deep_iterable( + member_validator=attr.validators.instance_of((float, int, bool, str)), + iterable_validator=attr.validators.instance_of(list), + ), + ) diff --git a/spock/addons/tune/interface.py b/spock/addons/tune/interface.py new file mode 100644 index 00000000..eac45dbf --- /dev/null +++ b/spock/addons/tune/interface.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the base interface""" +from abc import ABC, abstractmethod +from typing import Dict + +import attr + +from spock.backend.wrappers import Spockspace + + +class BaseInterface(ABC): + def __init__(self, tuner_config, tuner_namespace: Spockspace): + """Base init call that maps a few variables + + *Args*: + + _tuner_config: necessary object to determine the interface and sample correctly from the underlying library + _tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + + """ + + self._tuner_config = { + k: v for k, v in attr.asdict(tuner_config).items() if v is not None + } + self._tuner_namespace = tuner_namespace + + @abstractmethod + def sample(self): + """Calls the underlying library sample to get a single sample/draw from the hyper-parameter + sets (e.g. ranges, choices) + + *Returns*: + + Spockspace of the current hyper-parameter draw + + """ + pass + + @abstractmethod + def _construct(self): + """Constructs the base object needed by the underlying library to construct the correct object that allows + for hyper-parameter sampling + + *Returns*: + + Any typed object needed for support + + """ + pass + + @staticmethod + def _gen_attr_classes(tune_dict: Dict): + for k, v in tune_dict.items(): + attrs_dict = { + ik: attr.ib( + validator=attr.validators.instance_of(type(iv)), type=type(iv) + ) + for ik, iv in v.items() + } + obj = attr.make_class(name=k, attrs=attrs_dict, kw_only=True, frozen=True) + tune_dict.update({k: obj(**v)}) + return tune_dict + + @staticmethod + def _to_spockspace(tune_dict: Dict): + """Converts a dict to a Spockspace + + *Args*: + + tune_dict: current dictionary + + *Returns*: + + Spockspace of dict + + """ + return Spockspace(**tune_dict) + + @staticmethod + def _get_caster(val): + """Gets a callable type object from a string type + + *Args*: + + val: current attr val: + + *Returns*: + + type class object + + """ + return __builtins__[val.type] + + @property + @abstractmethod + def tuner_status(self): + """Returns a dictionary of all the necessary underlying tuner internals to report the result""" + pass + + @property + @abstractmethod + def best(self): + """Returns a Spockspace of the best hyper-parameter config and the associated metric value""" diff --git a/spock/addons/tune/optuna.py b/spock/addons/tune/optuna.py new file mode 100644 index 00000000..9b80e298 --- /dev/null +++ b/spock/addons/tune/optuna.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the optuna backend""" + +import hashlib +import json +from warnings import warn + +import attr +import optuna + +from spock.addons.tune.config import OptunaTunerConfig +from spock.addons.tune.interface import BaseInterface + + +class OptunaInterface(BaseInterface): + """Specific override to support the optuna backend + + *Attributes*: + + _map_type: dictionary that maps class names and types to fns that create optuna distributions + _tuner_obj: necessary object to determine the interface and sample correctly from the underlying library + _tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + _param_obj: underlying object that optuna study can sample from (flat dictionary) + + """ + + def __init__(self, tuner_config: OptunaTunerConfig, tuner_namespace): + """OptunaInterface init call that maps variables, creates a map to fnc calls, and constructs the necessary + underlying objects + + *Args*: + + tuner_config: necessary object to determine the interface and sample correctly from the underlying library + tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + + """ + super(OptunaInterface, self).__init__(tuner_config, tuner_namespace) + self._tuner_obj = optuna.create_study(**self._tuner_config) + self._trial = None + self._sample_hash = None + self._trial_status_hash = None + # Mapping spock underlying classes to optuna distributions (define-and-run interface) + self._map_type = { + "RangeHyperParameter": { + "int": self._uniform_int_dist, + "float": self._uniform_float_dist, + }, + "ChoiceHyperParameter": { + "int": self._categorical_dist, + "float": self._categorical_dist, + "str": self._categorical_dist, + "bool": self._categorical_dist, + }, + } + # Build the correct underlying dictionary object for Optuna + self._param_obj = self._construct() + + @property + def tuner_status(self): + return {"trial": self._trial, "study": self._tuner_obj} + + @property + def best(self): + rollup_dict, _ = self._trial_rollup(self._tuner_obj.best_trial) + return ( + self._to_spockspace(self._gen_attr_classes(rollup_dict)), + self._tuner_obj.best_value, + ) + + def sample(self): + self._trial = self._tuner_obj.ask(self._param_obj) + # Roll this back out into a Spockspace so it can be merged into the fixed parameter Spockspace + # Also need to un-dot the param names to rebuild the nested structure + rollup_dict, sample_hash = self._trial_rollup(self._trial) + self._sample_hash = sample_hash + return self._to_spockspace(self._gen_attr_classes(rollup_dict)) + + @staticmethod + def _trial_rollup(trial): + """Rollup the trial into a dictionary that can be converted to a spockspace with the correct names and roots + + *Returns*: + + dictionary of rolled up sampled parameters + md5 hash of the dictionary contents + + """ + key_set = {k.split(".")[0] for k in trial.params.keys()} + rollup_dict = {val: {} for val in key_set} + for k, v in trial.params.items(): + split_names = k.split(".") + rollup_dict[split_names[0]].update({split_names[1]: v}) + dict_hash = hashlib.md5( + json.dumps(rollup_dict, sort_keys=True).encode("utf-8") + ).digest() + return rollup_dict, dict_hash + + def _construct(self): + """Constructs the base object needed by the underlying library to construct the correct object that allows + for hyper-parameter sampling + + *Returns*: + + flat dictionary of all hyper-parameters named with dot notation (class.param_name) + + """ + optuna_dict = {} + # These will only be nested one level deep given the tuner syntax + for k, v in vars(self._tuner_namespace).items(): + for ik, iv in vars(v).items(): + param_fn = self._map_type[type(iv).__name__][iv.type] + optuna_dict.update({f"{k}.{ik}": param_fn(iv)}) + return optuna_dict + + @staticmethod + def _uniform_float_dist(val): + """Assemble the optuna.distributions.(Log)UniformDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.UniformDistribution or optuna.distributions.LogUniformDistribution + + """ + try: + low = float(val.bounds[0]) + high = float(val.bounds[1]) + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to RangeHyperParameter" + ) + log_scale = val.log_scale + return ( + optuna.distributions.LogUniformDistribution(low=low, high=high) + if log_scale + else optuna.distributions.UniformDistribution(low=low, high=high) + ) + + @staticmethod + def _uniform_int_dist(val): + """Assemble the optuna.distributions.Int(Log)UniformDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.IntUniformDistribution or optuna.distributions.IntLogUniformDistribution + + """ + try: + low = int(val.bounds[0]) + high = int(val.bounds[1]) + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to RangeHyperParameter" + ) + log_scale = val.log_scale + return ( + optuna.distributions.IntLogUniformDistribution(low=low, high=high) + if log_scale + else optuna.distributions.IntUniformDistribution(low=low, high=high) + ) + + def _categorical_dist(self, val): + """Assemble the optuna.distributions.CategoricalDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.CategoricalDistribution + + """ + caster = self._get_caster(val) + # Just attempt to cast in a try except + try: + val.choices = [caster(v) for v in val.choices] + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to ChoiceHyperParameter" + ) + return optuna.distributions.CategoricalDistribution(choices=val.choices) diff --git a/spock/addons/tune/payload.py b/spock/addons/tune/payload.py new file mode 100644 index 00000000..200fa75c --- /dev/null +++ b/spock/addons/tune/payload.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner payload backend""" + +from spock.backend.payload import BasePayload +from spock.backend.utils import get_attr_fields + + +class TunerPayload(BasePayload): + """Handles building the payload for tuners + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via a recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + + """ + + def __init__(self, s3_config=None): + """Init for TunerPayload + + *Args*: + + s3_config: optional S3 config object + + """ + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + """Call to allow self chaining + + *Args*: + + *args: + **kwargs: + + *Returns*: + + Payload: instance of self + + """ + return TunerPayload(*args, **kwargs) + + @staticmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + # Get basic args + attr_fields = get_attr_fields(input_classes=input_classes) + # Get the ignore fields + ignore_fields = get_attr_fields(input_classes=ignore_classes) + for k, v in base_payload.items(): + if k not in ignore_fields: + if k != "config": + # Dict infers that we are overriding a global setting in a specific config + if isinstance(v, dict): + # we're in a namespace + # Check for incorrect specific override of global def + if k not in attr_fields: + raise TypeError( + f"Referring to a class space {k} that is undefined" + ) + for i_keys in v.keys(): + if i_keys not in attr_fields[k]: + raise ValueError( + f"Provided an unknown argument named {k}.{i_keys}" + ) + if k in payload and isinstance(v, dict): + payload[k].update(v) + else: + payload[k] = v + # Handle tuple conversion here -- lazily + for ik, iv in v.items(): + if "bounds" in iv: + iv["bounds"] = tuple(iv["bounds"]) + return payload + + @staticmethod + def _handle_payload_override(payload, key, value): + key_split = key.split(".") + curr_ref = payload + for idx, split in enumerate(key_split): + # If the root isn't in the payload then it needs to be added but only for the first key split + if idx == 0 and (split not in payload): + payload.update({split: {}}) + # Check if it's the last value and figure out the override + if idx == (len(key_split) - 1): + # Handle bool(s) a bit differently as they are store_true + if isinstance(curr_ref, dict) and isinstance(value, bool): + if value is not False: + curr_ref[split] = value + # If we are at the dictionary level we should be able to just payload override + elif isinstance(curr_ref, dict) and not isinstance(value, bool): + curr_ref[split] = value + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level Dict" + ) + # If it's not keep walking the current payload + else: + curr_ref = curr_ref[split] + return payload diff --git a/spock/addons/tune/tuner.py b/spock/addons/tune/tuner.py new file mode 100644 index 00000000..438e4b97 --- /dev/null +++ b/spock/addons/tune/tuner.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner interface interface""" + +from typing import Union + +from spock.addons.tune.config import OptunaTunerConfig +from spock.addons.tune.optuna import OptunaInterface +from spock.backend.wrappers import Spockspace + + +class TunerInterface: + """Handles the general tuner interface by creating the necessary underlying tuner class and dispatches necessary + ops to the class instance + + *Attributes*: + + _fixed_namespace: fixed parameter namespace used for combination with a sample draw + _lib_interface: class instance of the underlying hyper-parameter library + + """ + + def __init__( + self, + tuner_config: Union[OptunaTunerConfig], + tuner_namespace: Spockspace, + fixed_namespace: Spockspace, + ): + """Init call to the TunerInterface + + *Args*: + + tuner_config: necessary object to determine the interface and sample correctly from the underlying library + tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + fixed_namespace: namespace of fixed parameters + + """ + self._fixed_namespace = fixed_namespace + # Todo: add ax type check here + accept_types = OptunaTunerConfig + if not isinstance(tuner_config, accept_types): + raise TypeError( + f"Passed incorrect tuner_config type of {type(tuner_config)} -- must be of type " + f"{repr(accept_types)}" + ) + if isinstance(tuner_config, OptunaTunerConfig): + self._lib_interface = OptunaInterface( + tuner_config=tuner_config, tuner_namespace=tuner_namespace + ) + # # TODO: Add ax class logic + # elif isinstance(tuner_config, (ax.Experiment, ax.SimpleExperiment)): + # pass + + def sample(self): + """Public interface to underlying library sepcific sample that returns a single sample/draw from the + hyper-parameter sets (e.g. ranges, choices) and combines them with the fixed parameters into a single Spockspace + + *Returns*: + + Spockspace of drawn sample of hyper-parameters and fixed parameters + + """ + curr_sample = self._lib_interface.sample() + # Merge w/ fixed parameters + return Spockspace(**vars(curr_sample), **vars(self._fixed_namespace)) + + @property + def tuner_status(self): + """Returns a dictionary of all the necessary underlying tuner internals to report the result""" + return self._lib_interface.tuner_status + + @property + def best(self): + """Returns a Spockspace of the best hyper-parameter config and the associated metric value""" + return self._lib_interface.best diff --git a/spock/args.py b/spock/args.py index e98032b4..d0e0109e 100644 --- a/spock/args.py +++ b/spock/args.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles import aliases to allow backwards compat with backends""" # from spock.backend.dataclass.args import * -from spock.backend.attr.typed import SavePath +from spock.backend.typed import SavePath diff --git a/spock/backend/__init__.py b/spock/backend/__init__.py index d8767c44..5d4e8add 100644 --- a/spock/backend/__init__.py +++ b/spock/backend/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -8,4 +8,5 @@ Please refer to the documentation provided in the README.md """ -__all__ = ["attr", "base"] + +__all__ = ["builder", "config", "payload", "saver", "typed"] diff --git a/spock/backend/attr/__init__.py b/spock/backend/attr/__init__.py deleted file mode 100644 index 1cd0336a..00000000 --- a/spock/backend/attr/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -""" -Spock is a framework that helps manage complex parameter configurations for Python applications - -Please refer to the documentation provided in the README.md -""" - -__all__ = ["builder", "config", "payload", "saver", "typed"] diff --git a/spock/backend/attr/builder.py b/spock/backend/attr/builder.py deleted file mode 100644 index 9b76ee7d..00000000 --- a/spock/backend/attr/builder.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles the building/saving of the configurations from the Spock config classes""" - -import attr -from enum import EnumMeta -import re -import sys -from warnings import warn -from spock.backend.base import BaseBuilder - - -class AttrBuilder(BaseBuilder): - """Attr specific builder - - Class that handles building for the attr backend - - *Attributes* - - input_classes: list of input classes that link to a backend - _configs: None or List of configs to read from - _create_save_path: boolean to make the path to save to - _desc: description for the arg parser - _no_cmd_line: flag to force no command line reads - save_path: list of path(s) to save the configs to - - """ - def __init__(self, *args, configs=None, create_save_path=False, desc='', no_cmd_line=False, **kwargs): - super().__init__(*args, configs=configs, create_save_path=create_save_path, desc=desc, - no_cmd_line=no_cmd_line, **kwargs) - for arg in self.input_classes: - if not attr.has(arg): - raise TypeError('*arg inputs to ConfigArgBuilder must all be class instances with attrs attributes') - - def print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1): - print(f'usage: {sys.argv[0]} -c [--config] config1 [config2, config3, ...]') - print(f'\n{self._desc if self._desc != "" else ""}\n') - print('configuration(s):\n') - self._handle_help_info() - if msg is not None: - print(msg) - if sys_exit: - sys.exit(exit_code) - - def _handle_help_info(self): - self._attrs_help(self.input_classes) - - def _handle_arguments(self, args, class_obj): - attr_name = class_obj.__name__ - class_names = [val.__name__ for val in self.input_classes] - # Handle repeated classes - if attr_name in class_names and attr_name in args and isinstance(args[attr_name], list): - fields = self._handle_repeated(args[attr_name], attr_name, class_names) - # Handle non-repeated classes - else: - fields = {} - for val in class_obj.__attrs_attrs__: - # Check if namespace is named and then check for key -- checking for local class def - if attr_name in args and val.name in args[attr_name]: - fields[val.name] = self._handle_nested_class(args, args[attr_name][val.name], class_names) - # If not named then just check for keys -- checking for global def - elif val.name in args: - fields[val.name] = self._handle_nested_class(args, args[val.name], class_names) - # Check for special keys to set - if 'special_key' in val.metadata and val.metadata['special_key'] is not None: - if val.name in args: - self.save_path = args[val.name] - elif val.default is not None: - self.save_path = val.default - return fields - - def _handle_repeated(self, args, check_value, class_names): - """Handles repeated classes as lists - - *Args*: - - args: dictionary of arguments from the configs - check_value: value to check classes against - class_names: current class names - - *Returns*: - - list of input_class[match)idx[0]] types filled with repeated values - - """ - # Check to see if the value trying to be set is actually an input class - match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] - return [self.input_classes[match_idx[0]](**val) for val in args] - - def _handle_nested_class(self, args, check_value, class_names): - """Handles passing another class to the field dictionary - - *Args*: - args: dictionary of arguments from the configs - check_value: value to check classes against - class_names: current class names - - *Returns*: - - either the check_value or the necessary class - - """ - # Check to see if the value trying to be set is actually an input class - match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] - # If so then create the needed class object by unrolling the args to **kwargs and return it - if len(match_idx) > 0: - if len(match_idx) > 1: - raise ValueError('Match error -- multiple classes with the same name definition') - else: - if args.get(self.input_classes[match_idx[0]].__name__) is None: - raise ValueError(f'Missing config file definition for the referenced class ' - f'{self.input_classes[match_idx[0]].__name__}') - current_arg = args.get(self.input_classes[match_idx[0]].__name__) - if isinstance(current_arg, list): - class_value = [self.input_classes[match_idx[0]](**val) for val in current_arg] - else: - class_value = self.input_classes[match_idx[0]](**current_arg) - return_value = class_value - # else return the expected value - else: - return_value = check_value - return return_value diff --git a/spock/backend/attr/payload.py b/spock/backend/attr/payload.py deleted file mode 100644 index 4c8c6dd0..00000000 --- a/spock/backend/attr/payload.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles payloads from markup files""" - -from itertools import chain -from spock.backend.attr.utils import convert_to_tuples -from spock.backend.attr.utils import get_type_fields -from spock.backend.attr.utils import deep_update -from spock.backend.base import BasePayload - - -class AttrPayload(BasePayload): - """Handles building the payload for attrs backend - - This class builds out the payload from config files of multiple types. It handles various - file types and also composition of config files via a recursive calls - - *Attributes*: - - _loaders: maps of each file extension to the loader class - - """ - def __init__(self, s3_config=None): - super().__init__(s3_config=s3_config) - - def __call__(self, *args, **kwargs): - """Call to allow self chaining - - *Args*: - - *args: - **kwargs: - - *Returns*: - - Payload: instance of self - - """ - return AttrPayload(*args, **kwargs) - - @staticmethod - def _update_payload(base_payload, input_classes, payload): - # Get basic args - attr_fields = {attr.__name__: [val.name for val in attr.__attrs_attrs__] for attr in input_classes} - # Class names - class_names = [val.__name__ for val in input_classes] - # Parse out the types if generic - type_fields = get_type_fields(input_classes) - for keys, values in base_payload.items(): - # check if the keys, value pair is expected by the attr class - if keys != 'config': - # Dict infers that we are overriding a global setting in a specific config - if isinstance(values, dict): - # we're in a namespace - # Check for incorrect specific override of global def - if keys not in attr_fields: - raise TypeError(f'Referring to a class space {keys} that is undefined') - for i_keys in values.keys(): - if i_keys not in attr_fields[keys]: - raise ValueError(f'Provided an unknown argument named {keys}.{i_keys}') - else: - # Check if the key is actually a reference to another class - if keys in class_names: - if isinstance(values, list): - # Check for incorrect specific override of global def - if keys not in attr_fields: - raise ValueError(f'Referring to a class space {keys} that is undefined') - # We are in a repeated class def - # Raise if the key set is different from the defined set (i.e. incorrect arguments) - key_set = set(list(chain(*[list(val.keys()) for val in values]))) - for i_keys in key_set: - if i_keys not in attr_fields[keys]: - raise ValueError(f'Provided an unknown argument named {keys}.{i_keys}') - # Chain all the values from multiple spock classes into one list - elif keys not in list(chain(*attr_fields.values())): - raise ValueError(f'Provided an unknown argument named {keys}') - # Chain all the values from multiple spock classes into one list - elif keys not in list(chain(*attr_fields.values())): - raise ValueError(f'Provided an unknown argument named {keys}') - if keys in payload and isinstance(values, dict): - payload[keys].update(values) - else: - payload[keys] = values - tuple_payload = convert_to_tuples(payload, type_fields, class_names) - payload = deep_update(payload, tuple_payload) - return payload diff --git a/spock/backend/attr/saver.py b/spock/backend/attr/saver.py deleted file mode 100644 index a0c45d06..00000000 --- a/spock/backend/attr/saver.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles prepping and saving the Spock config""" - -import attr -from spock.backend.base import BaseSaver - - -class AttrSaver(BaseSaver): - """Base class for saving configs for the attrs backend - - Contains methods to build a correct output payload and then writes to file based on the file - extension - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - - """ - def __init__(self, s3_config=None): - super().__init__(s3_config=s3_config) - - def __call__(self, *args, **kwargs): - return AttrSaver(*args, **kwargs) - - def _clean_up_values(self, payload, file_extension): - # Dictionary to recursively write to - out_dict = {} - # All of the classes are defined at the top level - all_spock_cls = set(vars(payload).keys()) - out_dict = self._recursively_handle_clean(payload, out_dict, all_cls=all_spock_cls) - # Convert values - clean_dict = self._clean_output(out_dict) - return clean_dict - - def _recursively_handle_clean(self, payload, out_dict, parent_name=None, all_cls=None): - """Recursively works through spock classes and adds clean data to a dictionary - - Given a payload (Spockspace) work recursively through items that don't have parents to catch all - parameter definitions while correctly mapping nested class definitions to their base level class thus - allowing the output markdown to be a valid input file - - *Args*: - - payload: current payload (namespace) - out_dict: output dictionary - parent_name: name of the parent spock class if nested - all_cls: all top level spock class definitions - - *Returns*: - - out_dict: modified dictionary with the cleaned data - - """ - for key, val in vars(payload).items(): - val_name = type(val).__name__ - # This catches basic lists and list of classes - if isinstance(val, list): - # Check if each entry is a spock class - clean_val = [] - repeat_flag = False - for l_val in val: - cls_name = type(l_val).__name__ - # For those that are a spock class and are repeated (cls_name == key) simply convert to dict - if (cls_name in all_cls) and (cls_name == key): - clean_val.append(attr.asdict(l_val)) - # For those whose cls is different than the key just append the cls name - elif cls_name in all_cls: - # Change the flag as this is a repeated class -- which needs to be compressed into a single - # k:v pair - repeat_flag = True - clean_val.append(cls_name) - # Fall back to the passed in values - else: - clean_val.append(l_val) - # Handle repeated classes - if repeat_flag: - clean_val = list(set(clean_val))[-1] - out_dict.update({key: clean_val}) - # If it's a spock class but has a parent then just use the class name to reference the values - elif(val_name in all_cls) and parent_name is not None: - out_dict.update({key: val_name}) - # Check if it's a spock class without a parent -- iterate the values and recurse to catch more lists - elif val_name in all_cls: - new_dict = self._recursively_handle_clean(val, {}, parent_name=key, all_cls=all_cls) - out_dict.update({key: new_dict}) - # Either base type or no nested values that could be Spock classes - else: - out_dict.update({key: val}) - return out_dict diff --git a/spock/backend/base.py b/spock/backend/base.py deleted file mode 100644 index a6bd6364..00000000 --- a/spock/backend/base.py +++ /dev/null @@ -1,928 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles base Spock classes""" - -from abc import ABC -from abc import abstractmethod -import argparse -import attr -from attr import NOTHING -from enum import EnumMeta -import os -from pathlib import Path -import re -import sys -from uuid import uuid1 -import yaml -from spock.handlers import JSONHandler -from spock.handlers import TOMLHandler -from spock.handlers import YAMLHandler -from spock.utils import add_info -from spock.utils import check_path_s3 -from spock.utils import make_argument -from typing import List - - -class Spockspace(argparse.Namespace): - """Inherits from Namespace to implement a pretty print on the obj - - Overwrites the __repr__ method with a pretty version of printing - - """ - def __init__(self, **kwargs): - super(Spockspace, self).__init__(**kwargs) - - def __repr__(self): - # Remove aliases in YAML print - yaml.Dumper.ignore_aliases = lambda *args: True - return yaml.dump(self.__dict__, default_flow_style=False) - - -class BaseHandler(ABC): - """Base class for saver and payload - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - _s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - self._supported_extensions = {'.yaml': YAMLHandler, '.toml': TOMLHandler, '.json': JSONHandler} - self._s3_config = s3_config - - def _check_extension(self, file_extension: str): - if file_extension not in self._supported_extensions: - raise TypeError(f'File extension {file_extension} not supported -- \n' - f'File extension must be from {list(self._supported_extensions.keys())}') - - -class BaseSaver(BaseHandler): # pylint: disable=too-few-public-methods - """Base class for saving configs - - Contains methods to build a correct output payload and then writes to file based on the file - extension - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - _s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - super(BaseSaver, self).__init__(s3_config=s3_config) - - def save(self, payload, path, file_name=None, create_save_path=False, extra_info=True, file_extension='.yaml'): #pylint: disable=too-many-arguments - """Writes Spock config to file - - Cleans and builds an output payload and then correctly writes it to file based on the - specified file extension - - *Args*: - - payload: current config payload - path: path to save - file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None - create_save_path: boolean to create the path if non-existent - extra_info: boolean to write extra info - file_extension: what type of file to write - - *Returns*: - - None - - """ - # Check extension - self._check_extension(file_extension=file_extension) - # Make the filename -- always append a uuid for unique-ness - uuid_str = str(uuid1()) - fname = '' if file_name is None else f'{file_name}.' - name = f'{fname}{uuid_str}.spock.cfg{file_extension}' - # Fix up values -- parameters - out_dict = self._clean_up_values(payload, file_extension) - # Get extra info - extra_dict = add_info() if extra_info else None - try: - self._supported_extensions.get(file_extension)().save( - out_dict=out_dict, info_dict=extra_dict, path=str(path), name=name, - create_path=create_save_path, s3_config=self._s3_config - ) - except OSError as e: - print(f'Unable to write to given path: {path / name}') - raise e - - @abstractmethod - def _clean_up_values(self, payload, file_extension): - """Clean up the config payload so it can be written to file - - *Args*: - - payload: dirty payload - extra_info: boolean to add extra info - file_extension: type of file to write - - *Returns*: - - clean_dict: cleaned output payload - - """ - - def _clean_output(self, out_dict): - """Clean up the dictionary so it can be written to file - - *Args*: - - out_dict: cleaned dictionary - extra_info: boolean to add extra info - - *Returns*: - - clean_dict: cleaned output payload - - """ - # Convert values - clean_dict = {} - for key, val in out_dict.items(): - clean_inner_dict = {} - if isinstance(val, list): - for idx, list_val in enumerate(val): - tmp_dict = {} - for inner_key, inner_val in list_val.items(): - tmp_dict = self._convert(tmp_dict, inner_val, inner_key) - val[idx] = tmp_dict - clean_inner_dict = val - else: - for inner_key, inner_val in val.items(): - clean_inner_dict = self._convert(clean_inner_dict, inner_val, inner_key) - clean_dict.update({key: clean_inner_dict}) - return clean_dict - - def _convert(self, clean_inner_dict, inner_val, inner_key): - # Convert tuples to lists so they get written correctly - if isinstance(inner_val, tuple): - clean_inner_dict.update({inner_key: self._recursive_tuple_to_list(inner_val)}) - elif inner_val is not None: - clean_inner_dict.update({inner_key: inner_val}) - return clean_inner_dict - - def _recursive_tuple_to_list(self, value): - """Recursively turn tuples into lists - - Recursively looks through tuple(s) and convert to lists - - *Args*: - - value: value to check and set typ if necessary - typed: type of the generic alias to check against - - *Returns*: - - value: updated value with correct type casts - - """ - # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple - # from a composed payload - list_v = [] - for v in value: - if isinstance(v, tuple): - v = self._recursive_tuple_to_list(v) - list_v.append(v) - else: - list_v.append(v) - return list_v - - -class BaseBuilder(ABC): # pylint: disable=too-few-public-methods - """Base class for building the backend specific builders - - This class handles the interface to the backend with the generic ConfigArgBuilder so that different - backends can be used to handle processing - - *Attributes* - - input_classes: list of input classes that link to a backend - _configs: None or List of configs to read from - _create_save_path: boolean to make the path to save to - _desc: description for the arg parser - _no_cmd_line: flag to force no command line reads - _max_indent: maximum to indent between help prints - save_path: list of path(s) to save the configs to - - """ - def __init__(self, *args, configs=None, create_save_path=False, desc='', no_cmd_line=False, - max_indent=4, **kwargs): - self.input_classes = args - self._configs = configs - self._create_save_path = create_save_path - self._desc = desc - self._no_cmd_line = no_cmd_line - self._max_indent = max_indent - self.save_path = None - - @abstractmethod - def print_usage_and_exit(self, msg=None, sys_exit=True): - """Prints the help message and exits - - *Args*: - - msg: message to print pre exit - - *Returns*: - - None - - """ - - @abstractmethod - def _handle_help_info(self): - """Handles walking through classes to get help info - - For each class this function will search __doc__ and attempt to pull out help information for both the class - itself and each attribute within the class - - *Returns*: - - None - - """ - - @abstractmethod - def _handle_arguments(self, args, class_obj): - """Handles all argument mapping - - Creates a dictionary of named parameters that are mapped to the final type of object - - *Args*: - - args: read file arguments - class_obj: instance of a class obj - - *Returns*: - - fields: dictionary of mapped parameters - - """ - - def generate(self, dict_args): - """Method to auto-generate the actual class instances from the generated args - - Based on the generated arguments groups and the args read in from the config file(s) - this function instantiates the classes with the necessary field or attr values - - *Args*: - - dict_args: dictionary of arguments from the configs - - *Returns*: - - namespace containing automatically generated instances of the classes - """ - auto_dict = {} - for attr_classes in self.input_classes: - attr_build = self._auto_generate(dict_args, attr_classes) - if isinstance(attr_build, list): - class_name = list({type(val).__name__ for val in attr_build}) - if len(class_name) > 1: - raise ValueError('Repeated class has more than one unique name') - auto_dict.update({class_name[0]: attr_build}) - else: - auto_dict.update({type(attr_build).__name__: attr_build}) - return Spockspace(**auto_dict) - # return argparse.Namespace(**auto_dict) - - def _auto_generate(self, args, input_class): - """Builds an instance of a DataClass - - Builds an instance with the necessary field values from the argument - dictionary read from the config file(s) - - *Args*: - - args: dictionary of arguments read from the config file(s) - data_class: data class to build - - *Returns*: - - An instance of data_class with correct values assigned to fields - """ - # Handle the basic data types - fields = self._handle_arguments(args, input_class) - if isinstance(fields, list): - return_value = fields - else: - self._handle_late_defaults(args, fields, input_class) - return_value = input_class(**fields) - return return_value - - def _handle_late_defaults(self, args, fields, input_class): - """Handles late defaults when the type is non-standard - - If the default type is not a base python type then we need to catch those defaults here and build the correct - values from the input classes while maintaining the optional nature. The trick is to exclude all 'base' types - as these defaults are covered by the attr default value - - *Args*: - - args: dictionary of arguments read from the config file(s) - fields: current fields returned from _handle_arguments - input_class: which input class being checked for late defaults - - *Returns*: - - fields: updated field dictionary with late defaults set - - """ - names = [val.name for val in input_class.__attrs_attrs__] - class_names = [val.__name__ for val in self.input_classes] - field_list = list(fields.keys()) - arg_list = list(args.keys()) - # Exclude all the base types that are supported -- these can be set by attrs - exclude_list = ['_Nothing', 'NoneType', 'bool', 'int', 'float', 'str', 'list', 'tuple'] - for val in names: - if val not in field_list: - default_type_name = type(getattr(input_class.__attrs_attrs__, val).default).__name__ - if default_type_name not in exclude_list: - default_name = getattr(input_class.__attrs_attrs__, val).default.__name__ - else: - default_name = None - if default_name is not None and default_name in arg_list: - if isinstance(args.get(default_name), list): - default_value = [self.input_classes[class_names.index(default_name)](**arg_val) - for arg_val in args.get(default_name)] - else: - default_value = self.input_classes[class_names.index(default_name)](**args.get(default_name)) - fields.update({val: default_value}) - return fields - - def get_config_paths(self): - """Get config paths from all methods - - Config paths can enter from either the command line or be added in the class init call - as a kwarg (configs=[]) - - *Returns*: - - args: namespace of args - - """ - # Check if the no_cmd_line is not flagged and if the configs are not empty - - if self._no_cmd_line and (self._configs is None): - raise ValueError("Flag set for preventing command line read but no paths were passed to the config kwarg") - if not self._no_cmd_line: - args = self._build_override_parsers(desc=self._desc) - else: - args = argparse.Namespace(config=[], help=False) - if self._configs is not None: - args = self._get_from_kwargs(args, self._configs) - return args - - def _build_override_parsers(self, desc): - """Creates parsers for command-line overrides - - Builds the basic command line parser for configs and help then iterates through each attr instance to make - namespace specific cmd line override parsers - - *Args*: - - desc: argparser description - - *Returns*: - - args: argument namespace - - """ - parser = argparse.ArgumentParser(description=desc, add_help=False) - parser.add_argument('-c', '--config', required=False, nargs='+', default=[]) - parser.add_argument('-h', '--help', action='store_true') - # Build out each class override specific parser - for val in self.input_classes: - parser = self._make_group_override_parser(parser=parser, class_obj=val) - args = parser.parse_args() - return args - - def _make_group_override_parser(self, parser, class_obj): - """Makes a name specific override parser for a given class obj - - Takes a class object of the backend and adds a new argument group with argument names given with name - Class.name so that individual parameters specific to a class can be overridden. - - *Args*: - - parser: argument parser - class_obj: instance of a backend class - - *Returns*: - - parser: argument parser with new class specific overrides - - """ - attr_name = class_obj.__name__ - group_parser = parser.add_argument_group(title=str(attr_name) + " Specific Overrides") - for val in class_obj.__attrs_attrs__: - val_type = val.metadata['type'] if 'type' in val.metadata else val.type - # Check if the val type has __args__ - # TODO (ncilfone): Fix up this super super ugly logic - if hasattr(val_type, '__args__') and ((list(set(val_type.__args__))[0]).__module__ == 'spock.backend.attr.config') and attr.has((list(set(val_type.__args__))[0])): - args = (list(set(val_type.__args__))[0]) - for inner_val in args.__attrs_attrs__: - arg_name = f"--{str(attr_name)}.{val.name}.{args.__name__}.{inner_val.name}" - group_parser = make_argument(arg_name, List[inner_val.type], group_parser) - else: - arg_name = f"--{str(attr_name)}.{val.name}" - group_parser = make_argument(arg_name, val_type, group_parser) - return parser - - @staticmethod - def _get_from_kwargs(args, configs): - """Get configs from the configs kwarg - - - *Args*: - - args: argument namespace - configs: config kwarg - - *Returns*: - - args: arg namespace - - """ - if type(configs).__name__ == 'list': - args.config.extend(configs) - else: - raise TypeError(f'configs kwarg must be of type list -- given {type(configs)}') - return args - - @staticmethod - def _find_attribute_idx(newline_split_docs): - """Finds the possible split between the header and Attribute annotations - - *Args*: - - newline_split_docs: new line split text - - Returns: - - idx: -1 if none or the idx of Attributes - - """ - for idx, val in enumerate(newline_split_docs): - re_check = re.search(r'(?i)Attribute?s?:', val) - if re_check is not None: - return idx - return -1 - - def _split_docs(self, obj): - """Possibly splits head class doc string from attribute docstrings - - Attempts to find the first contiguous line within the Google style docstring to use as the class docstring. - Splits the docs base on the Attributes tag if present. - - *Args*: - - obj: class object to rip info from - - *Returns*: - - class_doc: class docstring if present or blank str - attr_doc: list of attribute doc strings - - """ - if obj.__doc__ is not None: - # Split by new line - newline_split_docs = obj.__doc__.split('\n') - # Cleanup l/t whitespace - newline_split_docs = [val.strip() for val in newline_split_docs] - else: - newline_split_docs = [] - # Find the break between the class docs and the Attribute section -- if this returns -1 then there is no - # Attributes section - attr_idx = self._find_attribute_idx(newline_split_docs) - head_docs = newline_split_docs[:attr_idx] if attr_idx != -1 else newline_split_docs - attr_docs = newline_split_docs[attr_idx:] if attr_idx != -1 else [] - # Grab only the first contiguous line as everything else will probably be too verbose (e.g. the - # mid-level docstring that has detailed descriptions - class_doc = '' - for idx, val in enumerate(head_docs): - class_doc += f' {val}' - if idx + 1 != len(head_docs) and head_docs[idx + 1] == '': - break - # Clean up any l/t whitespace - class_doc = class_doc.strip() - return class_doc, attr_docs - - @staticmethod - def _match_attribute_docs(attr_name, attr_docs, attr_type_str, attr_default=NOTHING): - """Matches class attributes with attribute docstrings via regex - - *Args*: - - attr_name: attribute name - attr_docs: list of attribute docstrings - attr_type_str: str representation of the attribute type - attr_default: str representation of a possible default value - - *Returns*: - - dictionary of packed attribute information - - """ - # Regex match each value - a_str = None - for a_doc in attr_docs: - match_re = re.search(r'(?i)^' + attr_name + '?:', a_doc) - # Find only the first match -- if more than one than ignore - if match_re: - a_str = a_doc[match_re.end():].strip() - return {attr_name: { - 'type': attr_type_str, - 'desc': a_str if a_str is not None else "", - 'default': "(default: " + repr(attr_default) + ")" if type(attr_default).__name__ != '_Nothing' - else "", - 'len': {'name': len(attr_name), 'type': len(attr_type_str)} - }} - - def _handle_attributes_print(self, info_dict): - """Prints attribute information in an argparser style format - - *Args*: - - info_dict: packed attribute info dictionary to print - - """ - # Figure out indents - max_param_length = max([len(k) for k in info_dict.keys()]) - max_type_length = max([v['len']['type'] for v in info_dict.values()]) - # Print akin to the argparser - for k, v in info_dict.items(): - print(f' {k}' + (' ' * (max_param_length - v["len"]["name"] + self._max_indent)) + - f'{v["type"]}' + (' ' * (max_type_length - v["len"]["type"] + self._max_indent)) + - f'{v["desc"]} {v["default"]}') - # Blank for spacing :-/ - print('') - - def _extract_other_types(self, typed): - """Takes a high level type and recursively extracts any enum or class types - - *Args*: - - typed: highest level type - - *Returns*: - - return_list: list of nums (dot notation of module_path.enum_name or module_path.class_name) - - """ - return_list = [] - if hasattr(typed, '__args__'): - for val in typed.__args__: - recurse_return = self._extract_other_types(val) - if isinstance(recurse_return, list): - return_list.extend(recurse_return) - else: - return_list.append(self._extract_other_types(val)) - elif isinstance(typed, EnumMeta) or (typed.__module__ == 'spock.backend.attr.config'): - return f'{typed.__module__}.{typed.__name__}' - return return_list - - def _attrs_help(self, input_classes): - """Handles walking through a list classes to get help info - - For each class this function will search __doc__ and attempt to pull out help information for both the class - itself and each attribute within the class. If it finds a repeated class in a iterable object it will - recursively call self to handle information - - *Args*: - - input_classes: list of attr classes - - *Returns*: - - None - - """ - # List to catch Enums and classes and handle post spock wrapped attr classes - other_list = [] - covered_set = set() - for attrs_class in input_classes: - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(attrs_class) - print(' ' + attrs_class.__name__ + f' ({class_doc})') - # Keep a running info_dict of all the attribute level info - info_dict = {} - for val in attrs_class.__attrs_attrs__: - # If the type is an enum we need to handle it outside of this attr loop - # Match the style of nested enums and return a string of module.name notation - if isinstance(val.type, EnumMeta): - other_list.append(f'{val.type.__module__}.{val.type.__name__}') - # if there is a type (implied Iterable) -- check it for nested Enums or classes - nested_others = self._extract_other_types(val.metadata['type']) if 'type' in val.metadata else [] - if len(nested_others) > 0: - other_list.extend(nested_others) - # Grab the base or type info depending on what is provided - type_string = repr(val.metadata['type']) if 'type' in val.metadata else val.metadata['base'] - # Regex out the typing info if present - type_string = re.sub(r'typing.', '', type_string) - # Regex out any nested_others that have module path information - for other_val in nested_others: - split_other = f"{'.'.join(other_val.split('.')[:-1])}." - type_string = re.sub(split_other, '', type_string) - # Regex the string to see if it matches any Enums in the __main__ module space - # for val in sys.modules - # Construct the type with the metadata - if 'optional' in val.metadata: - type_string = f"Optional[{type_string}]" - info_dict.update(self._match_attribute_docs(val.name, attr_docs, type_string, val.default)) - # Add to covered so we don't print help twice in the case of some recursive nesting - covered_set.add(f'{attrs_class.__module__}.{attrs_class.__name__}') - self._handle_attributes_print(info_dict=info_dict) - # Convert the enum list to a set to remove dupes and then back to a list so it is iterable -- set diff to not - # repeat - other_list = list(set(other_list) - covered_set) - # Iterate any Enum type classes - for other in other_list: - # if it's longer than 2 then it's an embedded Spock class - if '.'.join(other.split('.')[:-1]) == 'spock.backend.attr.config': - class_type = self._get_from_sys_modules(other) - # Invoke recursive call for the class - self._attrs_help([class_type]) - # Fall back to enum style - else: - enum = self._get_from_sys_modules(other) - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(enum) - print(' ' + enum.__name__ + f' ({class_doc})') - info_dict = {} - for val in enum: - info_dict.update(self._match_attribute_docs( - attr_name=val.name, - attr_docs=attr_docs, - attr_type_str=type(val.value).__name__ - )) - self._handle_attributes_print(info_dict=info_dict) - - @staticmethod - def _get_from_sys_modules(cls_name): - """Gets the class from a dot notation name - - *Args*: - - cls_name: dot notation enum name - - *Returns*: - - module: enum class - - """ - # Split on dot notation - split_string = cls_name.split('.') - module = None - for idx, val in enumerate(split_string): - # idx = 0 will always be a call to the sys.modules dict - if idx == 0: - module = sys.modules[val] - # all other idx are paths along the module that need to be traversed - # idx = -1 will always be the final Enum object name we want to grab (final getattr call) - else: - module = getattr(module, val) - return module - - -class BasePayload(BaseHandler): # pylint: disable=too-few-public-methods - """Handles building the payload for config file(s) - - This class builds out the payload from config files of multiple types. It handles various - file types and also composition of config files via recursive calls - - *Attributes*: - - _loaders: maps of each file extension to the loader class - __s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - super(BasePayload, self).__init__(s3_config=s3_config) - - @staticmethod - @abstractmethod - def _update_payload(base_payload, input_classes, payload): - """Updates the payload - - Checks the parameters defined in the config files against the provided classes and if - passable adds them to the payload - - *Args*: - - base_payload: current payload - input_classes: class to roll into - payload: total payload - - *Returns*: - - payload: updated payload - - """ - - def payload(self, input_classes, path, cmd_args, deps): - """Builds the payload from config files - - Public exposed call to build the payload and set any command line overrides - - *Args*: - - input_classes: list of backend classes - path: path to config file(s) - cmd_args: command line overrides - deps: dictionary of config dependencies - - *Returns*: - - payload: dictionary of all mapped parameters - - """ - payload = self._payload(input_classes, path, deps, root=True) - payload = self._handle_overrides(payload, cmd_args) - return payload - - def _payload(self, input_classes, path, deps, root=False): - """Private call to construct the payload - - Main function call that builds out the payload from config files of multiple types. It handles - various file types and also composition of config files via a recursive calls - - *Args*: - input_classes: list of backend classes - path: path to config file(s) - deps: dictionary of config dependencies - - *Returns*: - - payload: dictionary of all mapped parameters - - """ - # Match to loader based on file-extension - config_extension = Path(path).suffix.lower() - # Verify extension - self._check_extension(file_extension=config_extension) - # Load from file - base_payload = self._supported_extensions.get(config_extension)().load(path, s3_config=self._s3_config) - # Check and? update the dependencies - deps = self._handle_dependencies(deps, path, root) - payload = {} - if 'config' in base_payload: - payload = self._handle_includes( - base_payload, config_extension, input_classes, path, payload, deps) - payload = self._update_payload(base_payload, input_classes, payload) - return payload - - @staticmethod - def _handle_dependencies(deps, path, root): - """Handles config file dependencies - - Checks to see if the config path (full or relative) has already been encountered. Essentially a DFS for graph - cycles - - *Args*: - - deps: dictionary of config dependencies - path: current config path - root: boolean if root - - *Returns*: - - deps: updated dependencies - - """ - if root and path in deps.get('paths'): - raise ValueError(f'Duplicate Read -- Config file {path} has already been encountered. ' - f'Please remove duplicate reads of config files.') - elif path in deps.get('paths') or path in deps.get('rel_paths'): - raise ValueError(f'Cyclical Dependency -- Config file {path} has already been encountered. ' - f'Please remove cyclical dependencies between config files.') - else: - # Update the dependency lists - deps.get('paths').append(path) - deps.get('rel_paths').append(os.path.basename(path)) - if root: - deps.get('roots').append(path) - return deps - - def _handle_includes(self, base_payload, config_extension, input_classes, path, payload, deps): # pylint: disable=too-many-arguments - """Handles config composition - - For all of the config tags in the config file this function will recursively call the payload function - with the composition path to get the additional payload(s) from the composed file(s) -- checks for file - validity or if it is an S3 URI via regex - - *Args*: - - base_payload: base payload that has a config kwarg - config_extension: file type - input_classes: defined backend classes - path: path to base file - payload: payload pulled from composed files - deps: dictionary of config dependencies - - *Returns*: - - payload: payload update from composed files - - """ - included_params = {} - for inc_path in base_payload['config']: - if check_path_s3(inc_path): - use_path = inc_path - elif os.path.exists(inc_path): - use_path = inc_path - elif os.path.join(os.path.dirname(path), inc_path): - use_path = os.path.join(os.path.dirname(path), inc_path) - else: - raise RuntimeError(f'Could not find included {config_extension} file {inc_path} or is not an S3 URI!') - included_params.update(self._payload(input_classes, use_path, deps)) - payload.update(included_params) - return payload - - def _handle_overrides(self, payload, args): - """Handle command line overrides - - Iterate through the command line override values, determine at what level to set them, and set them if possible - - *Args*: - - payload: current payload dictionary - args: command line override args - - *Returns*: - - payload: updated payload dictionary with override values set - - """ - skip_keys = ['config', 'help'] - for k, v in vars(args).items(): - if k not in skip_keys and v is not None: - payload = self._handle_payload_override(payload, k, v) - return payload - - @staticmethod - def _handle_payload_override(payload, key, value): - """Handles the complex logic needed for List[spock class] overrides - - Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str - names have to be mapped backed to sys.modules and can be set at either the general or class level. - - *Args*: - - payload: current payload dictionary - key: current arg key - value: value at current arg key - - *Returns*: - - payload: modified payload with overrides - - """ - key_split = key.split('.') - curr_ref = payload - for idx, split in enumerate(key_split): - # If the root isn't in the payload then it needs to be added but only for the first key split - if idx == 0 and (split not in payload): - payload.update({split: {}}) - # Check for curr_ref switch over -- verify by checking the sys modules names - if idx != 0 and (split in payload) and (isinstance(curr_ref, str)) and (hasattr(sys.modules['spock'].backend.attr.config, split)): - curr_ref = payload[split] - elif idx != 0 and (split in payload) and (isinstance(payload[split], str)) and (hasattr(sys.modules['spock'].backend.attr.config, payload[split])): - curr_ref = payload[split] - # elif check if it's the last value and figure out the override - elif idx == (len(key_split)-1): - # Handle bool(s) a bit differently as they are store_true - if isinstance(curr_ref, dict) and isinstance(value, bool): - if value is not False: - curr_ref[split] = value - # If we are at the dictionary level we should be able to just payload override - elif isinstance(curr_ref, dict) and not isinstance(value, bool): - curr_ref[split] = value - # If we are at a list level it must be some form of repeated class since this is the end of the class - # tree -- check the instance type but also make sure the cmd-line override is the correct len - elif isinstance(curr_ref, list) and len(value) == len(curr_ref): - # Walk the list and check for the key - for ref_idx, val in enumerate(curr_ref): - if split in val: - val[split] = value[ref_idx] - else: - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Failed to find key {split} within lowest level List[Dict]') - elif isinstance(curr_ref, list) and len(value) != len(curr_ref): - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Specified key {split} with len {len(value)} does not match len {len(curr_ref)} ' - f'of List[Dict]') - else: - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Failed to find key {split} within lowest level Dict') - # If it's not keep walking the current payload - else: - curr_ref = curr_ref[split] - return payload diff --git a/spock/backend/builder.py b/spock/backend/builder.py new file mode 100644 index 00000000..e3ec2d1b --- /dev/null +++ b/spock/backend/builder.py @@ -0,0 +1,849 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the building/saving of the configurations from the Spock config classes""" + +import re +import sys +from abc import ABC, abstractmethod +from enum import EnumMeta +from typing import List + +import attr +from attr import NOTHING + +from spock.backend.wrappers import Spockspace +from spock.utils import make_argument + + +class BaseBuilder(ABC): # pylint: disable=too-few-public-methods + """Base class for building the backend specific builders + + This class handles the interface to the backend with the generic ConfigArgBuilder so that different + backends can be used to handle processing + + *Attributes* + + input_classes: list of input classes that link to a backend + _configs: None or List of configs to read from + _desc: description for the arg parser + _no_cmd_line: flag to force no command line reads + _max_indent: maximum to indent between help prints + save_path: list of path(s) to save the configs to + + """ + + def __init__(self, *args, max_indent=4, module_name, **kwargs): + self.input_classes = args + self._module_name = module_name + self._max_indent = max_indent + self.save_path = None + + @staticmethod + @abstractmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.name so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + + def handle_help_info(self): + """Handles walking through classes to get help info + + For each class this function will search __doc__ and attempt to pull out help information for both the class + itself and each attribute within the class + + *Returns*: + + None + + """ + self._attrs_help(self.input_classes, self._module_name) + + def _handle_arguments(self, args, class_obj): + """Handles all argument mapping + + Creates a dictionary of named parameters that are mapped to the final type of object + + *Args*: + + args: read file arguments + class_obj: instance of a class obj + + *Returns*: + + fields: dictionary of mapped parameters + + """ + attr_name = class_obj.__name__ + class_names = [val.__name__ for val in self.input_classes] + # Handle repeated classes + if ( + attr_name in class_names + and attr_name in args + and isinstance(args[attr_name], list) + ): + fields = self._handle_repeated(args[attr_name], attr_name, class_names) + # Handle non-repeated classes + else: + fields = {} + for val in class_obj.__attrs_attrs__: + # Check if namespace is named and then check for key -- checking for local class def + if attr_name in args and val.name in args[attr_name]: + fields[val.name] = self._handle_nested_class( + args, args[attr_name][val.name], class_names + ) + # If not named then just check for keys -- checking for global def + elif val.name in args: + fields[val.name] = self._handle_nested_class( + args, args[val.name], class_names + ) + # Check for special keys to set + if ( + "special_key" in val.metadata + and val.metadata["special_key"] is not None + ): + if val.name in args: + self.save_path = args[val.name] + elif val.default is not None: + self.save_path = val.default + return fields + + def _handle_repeated(self, args, check_value, class_names): + """Handles repeated classes as lists + + *Args*: + + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + list of input_class[match)idx[0]] types filled with repeated values + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + return [self.input_classes[match_idx[0]](**val) for val in args] + + def _handle_nested_class(self, args, check_value, class_names): + """Handles passing another class to the field dictionary + + *Args*: + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + either the check_value or the necessary class + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + # If so then create the needed class object by unrolling the args to **kwargs and return it + if len(match_idx) > 0: + if len(match_idx) > 1: + raise ValueError( + "Match error -- multiple classes with the same name definition" + ) + else: + if args.get(self.input_classes[match_idx[0]].__name__) is None: + raise ValueError( + f"Missing config file definition for the referenced class " + f"{self.input_classes[match_idx[0]].__name__}" + ) + current_arg = args.get(self.input_classes[match_idx[0]].__name__) + if isinstance(current_arg, list): + class_value = [ + self.input_classes[match_idx[0]](**val) for val in current_arg + ] + else: + class_value = self.input_classes[match_idx[0]](**current_arg) + return_value = class_value + # else return the expected value + else: + return_value = check_value + return return_value + + def generate(self, dict_args): + """Method to auto-generate the actual class instances from the generated args + + Based on the generated arguments groups and the args read in from the config file(s) + this function instantiates the classes with the necessary field or attr values + + *Args*: + + dict_args: dictionary of arguments from the configs + + *Returns*: + + namespace containing automatically generated instances of the classes + """ + auto_dict = {} + for attr_classes in self.input_classes: + attr_build = self._auto_generate(dict_args, attr_classes) + if isinstance(attr_build, list): + class_name = list({type(val).__name__ for val in attr_build}) + if len(class_name) > 1: + raise ValueError("Repeated class has more than one unique name") + auto_dict.update({class_name[0]: attr_build}) + else: + auto_dict.update({type(attr_build).__name__: attr_build}) + return Spockspace(**auto_dict) + + def _auto_generate(self, args, input_class): + """Builds an instance of an attr class + + Builds an instance with the necessary field values from the argument + dictionary read from the config file(s) + + *Args*: + + args: dictionary of arguments read from the config file(s) + data_class: data class to build + + *Returns*: + + An instance of data_class with correct values assigned to fields + """ + # Handle the basic data types + fields = self._handle_arguments(args, input_class) + if isinstance(fields, list): + return_value = fields + else: + self._handle_late_defaults(args, fields, input_class) + return_value = input_class(**fields) + return return_value + + def _handle_late_defaults(self, args, fields, input_class): + """Handles late defaults when the type is non-standard + + If the default type is not a base python type then we need to catch those defaults here and build the correct + values from the input classes while maintaining the optional nature. The trick is to exclude all 'base' types + as these defaults are covered by the attr default value + + *Args*: + + args: dictionary of arguments read from the config file(s) + fields: current fields returned from _handle_arguments + input_class: which input class being checked for late defaults + + *Returns*: + + fields: updated field dictionary with late defaults set + + """ + names = [val.name for val in input_class.__attrs_attrs__] + class_names = [val.__name__ for val in self.input_classes] + field_list = list(fields.keys()) + arg_list = list(args.keys()) + # Exclude all the base types that are supported -- these can be set by attrs + exclude_list = [ + "_Nothing", + "NoneType", + "bool", + "int", + "float", + "str", + "list", + "tuple", + ] + for val in names: + if val not in field_list: + default_type_name = type( + getattr(input_class.__attrs_attrs__, val).default + ).__name__ + if default_type_name not in exclude_list: + default_name = getattr( + input_class.__attrs_attrs__, val + ).default.__name__ + else: + default_name = None + if default_name is not None and default_name in arg_list: + if isinstance(args.get(default_name), list): + default_value = [ + self.input_classes[class_names.index(default_name)]( + **arg_val + ) + for arg_val in args.get(default_name) + ] + else: + default_value = self.input_classes[ + class_names.index(default_name) + ](**args.get(default_name)) + fields.update({val: default_value}) + return fields + + def build_override_parsers(self, parser): + """Creates parsers for command-line overrides + + Builds the basic command line parser for configs and help then iterates through each attr instance to make + namespace specific cmd line override parsers + + *Args*: + + parser: argument parser + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + # Build out each class override specific parser + for val in self.input_classes: + parser = self._make_group_override_parser( + parser=parser, class_obj=val, class_name=self._module_name + ) + return parser + + @staticmethod + def _get_from_kwargs(args, configs): + """Get configs from the configs kwarg + + *Args*: + + args: argument namespace + configs: config kwarg + + *Returns*: + + args: arg namespace + + """ + if isinstance(configs, list): + args.config.extend(configs) + else: + raise TypeError( + f"configs kwarg must be of type list -- given {type(configs)}" + ) + return args + + @staticmethod + def _find_attribute_idx(newline_split_docs): + """Finds the possible split between the header and Attribute annotations + + *Args*: + + newline_split_docs: new line split text + + Returns: + + idx: -1 if none or the idx of Attributes + + """ + for idx, val in enumerate(newline_split_docs): + re_check = re.search(r"(?i)Attribute?s?:", val) + if re_check is not None: + return idx + return -1 + + def _split_docs(self, obj): + """Possibly splits head class doc string from attribute docstrings + + Attempts to find the first contiguous line within the Google style docstring to use as the class docstring. + Splits the docs base on the Attributes tag if present. + + *Args*: + + obj: class object to rip info from + + *Returns*: + + class_doc: class docstring if present or blank str + attr_doc: list of attribute doc strings + + """ + if obj.__doc__ is not None: + # Split by new line + newline_split_docs = obj.__doc__.split("\n") + # Cleanup l/t whitespace + newline_split_docs = [val.strip() for val in newline_split_docs] + else: + newline_split_docs = [] + # Find the break between the class docs and the Attribute section -- if this returns -1 then there is no + # Attributes section + attr_idx = self._find_attribute_idx(newline_split_docs) + head_docs = ( + newline_split_docs[:attr_idx] if attr_idx != -1 else newline_split_docs + ) + attr_docs = newline_split_docs[attr_idx:] if attr_idx != -1 else [] + # Grab only the first contiguous line as everything else will probably be too verbose (e.g. the + # mid-level docstring that has detailed descriptions + class_doc = "" + for idx, val in enumerate(head_docs): + class_doc += f" {val}" + if idx + 1 != len(head_docs) and head_docs[idx + 1] == "": + break + # Clean up any l/t whitespace + class_doc = class_doc.strip() + if len(class_doc) > 0: + class_doc = f"-- {class_doc}" + return class_doc, attr_docs + + @staticmethod + def _match_attribute_docs( + attr_name, attr_docs, attr_type_str, attr_default=NOTHING + ): + """Matches class attributes with attribute docstrings via regex + + *Args*: + + attr_name: attribute name + attr_docs: list of attribute docstrings + attr_type_str: str representation of the attribute type + attr_default: str representation of a possible default value + + *Returns*: + + dictionary of packed attribute information + + """ + # Regex match each value + a_str = None + for a_doc in attr_docs: + match_re = re.search(r"(?i)^" + attr_name + "?:", a_doc) + # Find only the first match -- if more than one than ignore + if match_re: + a_str = a_doc[match_re.end() :].strip() + return { + attr_name: { + "type": attr_type_str, + "desc": a_str if a_str is not None else "", + "default": "(default: " + repr(attr_default) + ")" + if type(attr_default).__name__ != "_Nothing" + else "", + "len": {"name": len(attr_name), "type": len(attr_type_str)}, + } + } + + def _handle_attributes_print(self, info_dict): + """Prints attribute information in an argparser style format + + *Args*: + + info_dict: packed attribute info dictionary to print + + """ + # Figure out indents + max_param_length = max([len(k) for k in info_dict.keys()]) + max_type_length = max([v["len"]["type"] for v in info_dict.values()]) + # Print akin to the argparser + for k, v in info_dict.items(): + print( + f" {k}" + + (" " * (max_param_length - v["len"]["name"] + self._max_indent)) + + f'{v["type"]}' + + (" " * (max_type_length - v["len"]["type"] + self._max_indent)) + + f'{v["desc"]} {v["default"]}' + ) + # Blank for spacing :-/ + print("") + + def _extract_other_types(self, typed, module_name): + """Takes a high level type and recursively extracts any enum or class types + + *Args*: + + typed: highest level type + module_name: name of module to match + + *Returns*: + + return_list: list of nums (dot notation of module_path.enum_name or module_path.class_name) + + """ + return_list = [] + if hasattr(typed, "__args__"): + for val in typed.__args__: + recurse_return = self._extract_other_types(val, module_name) + if isinstance(recurse_return, list): + return_list.extend(recurse_return) + else: + return_list.append(self._extract_other_types(val, module_name)) + elif isinstance(typed, EnumMeta) or (typed.__module__ == module_name): + return [f"{typed.__module__}.{typed.__name__}"] + return return_list + + def _attrs_help(self, input_classes, module_name): + """Handles walking through a list classes to get help info + + For each class this function will search __doc__ and attempt to pull out help information for both the class + itself and each attribute within the class. If it finds a repeated class in a iterable object it will + recursively call self to handle information + + *Args*: + + input_classes: list of attr classes + module_name: name of module to match + + *Returns*: + + None + + """ + # Handle the main loop + other_list = self._handle_help_main(input_classes, module_name) + self._handle_help_enums(other_list=other_list, module_name=module_name) + + @staticmethod + def _get_type_string(val, nested_others): + """Gets the type of the attr val as a string + + *Args*: + + val: current attr being processed + nested_others: list of nested others to deal with that might have module path info in the string + + *Returns*: + + type_string: type of the attr as a str + + """ + # Grab the base or type info depending on what is provided + if "type" in val.metadata: + type_string = repr(val.metadata["type"]) + elif "base" in val.metadata: + type_string = val.metadata["base"] + elif hasattr(val.type, "__name__"): + type_string = val.type.__name__ + else: + type_string = str(val.type) + # Regex out the typing info if present + type_string = re.sub(r"typing.", "", type_string) + # Regex out any nested_others that have module path information + for other_val in nested_others: + split_other = f"{'.'.join(other_val.split('.')[:-1])}." + type_string = re.sub(split_other, "", type_string) + # Regex the string to see if it matches any Enums in the __main__ module space + # Construct the type with the metadata + if "optional" in val.metadata: + type_string = f"Optional[{type_string}]" + return type_string + + def _handle_help_main(self, input_classes, module_name): + """Handles the print of the main class types + + *Args*: + + input_classes: current set of input classes + module_name: module name to match + + *Returns*: + + other_list: extended list of other classes/enums to process + + """ + # List to catch Enums and classes and handle post spock wrapped attr classes + other_list = [] + covered_set = set() + for attrs_class in input_classes: + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(attrs_class) + print(" " + attrs_class.__name__ + f" {class_doc}") + # Keep a running info_dict of all the attribute level info + info_dict = {} + for val in attrs_class.__attrs_attrs__: + # If the type is an enum we need to handle it outside of this attr loop + # Match the style of nested enums and return a string of module.name notation + if isinstance(val.type, EnumMeta): + other_list.append(f"{val.type.__module__}.{val.type.__name__}") + # if there is a type (implied Iterable) -- check it for nested Enums or classes + nested_others = self._extract_fnc(val, module_name) + if len(nested_others) > 0: + other_list.extend(nested_others) + # Get the type represented as a string + type_string = self._get_type_string(val, nested_others) + info_dict.update( + self._match_attribute_docs( + val.name, attr_docs, type_string, val.default + ) + ) + # Add to covered so we don't print help twice in the case of some recursive nesting + covered_set.add(f"{attrs_class.__module__}.{attrs_class.__name__}") + self._handle_attributes_print(info_dict=info_dict) + # Convert the enum list to a set to remove dupes and then back to a list so it is iterable -- set diff to not + # repeat + return list(set(other_list) - covered_set) + + def _handle_help_enums(self, other_list, module_name): + """handles any extra enums from non main args + + *Args*: + + other_list: extended list of other classes/enums to process + module_name: module name to match + + *Returns*: + + None + + """ + # Iterate any Enum type classes + for other in other_list: + # if it's longer than 2 then it's an embedded Spock class + if ".".join(other.split(".")[:-1]) == module_name: + class_type = self._get_from_sys_modules(other) + # Invoke recursive call for the class + self._attrs_help([class_type], module_name) + # Fall back to enum style + else: + enum = self._get_from_sys_modules(other) + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(enum) + print(" " + enum.__name__ + f" ({class_doc})") + info_dict = {} + for val in enum: + info_dict.update( + self._match_attribute_docs( + attr_name=val.name, + attr_docs=attr_docs, + attr_type_str=type(val.value).__name__, + ) + ) + self._handle_attributes_print(info_dict=info_dict) + + @abstractmethod + def _extract_fnc(self, val, module_name): + """Function that gets the nested lists within classes + + *Args*: + + val: current attr + module_name: matching module name + + *Returns*: + + list of any nested classes/enums + + """ + + @staticmethod + def _get_from_sys_modules(cls_name): + """Gets the class from a dot notation name + + *Args*: + + cls_name: dot notation enum name + + *Returns*: + + module: enum class + + """ + # Split on dot notation + split_string = cls_name.split(".") + module = None + for idx, val in enumerate(split_string): + # idx = 0 will always be a call to the sys.modules dict + if idx == 0: + module = sys.modules[val] + # all other idx are paths along the module that need to be traversed + # idx = -1 will always be the final Enum object name we want to grab (final getattr call) + else: + module = getattr(module, val) + return module + + +class AttrBuilder(BaseBuilder): + """Attr specific builder + + Class that handles building for the attr backend + + *Attributes* + + input_classes: list of input classes that link to a backend + _configs: None or List of configs to read from + _create_save_path: boolean to make the path to save to + _desc: description for the arg parser + _no_cmd_line: flag to force no command line reads + save_path: list of path(s) to save the configs to + + """ + + def __init__(self, *args, **kwargs): + """AttrBuilder init + + Args: + *args: list of input classes that link to a backend + configs: None or List of configs to read from + desc: description for the arg parser + no_cmd_line: flag to force no command line reads + **kwargs: any extra keyword args + """ + super().__init__(*args, module_name="spock.backend.config", **kwargs) + + @staticmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.name so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + attr_name = class_obj.__name__ + group_parser = parser.add_argument_group( + title=str(attr_name) + " Specific Overrides" + ) + for val in class_obj.__attrs_attrs__: + val_type = val.metadata["type"] if "type" in val.metadata else val.type + # Check if the val type has __args__ -- this catches lists? + # TODO (ncilfone): Fix up this super super ugly logic + if ( + hasattr(val_type, "__args__") + and ((list(set(val_type.__args__))[0]).__module__ == class_name) + and attr.has((list(set(val_type.__args__))[0])) + ): + args = list(set(val_type.__args__))[0] + for inner_val in args.__attrs_attrs__: + arg_name = f"--{str(attr_name)}.{val.name}.{args.__name__}.{inner_val.name}" + group_parser = make_argument( + arg_name, List[inner_val.type], group_parser + ) + # If it's a reference to a class it needs to be an arg of a simple string as class matching will take care + # of it later on + elif val_type.__module__ == "spock.backend.config": + arg_name = f"--{str(attr_name)}.{val.name}" + val_type = str + group_parser = make_argument(arg_name, val_type, group_parser) + else: + arg_name = f"--{str(attr_name)}.{val.name}" + group_parser = make_argument(arg_name, val_type, group_parser) + return parser + + def _handle_arguments(self, args, class_obj): + attr_name = class_obj.__name__ + class_names = [val.__name__ for val in self.input_classes] + # Handle repeated classes + if ( + attr_name in class_names + and attr_name in args + and isinstance(args[attr_name], list) + ): + fields = self._handle_repeated(args[attr_name], attr_name, class_names) + # Handle non-repeated classes + else: + fields = {} + for val in class_obj.__attrs_attrs__: + # Check if namespace is named and then check for key -- checking for local class def + if attr_name in args and val.name in args[attr_name]: + fields[val.name] = self._handle_nested_class( + args, args[attr_name][val.name], class_names + ) + # If not named then just check for keys -- checking for global def + elif val.name in args: + fields[val.name] = self._handle_nested_class( + args, args[val.name], class_names + ) + # Check for special keys to set + if ( + "special_key" in val.metadata + and val.metadata["special_key"] is not None + ): + if val.name in args: + self.save_path = args[val.name] + elif val.default is not None: + self.save_path = val.default + return fields + + def _handle_repeated(self, args, check_value, class_names): + """Handles repeated classes as lists + + *Args*: + + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + list of input_class[match)idx[0]] types filled with repeated values + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + return [self.input_classes[match_idx[0]](**val) for val in args] + + def _handle_nested_class(self, args, check_value, class_names): + """Handles passing another class to the field dictionary + + *Args*: + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + either the check_value or the necessary class + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + # If so then create the needed class object by unrolling the args to **kwargs and return it + if len(match_idx) > 0: + if len(match_idx) > 1: + raise ValueError( + "Match error -- multiple classes with the same name definition" + ) + else: + if args.get(self.input_classes[match_idx[0]].__name__) is None: + raise ValueError( + f"Missing config file definition for the referenced class " + f"{self.input_classes[match_idx[0]].__name__}" + ) + current_arg = args.get(self.input_classes[match_idx[0]].__name__) + if isinstance(current_arg, list): + class_value = [ + self.input_classes[match_idx[0]](**val) for val in current_arg + ] + else: + class_value = self.input_classes[match_idx[0]](**current_arg) + return_value = class_value + # else return the expected value + else: + return_value = check_value + return return_value + + def _extract_fnc(self, val, module_name): + """Function that gets the nested lists within classes + + *Args*: + + val: current attr + module_name: matching module name + + *Returns*: + + list of any nested classes/enums + + """ + return ( + self._extract_other_types(val.metadata["type"], module_name) + if "type" in val.metadata + else [] + ) diff --git a/spock/backend/attr/config.py b/spock/backend/config.py similarity index 67% rename from spock/backend/attr/config.py rename to spock/backend/config.py index a081d14d..9f481294 100644 --- a/spock/backend/attr/config.py +++ b/spock/backend/config.py @@ -1,16 +1,18 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Creates the spock config interface that wraps attr""" import sys + import attr -from spock.backend.attr.typed import katra +from spock.backend.typed import katra -def spock_attr(cls): + +def _base_attr(cls): """Map type hints to katras Connector function that maps type hinting style to the defined katra style which uses the more strict @@ -24,7 +26,6 @@ def spock_attr(cls): cls: slotted attrs class that is frozen and kw only """ - # Since we are not using the @attr.s decorator we need to get the parent classes for inheritance # We do this by using the mro and grabbing anything that is not the first and last indices in the list and wrapping # it into a tuple @@ -35,7 +36,7 @@ def spock_attr(cls): bases = () # Make a blank attrs dict for new attrs attrs_dict = {} - if hasattr(cls, '__annotations__'): + if hasattr(cls, "__annotations__"): for k, v in cls.__annotations__.items(): # If the cls has the attribute then a default was set if hasattr(cls, k): @@ -43,10 +44,30 @@ def spock_attr(cls): else: default = None attrs_dict.update({k: katra(typed=v, default=default)}) + return bases, attrs_dict + + +def spock_attr(cls): + """Map type hints to katras + + Connector function that maps type hinting style to the defined katra style which uses the more strict + attr.ib() definition + + *Args*: + + cls: basic class def + + *Returns*: + + cls: slotted attrs class that is frozen and kw only + """ + bases, attrs_dict = _base_attr(cls) # Dynamically make an attr class - obj = attr.make_class(name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True) + obj = attr.make_class( + name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True + ) # For each class we dynamically create we need to register it within the system modules for pickle to work - setattr(sys.modules['spock'].backend.attr.config, obj.__name__, obj) + setattr(sys.modules["spock"].backend.config, obj.__name__, obj) # Swap the __doc__ string from cls to obj obj.__doc__ = cls.__doc__ return obj diff --git a/spock/backend/handler.py b/spock/backend/handler.py new file mode 100644 index 00000000..534b9b38 --- /dev/null +++ b/spock/backend/handler.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Base handler Spock class""" + +from abc import ABC + +from spock.handlers import JSONHandler, TOMLHandler, YAMLHandler + + +class BaseHandler(ABC): + """Base class for saver and payload + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + _s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + self._supported_extensions = { + ".yaml": YAMLHandler, + ".toml": TOMLHandler, + ".json": JSONHandler, + } + self._s3_config = s3_config + + def _check_extension(self, file_extension: str): + if file_extension not in self._supported_extensions: + raise TypeError( + f"File extension {file_extension} not supported -- \n" + f"File extension must be from {list(self._supported_extensions.keys())}" + ) diff --git a/spock/backend/payload.py b/spock/backend/payload.py new file mode 100644 index 00000000..5fc7ec86 --- /dev/null +++ b/spock/backend/payload.py @@ -0,0 +1,486 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles payloads from markup files""" + +import os +import sys +from abc import abstractmethod +from itertools import chain +from pathlib import Path + +from spock.backend.handler import BaseHandler +from spock.backend.utils import ( + convert_to_tuples, + deep_update, + get_attr_fields, + get_type_fields, +) +from spock.utils import check_path_s3 + + +class BasePayload(BaseHandler): # pylint: disable=too-few-public-methods + """Handles building the payload for config file(s) + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + __s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + super(BasePayload, self).__init__(s3_config=s3_config) + + @staticmethod + @abstractmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + """Updates the payload + + Checks the parameters defined in the config files against the provided classes and if + passable adds them to the payload + + *Args*: + + base_payload: current payload + input_classes: class to roll into + ignore_classes: list of classes to ignore + payload: total payload + + *Returns*: + + payload: updated payload + + """ + + def payload(self, input_classes, ignore_classes, path, cmd_args, deps): + """Builds the payload from config files + + Public exposed call to build the payload and set any command line overrides + + *Args*: + + input_classes: list of backend classes + ignore_classes: list of classes to ignore + path: path to config file(s) + cmd_args: command line overrides + deps: dictionary of config dependencies + + *Returns*: + + payload: dictionary of all mapped parameters + + """ + payload = self._payload(input_classes, ignore_classes, path, deps, root=True) + payload = self._handle_overrides(payload, ignore_classes, cmd_args) + return payload + + def _payload(self, input_classes, ignore_classes, path, deps, root=False): + """Private call to construct the payload + + Main function call that builds out the payload from config files of multiple types. It handles + various file types and also composition of config files via a recursive calls + + *Args*: + input_classes: list of backend classes + ignore_classes: list of classes to ignore + path: path to config file(s) + deps: dictionary of config dependencies + + *Returns*: + + payload: dictionary of all mapped parameters + + """ + # empty payload + payload = {} + if path is not None: + # Match to loader based on file-extension + config_extension = Path(path).suffix.lower() + # Verify extension + self._check_extension(file_extension=config_extension) + # Load from file + base_payload = self._supported_extensions.get(config_extension)().load( + path, s3_config=self._s3_config + ) + # Check and? update the dependencies + deps = self._handle_dependencies(deps, path, root) + if "config" in base_payload: + payload = self._handle_includes( + base_payload, + config_extension, + input_classes, + ignore_classes, + path, + payload, + deps, + ) + payload = self._update_payload( + base_payload, input_classes, ignore_classes, payload + ) + return payload + + @staticmethod + def _handle_dependencies(deps, path, root): + """Handles config file dependencies + + Checks to see if the config path (full or relative) has already been encountered. Essentially a DFS for graph + cycles + + *Args*: + + deps: dictionary of config dependencies + path: current config path + root: boolean if root + + *Returns*: + + deps: updated dependencies + + """ + if root and path in deps.get("paths"): + raise ValueError( + f"Duplicate Read -- Config file {path} has already been encountered. " + f"Please remove duplicate reads of config files." + ) + elif path in deps.get("paths") or path in deps.get("rel_paths"): + raise ValueError( + f"Cyclical Dependency -- Config file {path} has already been encountered. " + f"Please remove cyclical dependencies between config files." + ) + else: + # Update the dependency lists + deps.get("paths").append(path) + deps.get("rel_paths").append(os.path.basename(path)) + if root: + deps.get("roots").append(path) + return deps + + def _handle_includes( + self, + base_payload, + config_extension, + input_classes, + ignore_classes, + path, + payload, + deps, + ): # pylint: disable=too-many-arguments + """Handles config composition + + For all of the config tags in the config file this function will recursively call the payload function + with the composition path to get the additional payload(s) from the composed file(s) -- checks for file + validity or if it is an S3 URI via regex + + *Args*: + + base_payload: base payload that has a config kwarg + config_extension: file type + input_classes: defined backend classes + ignore_classes: list of classes to ignore + path: path to base file + payload: payload pulled from composed files + deps: dictionary of config dependencies + + *Returns*: + + payload: payload update from composed files + + """ + included_params = {} + for inc_path in base_payload["config"]: + if check_path_s3(inc_path): + use_path = inc_path + elif os.path.exists(inc_path): + use_path = inc_path + elif os.path.join(os.path.dirname(path), inc_path): + use_path = os.path.join(os.path.dirname(path), inc_path) + else: + raise RuntimeError( + f"Could not find included {config_extension} file {inc_path} or is not an S3 URI!" + ) + included_params.update( + self._payload(input_classes, ignore_classes, use_path, deps) + ) + payload.update(included_params) + return payload + + def _handle_overrides(self, payload, ignore_classes, args): + """Handle command line overrides + + Iterate through the command line override values, determine at what level to set them, and set them if possible + + *Args*: + + payload: current payload dictionary + args: command line override args + + *Returns*: + + payload: updated payload dictionary with override values set + + """ + skip_keys = ["config", "help"] + pruned_args = self._prune_args(args, ignore_classes) + for k, v in pruned_args.items(): + if k not in skip_keys and v is not None: + payload = self._handle_payload_override(payload, k, v) + return payload + + @staticmethod + def _prune_args(args, ignore_classes): + """Prunes ignored class names from the cmd line args list to prevent incorrect access + + *Args*: + + args: current cmd line args + ignore_classes: list of class names to ignore + + *Returns*: + + dictionary of pruned cmd line args + + """ + ignored_stems = [val.__name__ for val in ignore_classes] + return { + k: v for k, v in vars(args).items() if k.split(".")[0] not in ignored_stems + } + + @staticmethod + @abstractmethod + def _handle_payload_override(payload, key, value): + """Handles the complex logic needed for List[spock class] overrides + + Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str + names have to be mapped backed to sys.modules and can be set at either the general or class level. + + *Args*: + + payload: current payload dictionary + key: current arg key + value: value at current arg key + + *Returns*: + + payload: modified payload with overrides + + """ + + +class AttrPayload(BasePayload): + """Handles building the payload for attrs backend + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via a recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + + """ + + def __init__(self, s3_config=None): + """Init for AttrPayload + + *Args*: + + s3_config: optional S3 config object + + """ + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + """Call to allow self chaining + + *Args*: + + *args: + **kwargs: + + *Returns*: + + Payload: instance of self + + """ + return AttrPayload(*args, **kwargs) + + @staticmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + # Get basic args + attr_fields = get_attr_fields(input_classes=input_classes) + # Get the ignore fields + ignore_fields = get_attr_fields(input_classes=ignore_classes) + # Class names + class_names = [val.__name__ for val in input_classes] + # Parse out the types if generic + type_fields = get_type_fields(input_classes) + for keys, values in base_payload.items(): + if keys not in ignore_fields: + # check if the keys, value pair is expected by the attr class + if keys != "config": + # Dict infers that we are overriding a global setting in a specific config + if isinstance(values, dict): + # we're in a namespace + # Check for incorrect specific override of global def + if keys not in attr_fields: + raise TypeError( + f"Referring to a class space {keys} that is undefined" + ) + for i_keys in values.keys(): + if i_keys not in attr_fields[keys]: + raise ValueError( + f"Provided an unknown argument named {keys}.{i_keys}" + ) + else: + # Check if the key is actually a reference to another class + if keys in class_names: + if isinstance(values, list): + # Check for incorrect specific override of global def + if keys not in attr_fields: + raise ValueError( + f"Referring to a class space {keys} that is undefined" + ) + # We are in a repeated class def + # Raise if the key set is different from the defined set (i.e. incorrect arguments) + key_set = set( + list(chain(*[list(val.keys()) for val in values])) + ) + for i_keys in key_set: + if i_keys not in attr_fields[keys]: + raise ValueError( + f"Provided an unknown argument named {keys}.{i_keys}" + ) + # Chain all the values from multiple spock classes into one list + elif keys not in list(chain(*attr_fields.values())): + raise ValueError( + f"Provided an unknown argument named {keys}" + ) + # Chain all the values from multiple spock classes into one list + elif keys not in list(chain(*attr_fields.values())): + raise ValueError( + f"Provided an unknown argument named {keys}" + ) + if keys in payload and isinstance(values, dict): + payload[keys].update(values) + else: + payload[keys] = values + tuple_payload = convert_to_tuples(payload, type_fields, class_names) + payload = deep_update(payload, tuple_payload) + return payload + + @staticmethod + def _handle_payload_override(payload, key, value): + """Handles the complex logic needed for List[spock class] overrides + + Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str + names have to be mapped backed to sys.modules and can be set at either the general or class level. + + *Args*: + + payload: current payload dictionary + key: current arg key + value: value at current arg key + + *Returns*: + + payload: modified payload with overrides + + """ + key_split = key.split(".") + curr_ref = payload + # Handle non existing parts of the payload for specific cases + root_classes = [ + idx + for idx, val in enumerate(key_split) + if hasattr(sys.modules["spock"].backend.config, val) + ] + # Verify any classes have roots in the payload dict + for idx in root_classes: + # Update all root classes if not present + if key_split[idx] not in payload: + payload.update({key_split[idx]: {}}) + # If not updating the root then it is a reference to another class which might not be in the payload + # Make sure it's there by setting it -- since this is an override setting is fine as these should be the + # final say in the param values so don't worry about clashing + if idx != 0: + payload[key_split[0]][key_split[idx - 1]] = key_split[idx] + # Check also for repeated classes -- value will be a list when the type is not + var = getattr( + getattr( + sys.modules["spock"].backend.config, key_split[idx] + ).__attrs_attrs__, + key_split[-1], + ) + if isinstance(value, list) and var.type != list: + # If the dict is blank we need to handle the creation of the list of dicts + if len(payload[key_split[idx]]) == 0: + payload.update( + { + key_split[idx]: [ + {key_split[-1]: None} for _ in range(len(value)) + ] + } + ) + # If it's already partially filled we need to update not overwrite + else: + for val in payload[key_split[idx]]: + val.update({key_split[-1]: None}) + + for idx, split in enumerate(key_split): + # Check for curr_ref switch over -- verify by checking the sys modules names + if ( + idx != 0 + and (split in payload) + and (isinstance(curr_ref, str)) + and (hasattr(sys.modules["spock"].backend.config, split)) + ): + curr_ref = payload[split] + # Look ahead to check if the next value exists in the dictionary + elif ( + idx != 0 + and (split in payload) + and (isinstance(payload[split], str)) + and (hasattr(sys.modules["spock"].backend.config, payload[split])) + ): + curr_ref = payload[split] + # elif check if it's the last value and figure out the override + elif idx == (len(key_split) - 1): + # Handle bool(s) a bit differently as they are store_true + if isinstance(curr_ref, dict) and isinstance(value, bool): + if value is not False: + curr_ref[split] = value + # If we are at the dictionary level we should be able to just payload override + elif isinstance(curr_ref, dict) and not isinstance(value, bool): + curr_ref[split] = value + # If we are at a list level it must be some form of repeated class since this is the end of the class + # tree -- check the instance type but also make sure the cmd-line override is the correct len + elif isinstance(curr_ref, list) and len(value) == len(curr_ref): + # Walk the list and check for the key + for ref_idx, val in enumerate(curr_ref): + if split in val: + val[split] = value[ref_idx] + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level List[Dict]" + ) + elif isinstance(curr_ref, list) and len(value) != len(curr_ref): + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Specified key {split} with len {len(value)} does not match len {len(curr_ref)} " + f"of List[Dict]" + ) + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level Dict" + ) + # If it's not keep walking the current payload + else: + curr_ref = curr_ref[split] + return payload diff --git a/spock/backend/saver.py b/spock/backend/saver.py new file mode 100644 index 00000000..27f782ad --- /dev/null +++ b/spock/backend/saver.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles prepping and saving the Spock config""" + +from abc import abstractmethod +from uuid import uuid4 + +import attr + +from spock.backend.handler import BaseHandler +from spock.utils import add_info + + +class BaseSaver(BaseHandler): # pylint: disable=too-few-public-methods + """Base class for saving configs + + Contains methods to build a correct output payload and then writes to file based on the file + extension + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + _s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + super(BaseSaver, self).__init__(s3_config=s3_config) + + def save( + self, + payload, + path, + file_name=None, + create_save_path=False, + extra_info=True, + file_extension=".yaml", + tuner_payload=None, + fixed_uuid=None, + ): # pylint: disable=too-many-arguments + """Writes Spock config to file + + Cleans and builds an output payload and then correctly writes it to file based on the + specified file extension + + *Args*: + + payload: current config payload + path: path to save + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None + create_save_path: boolean to create the path if non-existent + extra_info: boolean to write extra info + file_extension: what type of file to write + tuner_payload: tuner level payload (unsampled) + fixed_uuid: fixed uuid to allow for file overwrite + + *Returns*: + + None + + """ + # Check extension + self._check_extension(file_extension=file_extension) + # Make the filename -- always append a uuid for unique-ness + uuid_str = str(uuid4()) if fixed_uuid is None else fixed_uuid + fname = "" if file_name is None else f"{file_name}." + name = f"{fname}{uuid_str}.spock.cfg{file_extension}" + # Fix up values -- parameters + out_dict = self._clean_up_values(payload, file_extension) + # Fix up the tuner values if present + tuner_dict = ( + self._clean_tuner_values(tuner_payload) + if tuner_payload is not None + else None + ) + if tuner_dict is not None: + out_dict.update(tuner_dict) + # Get extra info + extra_dict = add_info() if extra_info else None + try: + self._supported_extensions.get(file_extension)().save( + out_dict=out_dict, + info_dict=extra_dict, + path=str(path), + name=name, + create_path=create_save_path, + s3_config=self._s3_config, + ) + except OSError as e: + print(f"Unable to write to given path: {path / name}") + raise e + + @abstractmethod + def _clean_up_values(self, payload, file_extension): + """Clean up the config payload so it can be written to file + + *Args*: + + payload: dirty payload + extra_info: boolean to add extra info + file_extension: type of file to write + + *Returns*: + + clean_dict: cleaned output payload + + """ + + @abstractmethod + def _clean_tuner_values(self, payload): + """Cleans up the base tuner payload that is not sampled + + *Args*: + + payload: dirty payload + + *Returns*: + + clean_dict: cleaned output payload + + """ + + def _clean_output(self, out_dict): + """Clean up the dictionary so it can be written to file + + *Args*: + + out_dict: cleaned dictionary + extra_info: boolean to add extra info + + *Returns*: + + clean_dict: cleaned output payload + + """ + # Convert values + clean_dict = {} + for key, val in out_dict.items(): + clean_inner_dict = {} + if isinstance(val, list): + for idx, list_val in enumerate(val): + tmp_dict = {} + for inner_key, inner_val in list_val.items(): + tmp_dict = self._convert(tmp_dict, inner_val, inner_key) + val[idx] = tmp_dict + clean_inner_dict = val + else: + for inner_key, inner_val in val.items(): + clean_inner_dict = self._convert( + clean_inner_dict, inner_val, inner_key + ) + clean_dict.update({key: clean_inner_dict}) + return clean_dict + + def _convert(self, clean_inner_dict, inner_val, inner_key): + # Convert tuples to lists so they get written correctly + if isinstance(inner_val, tuple): + clean_inner_dict.update( + {inner_key: self._recursive_tuple_to_list(inner_val)} + ) + elif inner_val is not None: + clean_inner_dict.update({inner_key: inner_val}) + return clean_inner_dict + + def _recursive_tuple_to_list(self, value): + """Recursively turn tuples into lists + + Recursively looks through tuple(s) and convert to lists + + *Args*: + + value: value to check and set typ if necessary + typed: type of the generic alias to check against + + *Returns*: + + value: updated value with correct type casts + + """ + # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple + # from a composed payload + list_v = [] + for v in value: + if isinstance(v, tuple): + v = self._recursive_tuple_to_list(v) + list_v.append(v) + else: + list_v.append(v) + return list_v + + +class AttrSaver(BaseSaver): + """Base class for saving configs for the attrs backend + + Contains methods to build a correct output payload and then writes to file based on the file + extension + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + + """ + + def __init__(self, s3_config=None): + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + return AttrSaver(*args, **kwargs) + + def _clean_up_values(self, payload, file_extension): + # Dictionary to recursively write to + out_dict = {} + # All of the classes are defined at the top level + all_spock_cls = set(vars(payload).keys()) + out_dict = self._recursively_handle_clean( + payload, out_dict, all_cls=all_spock_cls + ) + # Convert values + clean_dict = self._clean_output(out_dict) + return clean_dict + + def _clean_tuner_values(self, payload): + # Just a double nested dict comprehension to unroll to dicts + out_dict = { + k: {ik: vars(iv) for ik, iv in vars(v).items()} + for k, v in vars(payload).items() + } + # Convert values + clean_dict = self._clean_output(out_dict) + return clean_dict + + def _recursively_handle_clean( + self, payload, out_dict, parent_name=None, all_cls=None + ): + """Recursively works through spock classes and adds clean data to a dictionary + + Given a payload (Spockspace) work recursively through items that don't have parents to catch all + parameter definitions while correctly mapping nested class definitions to their base level class thus + allowing the output markdown to be a valid input file + + *Args*: + + payload: current payload (namespace) + out_dict: output dictionary + parent_name: name of the parent spock class if nested + all_cls: all top level spock class definitions + + *Returns*: + + out_dict: modified dictionary with the cleaned data + + """ + for key, val in vars(payload).items(): + val_name = type(val).__name__ + # This catches basic lists and list of classes + if isinstance(val, list): + # Check if each entry is a spock class + clean_val = [] + repeat_flag = False + for l_val in val: + cls_name = type(l_val).__name__ + # For those that are a spock class and are repeated (cls_name == key) simply convert to dict + if (cls_name in all_cls) and (cls_name == key): + clean_val.append(attr.asdict(l_val)) + # For those whose cls is different than the key just append the cls name + elif cls_name in all_cls: + # Change the flag as this is a repeated class -- which needs to be compressed into a single + # k:v pair + repeat_flag = True + clean_val.append(cls_name) + # Fall back to the passed in values + else: + clean_val.append(l_val) + # Handle repeated classes + if repeat_flag: + clean_val = list(set(clean_val))[-1] + out_dict.update({key: clean_val}) + # If it's a spock class but has a parent then just use the class name to reference the values + elif (val_name in all_cls) and parent_name is not None: + out_dict.update({key: val_name}) + # Check if it's a spock class without a parent -- iterate the values and recurse to catch more lists + elif val_name in all_cls: + new_dict = self._recursively_handle_clean( + val, {}, parent_name=key, all_cls=all_cls + ) + out_dict.update({key: new_dict}) + # Either base type or no nested values that could be Spock classes + else: + out_dict.update({key: val}) + return out_dict diff --git a/spock/backend/attr/typed.py b/spock/backend/typed.py similarity index 72% rename from spock/backend/attr/typed.py rename to spock/backend/typed.py index 8452e111..fcb8678b 100644 --- a/spock/backend/attr/typed.py +++ b/spock/backend/typed.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles the definitions of arguments types for Spock (backend: attrs)""" @@ -8,9 +8,10 @@ import sys from enum import EnumMeta from functools import partial -from typing import TypeVar -from typing import Union +from typing import TypeVar, Union + import attr + minor = sys.version_info.minor if minor < 7: from typing import GenericMeta as _GenericAlias @@ -24,6 +25,7 @@ class SavePath(str): Defines a special key use to save the current Spock config to file """ + def __new__(cls, x): return super().__new__(cls, x) @@ -40,7 +42,7 @@ def _get_name_py_version(typed): name of the type """ - return typed._name if hasattr(typed, '_name') else typed.__name__ + return typed._name if hasattr(typed, "_name") else typed.__name__ def _extract_base_type(typed): @@ -57,7 +59,7 @@ def _extract_base_type(typed): name of type """ - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): name = _get_name_py_version(typed=typed) bracket_val = f"{name}[{_extract_base_type(typed.__args__[0])}]" return bracket_val @@ -81,19 +83,28 @@ def _recursive_generic_validator(typed): return_type: recursively built deep_iterable validators """ - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): # If there are more __args__ then we still need to recurse as it is still a GenericAlias - return_type = attr.validators.deep_iterable( - member_validator=_recursive_generic_validator(typed.__args__[0]), - iterable_validator=attr.validators.instance_of(typed.__origin__) - ) + # Iterate through since there might be multiple types? + if len(typed.__args__) > 1: + return_type = attr.validators.deep_iterable( + member_validator=_recursive_generic_validator(typed.__args__), + iterable_validator=attr.validators.instance_of(typed.__origin__), + ) + else: + return_type = attr.validators.deep_iterable( + member_validator=_recursive_generic_validator(typed.__args__[0]), + iterable_validator=attr.validators.instance_of(typed.__origin__), + ) return return_type else: # If no more __args__ then we are to the base type and need to bubble up the type # But we need to check against base types and enums if isinstance(typed, EnumMeta): base_type, allowed = _check_enum_props(typed) - return_type = attr.validators.and_(attr.validators.instance_of(base_type), attr.validators.in_(allowed)) + return_type = attr.validators.and_( + attr.validators.instance_of(base_type), attr.validators.in_(allowed) + ) else: return_type = attr.validators.instance_of(typed) return return_type @@ -122,19 +133,34 @@ def _generic_alias_katra(typed, default=None, optional=False): # base python class from which a GenericAlias is derived base_typed = typed.__origin__ if default is not None: - x = attr.ib(validator=_recursive_generic_validator(typed), default=default, type=base_typed, - metadata={'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=_recursive_generic_validator(typed), + default=default, + type=base_typed, + metadata={"base": _extract_base_type(typed), "type": typed}, + ) # x = attr.ib(validator=_recursive_generic_iterator(typed), default=default, type=base_typed, # metadata={'base': _extract_base_type(typed)}) elif optional: # if there's no default, but marked as optional, then set the default to None - x = attr.ib(validator=attr.validators.optional(_recursive_generic_validator(typed)), type=base_typed, - default=default, metadata={'optional': True, 'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=attr.validators.optional(_recursive_generic_validator(typed)), + type=base_typed, + default=default, + metadata={ + "optional": True, + "base": _extract_base_type(typed), + "type": typed, + }, + ) # x = attr.ib(validator=attr.validators.optional(_recursive_generic_iterator(typed)), type=base_typed, # default=default, metadata={'optional': True, 'base': _extract_base_type(typed)}) else: - x = attr.ib(validator=_recursive_generic_validator(typed), type=base_typed, - metadata={'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=_recursive_generic_validator(typed), + type=base_typed, + metadata={"base": _extract_base_type(typed), "type": typed}, + ) # x = attr.ib(validator=_recursive_generic_iterator(typed), type=base_typed, # metadata={'base': _extract_base_type(typed)}) return x @@ -184,10 +210,18 @@ def _enum_katra(typed, default=None, optional=False): """ # First check if the types of Enum are the same base_type, allowed = _check_enum_props(typed) - if base_type.__name__ == 'type': - x = _enum_class_katra(typed=typed, allowed=allowed, default=default, optional=optional) + if base_type.__name__ == "type": + x = _enum_class_katra( + typed=typed, allowed=allowed, default=default, optional=optional + ) else: - x = _enum_base_katra(typed=typed, base_type=base_type, allowed=allowed, default=default, optional=optional) + x = _enum_base_katra( + typed=typed, + base_type=base_type, + allowed=allowed, + default=default, + optional=optional, + ) return x @@ -214,15 +248,32 @@ def _enum_base_katra(typed, base_type, allowed, default=None, optional=False): """ if default is not None: x = attr.ib( - validator=[attr.validators.instance_of(base_type), attr.validators.in_(allowed)], - default=default, type=typed, metadata={'base': typed.__name__}) + validator=[ + attr.validators.instance_of(base_type), + attr.validators.in_(allowed), + ], + default=default, + type=typed, + metadata={"base": typed.__name__}, + ) elif optional: x = attr.ib( - validator=attr.validators.optional([attr.validators.instance_of(base_type), attr.validators.in_(allowed)]), - default=default, type=typed, metadata={'base': typed.__name__, 'optional': True}) + validator=attr.validators.optional( + [attr.validators.instance_of(base_type), attr.validators.in_(allowed)] + ), + default=default, + type=typed, + metadata={"base": typed.__name__, "optional": True}, + ) else: - x = attr.ib(validator=[attr.validators.instance_of(base_type), attr.validators.in_(allowed)], type=typed, - metadata={'base': typed.__name__}) + x = attr.ib( + validator=[ + attr.validators.instance_of(base_type), + attr.validators.in_(allowed), + ], + type=typed, + metadata={"base": typed.__name__}, + ) return x @@ -242,7 +293,7 @@ def _in_type(instance, attribute, value, options): """ if type(value) not in options: - raise ValueError(f'{attribute.name} must be in {options}') + raise ValueError(f"{attribute.name} must be in {options}") def _enum_class_katra(typed, allowed, default=None, optional=False): @@ -269,14 +320,24 @@ def _enum_class_katra(typed, allowed, default=None, optional=False): """ if default is not None: x = attr.ib( - validator=[partial(_in_type, options=allowed)], default=default, type=typed, - metadata={'base': typed.__name__}) + validator=[partial(_in_type, options=allowed)], + default=default, + type=typed, + metadata={"base": typed.__name__}, + ) elif optional: x = attr.ib( validator=attr.validators.optional([partial(_in_type, options=allowed)]), - default=default, type=typed, metadata={'base': typed.__name__, 'optional': True}) + default=default, + type=typed, + metadata={"base": typed.__name__, "optional": True}, + ) else: - x = attr.ib(validator=[partial(_in_type, options=allowed)], type=typed, metadata={'base': typed.__name__}) + x = attr.ib( + validator=[partial(_in_type, options=allowed)], + type=typed, + metadata={"base": typed.__name__}, + ) return x @@ -306,7 +367,7 @@ def _type_katra(typed, default=None, optional=False): elif isinstance(typed, _GenericAlias): name = _get_name_py_version(typed=typed) else: - raise TypeError('Encountered an unexpected type in _type_katra') + raise TypeError("Encountered an unexpected type in _type_katra") special_key = None # Default booleans to false and optional due to the nature of a boolean if isinstance(typed, type) and name == "bool": @@ -320,14 +381,25 @@ def _type_katra(typed, default=None, optional=False): typed = str if default is not None: # if a default is provided, that takes precedence - x = attr.ib(validator=attr.validators.instance_of(typed), default=default, type=typed, - metadata={'base': name, 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.instance_of(typed), + default=default, + type=typed, + metadata={"base": name, "special_key": special_key}, + ) elif optional: - x = attr.ib(validator=attr.validators.optional(attr.validators.instance_of(typed)), default=default, type=typed, - metadata={'optional': True, 'base': name, 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(typed)), + default=default, + type=typed, + metadata={"optional": True, "base": name, "special_key": special_key}, + ) else: - x = attr.ib(validator=attr.validators.instance_of(typed), type=typed, metadata={'base': name, - 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.instance_of(typed), + type=typed, + metadata={"base": name, "special_key": special_key}, + ) return x @@ -349,7 +421,7 @@ def _handle_optional_typing(typed): # Set optional to false optional = False # Check if it has __args__ to look for optionality as it is a GenericAlias - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): # If it is more than one than it is most likely optional but check against NoneType in the tuple to verify # Check the length of type __args__ type_args = typed.__args__ @@ -364,6 +436,8 @@ def _handle_optional_typing(typed): def _check_generic_recursive_single_type(typed): """Checks generics for the single types -- mixed types of generics are not allowed + DEPRECATED -- NOW SUPPORTS MIXED TYPES OF TUPLES + *Args*: typed: type @@ -372,13 +446,14 @@ def _check_generic_recursive_single_type(typed): """ # Check if it has __args__ to look for optionality as it is a GenericAlias - if hasattr(typed, '__args__'): - if len(set(typed.__args__)) > 1: - type_list = [str(val) for val in typed.__args__] - raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}") - else: - for val in typed.__args__: - _check_generic_recursive_single_type(typed=val) + # if hasattr(typed, '__args__'): + # if len(set(typed.__args__)) > 1: + # type_list = [str(val) for val in typed.__args__] + # raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}") + # else: + # for val in typed.__args__: + # _check_generic_recursive_single_type(typed=val) + pass def katra(typed, default=None): @@ -405,7 +480,9 @@ def katra(typed, default=None): _check_generic_recursive_single_type(typed) # We need to check if the type is a _GenericAlias so that we can handle subscripted general types # If it is subscript typed it will not be T which python uses as a generic type name - if isinstance(typed, _GenericAlias) and (not isinstance(typed.__args__[0], TypeVar)): + if isinstance(typed, _GenericAlias) and ( + not isinstance(typed.__args__[0], TypeVar) + ): x = _generic_alias_katra(typed=typed, default=default, optional=optional) elif isinstance(typed, EnumMeta): x = _enum_katra(typed=typed, default=default, optional=optional) diff --git a/spock/backend/attr/utils.py b/spock/backend/utils.py similarity index 77% rename from spock/backend/attr/utils.py rename to spock/backend/utils.py index 51e84e63..3ac147ab 100644 --- a/spock/backend/attr/utils.py +++ b/spock/backend/utils.py @@ -1,11 +1,29 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Attr utility functions for Spock""" +def get_attr_fields(input_classes): + """Gets the attribute fields from all classes + + *Args*: + + input_classes: current list of input classes + + *Returns*: + + dictionary of all attrs attribute fields + + """ + return { + attr.__name__: [val.name for val in attr.__attrs_attrs__] + for attr in input_classes + } + + def get_type_fields(input_classes): """Creates a dictionary of names and types @@ -23,8 +41,8 @@ def get_type_fields(input_classes): for attr in input_classes: input_attr = {} for val in attr.__attrs_attrs__: - if 'type' in val.metadata: - input_attr.update({val.name: val.metadata['type']}) + if "type" in val.metadata: + input_attr.update({val.name: val.metadata["type"]}) else: input_attr.update({val.name: None}) type_fields.update({attr.__name__: input_attr}) @@ -73,14 +91,16 @@ def convert_to_tuples(input_dict, named_type_dict, class_names): updated_dict = {} all_typed_dict = flatten_type_dict(named_type_dict) for k, v in input_dict.items(): - if k != 'config': + if k != "config": if isinstance(v, dict): updated = convert_to_tuples(v, named_type_dict.get(k), class_names) if updated: updated_dict.update({k: updated}) elif isinstance(v, list) and k in class_names: for val in v: - updated = convert_to_tuples(val, named_type_dict.get(k), class_names) + updated = convert_to_tuples( + val, named_type_dict.get(k), class_names + ) if updated: updated_dict.update({k: updated}) elif all_typed_dict[k] is not None: @@ -133,18 +153,25 @@ def _recursive_list_to_tuple(value, typed, class_names): """ # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple # from a composed payload - if hasattr(typed, '__args__') and not isinstance(value, tuple) and not (isinstance(value, str) - and value in class_names): + if ( + hasattr(typed, "__args__") + and not isinstance(value, tuple) + and not (isinstance(value, str) and value in class_names) + ): # Force those with origin tuple types to be of the defined length - if (typed.__origin__.__name__.lower() == 'tuple') and len(value) != len(typed.__args__): - raise ValueError(f'Tuple(s) use a fixed/defined length -- Length of the provided argument ({len(value)}) ' - f'does not match the length of the defined argument ({len(typed.__args__)})') + if (typed.__origin__.__name__.lower() == "tuple") and len(value) != len( + typed.__args__ + ): + raise ValueError( + f"Tuple(s) use a fixed/defined length -- Length of the provided argument ({len(value)}) " + f"does not match the length of the defined argument ({len(typed.__args__)})" + ) # need to recurse before casting as we can't set values in a tuple with idx # Since it's generic it should be iterable to recurse and check it's children for idx, val in enumerate(value): value[idx] = _recursive_list_to_tuple(val, typed.__args__[0], class_names) # First check if list and then swap to tuple if the origin is tuple - if isinstance(value, list) and typed.__origin__.__name__.lower() == 'tuple': + if isinstance(value, list) and typed.__origin__.__name__.lower() == "tuple": value = tuple(value) else: return value diff --git a/spock/backend/wrappers.py b/spock/backend/wrappers.py new file mode 100644 index 00000000..ffe2b9c6 --- /dev/null +++ b/spock/backend/wrappers.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles Spock data type wrappers""" + +import argparse + +import yaml + + +class Spockspace(argparse.Namespace): + """Inherits from Namespace to implement a pretty print on the obj + + Overwrites the __repr__ method with a pretty version of printing + + """ + + def __init__(self, **kwargs): + super(Spockspace, self).__init__(**kwargs) + + def __repr__(self): + # Remove aliases in YAML print + yaml.Dumper.ignore_aliases = lambda *args: True + return yaml.dump(self.__dict__, default_flow_style=False) diff --git a/spock/builder.py b/spock/builder.py index d6782064..6fba9819 100644 --- a/spock/builder.py +++ b/spock/builder.py @@ -1,18 +1,23 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles the building/saving of the configurations from the Spock config classes""" +import argparse +import sys +import typing from pathlib import Path +from uuid import uuid4 + import attr -from spock.backend.attr.builder import AttrBuilder -from spock.backend.attr.payload import AttrPayload -from spock.backend.attr.saver import AttrSaver -from spock.utils import check_payload_overwrite -from spock.utils import deep_payload_update -import typing + +from spock.backend.builder import AttrBuilder +from spock.backend.payload import AttrPayload +from spock.backend.saver import AttrSaver +from spock.backend.wrappers import Spockspace +from spock.utils import check_payload_overwrite, deep_payload_update class ConfigArgBuilder: @@ -25,27 +30,88 @@ class ConfigArgBuilder: *Attributes*: + _args: all command line args _arg_namespace: generated argument namespace _builder_obj: instance of a BaseBuilder class - _create_save_path: boolean to make the path to save to _dict_args: dictionary args from the command line _payload_obj: instance of a BasePayload class _saver_obj: instance of a BaseSaver class + _tune_payload_obj: payload for tuner related objects -- instance of TunerPayload class + _tune_obj: instance of TunerBuilder class + _tuner_interface: interface that handles the underlying library for sampling -- instance of TunerInterface + _tuner_state: current state of the hyper-parameter sampler + _tune_namespace: namespace that hold the generated tuner related parameters + _sample_count: current call to the sample function + _fixed_uuid: fixed uuid to write the best file to the same path """ - def __init__(self, *args, configs: typing.Optional[typing.List] = None, create_save_path: bool = False, - desc: str = '', no_cmd_line: bool = False, s3_config=None, **kwargs): - backend = self._set_backend(args) - self._create_save_path = create_save_path - self._builder_obj = backend.get('builder')( - *args, configs=configs, create_save_path=create_save_path, desc=desc, no_cmd_line=no_cmd_line, **kwargs) - self._payload_obj = backend.get('payload')(s3_config=s3_config) - self._saver_obj = backend.get('saver')(s3_config=s3_config) + + def __init__( + self, + *args, + configs: typing.Optional[typing.List] = None, + desc: str = "", + no_cmd_line: bool = False, + s3_config=None, + **kwargs, + ): + """Init call for ConfigArgBuilder + + *Args*: + + *args: tuple of spock decorated classes to process + configs: list of config paths + desc: description for help + no_cmd_line: turn off cmd line args + s3_config: s3Config object for S3 support + **kwargs: keyword args + + """ + # Do some verification first + self._verify_attr(args) + self._configs = configs + self._no_cmd_line = no_cmd_line + self._desc = desc + # Build the payload and saver objects + self._payload_obj = AttrPayload(s3_config=s3_config) + self._saver_obj = AttrSaver(s3_config=s3_config) + # Split the fixed parameters from the tuneable ones (if present) + fixed_args, tune_args = self._strip_tune_parameters(args) + # The fixed parameter builder + self._builder_obj = AttrBuilder(*fixed_args, **kwargs) + # The possible tunable parameter builder -- might return None + self._tune_obj, self._tune_payload_obj = self._handle_tuner_objects( + tune_args, s3_config, kwargs + ) + self._tuner_interface = None + self._tuner_state = None + self._tuner_status = None + self._sample_count = 0 + self._fixed_uuid = str(uuid4()) try: - self._dict_args = self._get_payload() + # Get all cmd line args and build overrides + self._args = self._handle_cmd_line() + # Get the actual payload from the config files -- fixed configs + self._dict_args = self._get_payload( + payload_obj=self._payload_obj, + input_classes=self._builder_obj.input_classes, + ignore_args=tune_args, + ) + # Build the Spockspace from the payload and the classes + # Fixed configs self._arg_namespace = self._builder_obj.generate(self._dict_args) + # Get the payload from the config files -- hyper-parameters -- only if the obj is not None + if self._tune_obj is not None: + self._tune_args = self._get_payload( + payload_obj=self._tune_payload_obj, + input_classes=self._tune_obj.input_classes, + ignore_args=fixed_args, + ) + # Build the Spockspace from the payload and the classes + # Tuneable parameters + self._tune_namespace = self._tune_obj.generate(self._tune_args) except Exception as e: - self._builder_obj.print_usage_and_exit(str(e), sys_exit=False) + self._print_usage_and_exit(str(e), sys_exit=False) raise ValueError(e) def __call__(self, *args, **kwargs): @@ -65,8 +131,6 @@ def __call__(self, *args, **kwargs): def generate(self): """Generate method that returns the actual argument namespace - *Args*: - *Returns*: @@ -75,81 +139,325 @@ def generate(self): """ return self._arg_namespace + @property + def tuner_status(self): + """Returns a dictionary of all the necessary underlying tuner internals to report the result""" + return self._tuner_status + + @property + def best(self): + """Returns a Spockspace of the best hyper-parameter config and the associated metric value""" + return self._tuner_interface.best + + def sample(self): + """Sample method that constructs a namespace from the fixed parameters and samples from the tuner space to + generate a Spockspace derived from both + + *Returns*: + + argument namespace(s) -- fixed + drawn sample from tuner backend + + """ + if self._tune_obj is None: + raise ValueError( + f"Called sample method without passing any @spockTuner decorated classes" + ) + if self._tuner_interface is None: + raise ValueError( + f"Called sample method without first calling the tuner method that initializes the " + f"backend library" + ) + return_tuple = self._tuner_state + self._tuner_status = self._tuner_interface.tuner_status + self._tuner_state = self._tuner_interface.sample() + self._sample_count += 1 + return return_tuple + + def tuner(self, tuner_config): + """Chained call that builds the tuner interface for either optuna or ax depending upon the type of the tuner_obj + + *Args*: + + tuner_config: a class of type optuna.study.Study or AX**** + + *Returns*: + + self so that functions can be chained + + """ + if self._tune_obj is None: + raise ValueError( + f"Called tuner method without passing any @spockTuner decorated classes" + ) + try: + from spock.addons.tune.tuner import TunerInterface + + self._tuner_interface = TunerInterface( + tuner_config=tuner_config, + tuner_namespace=self._tune_namespace, + fixed_namespace=self._arg_namespace, + ) + self._tuner_state = self._tuner_interface.sample() + except ImportError: + print( + "Missing libraries to support tune functionality. Please re-install with the extra tune " + "dependencies -- pip install spock-config[tune]" + ) + return self + + def _print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1): + """Prints the help message and exits + + *Args*: + + msg: message to print pre exit + + *Returns*: + + None + + """ + print(f"usage: {sys.argv[0]} -c [--config] config1 [config2, config3, ...]") + print(f'\n{self._desc if self._desc != "" else ""}\n') + print("configuration(s):\n") + # Call the fixed parameter help info + self._builder_obj.handle_help_info() + if self._tune_obj is not None: + self._tune_obj.handle_help_info() + if msg is not None: + print(msg) + if sys_exit: + sys.exit(exit_code) + + @staticmethod + def _handle_tuner_objects(tune_args, s3_config, kwargs): + """Handles creating the tuner builder object if @spockTuner classes were passed in + + *Args*: + + tune_args: list of tuner classes + s3_config: s3Config object for S3 support + kwargs: optional keyword args + + *Returns*: + + tuner builder object or None + + """ + if len(tune_args) > 0: + try: + from spock.addons.tune.builder import TunerBuilder + from spock.addons.tune.payload import TunerPayload + + tuner_builder = TunerBuilder(*tune_args, **kwargs) + tuner_payload = TunerPayload(s3_config=s3_config) + return tuner_builder, tuner_payload + except ImportError: + print( + "Missing libraries to support tune functionality. Please re-install with the extra tune " + "dependencies -- pip install spock-config[tune]" + ) + else: + return None, None + @staticmethod - def _set_backend(args: typing.List): - """Determines which backend class to use + def _verify_attr(args: typing.Tuple): + """Verifies that all the input classes are attr based *Args*: - args: list of classes passed to the builder + args: tuple of classes passed to the builder *Returns*: - backend: class of backend + None """ # Gather if all attr backend type_attrs = all([attr.has(arg) for arg in args]) if not type_attrs: which_idx = [attr.has(arg) for arg in args].index(False) - if hasattr(args[which_idx], '__name__'): - raise TypeError(f"*args must be of all attrs backend -- missing a @spock decorator on class " - f"{args[which_idx].__name__}") + if hasattr(args[which_idx], "__name__"): + raise TypeError( + f"*args must be of all attrs backend -- missing a @spock decorator on class " + f"{args[which_idx].__name__}" + ) else: - raise TypeError(f"*args must be of all attrs backend -- invalid type " - f"{type(args[which_idx])}") - else: - backend = {'builder': AttrBuilder, 'payload': AttrPayload, 'saver': AttrSaver} - return backend + raise TypeError( + f"*args must be of all attrs backend -- invalid type " + f"{type(args[which_idx])}" + ) + + @staticmethod + def _strip_tune_parameters(args: typing.Tuple): + """Separates the fixed arguments from any hyper-parameter arguments - def _get_config_paths(self): - """Get config paths from all methods + *Args*: + + args: tuple of classes passed to the builder + + *Returns*: + + fixed_args: list of fixed args + tune_args: list of args destined for a tuner backend + + """ + fixed_args = [] + tune_args = [] + for arg in args: + if arg.__module__ == "spock.backend.config": + fixed_args.append(arg) + elif arg.__module__ == "spock.addons.tune.config": + tune_args.append(arg) + return fixed_args, tune_args + + def _handle_cmd_line(self): + """Handle all cmd line related tasks Config paths can enter from either the command line or be added in the class init call - as a kwarg (configs=[]) + as a kwarg (configs=[]) -- also trigger the building of the cmd line overrides for each fixed and + tunable objects *Returns*: args: namespace of args """ - # Call the objects get_config_paths function - args = self._builder_obj.get_config_paths() + # Need to hold an overarching parser here that just gets appended to for both fixed and tunable objects + # Check if the no_cmd_line is not flagged and if the configs are not empty + if self._no_cmd_line and (self._configs is None): + raise ValueError( + "Flag set for preventing command line read but no paths were passed to the config kwarg" + ) + # If cmd_line is flagged then build the parsers if not make any empty Namespace + args = ( + self._build_override_parsers(desc=self._desc) + if not self._no_cmd_line + else argparse.Namespace(config=[], help=False) + ) + # If configs are present from the init call then roll these into the namespace + if self._configs is not None: + args = self._get_from_kwargs(args, self._configs) + return args + + def _build_override_parsers(self, desc): + """Creates parsers for command-line overrides + + Builds the basic command line parser for configs and help then iterates through each attr instance to make + namespace specific cmd line override parsers -- handles calling both the fixed and tunable objects + + *Args*: + + desc: argparser description + + *Returns*: + + args: argument namespace + + """ + # Highest level parser object + parser = argparse.ArgumentParser(description=desc, add_help=False) + parser.add_argument("-c", "--config", required=False, nargs="+", default=[]) + parser.add_argument("-h", "--help", action="store_true") + # Handle the builder obj + parser = self._builder_obj.build_override_parsers(parser=parser) + if self._tune_obj is not None: + parser = self._tune_obj.build_override_parsers(parser=parser) + args = parser.parse_args() return args - def _get_payload(self): + @staticmethod + def _get_from_kwargs(args, configs): + """Get configs from the configs kwarg + + *Args*: + + args: argument namespace + configs: config kwarg + + *Returns*: + + args: arg namespace + + """ + if isinstance(configs, list): + args.config.extend(configs) + else: + raise TypeError( + f"configs kwarg must be of type list -- given {type(configs)}" + ) + return args + + def _get_payload(self, payload_obj, input_classes, ignore_args: typing.List): """Get the parameter payload from the config file(s) Calls the various ways to get configs and then parses to retrieve the parameter payload - make sure to call deep update so as to not lose some parameters when only partially updating the payload + *Args*: + + payload_obj: current payload object to call + input_classes: classes to use to get payload + ignore_args: args that were decorated for hyper-parameter tuning + *Returns*: payload: dictionary of parameter values """ - args = self._get_config_paths() - if args.help: + if self._args.help: # Call sys exit with a clean code as this is the help call which is not unexpected behavior - self._builder_obj.print_usage_and_exit(sys_exit=True, exit_code=0) + self._print_usage_and_exit(sys_exit=True, exit_code=0) payload = {} - dependencies = {'paths': [], 'rel_paths': [], 'roots': []} - for configs in args.config: - payload_update = self._payload_obj.payload(self._builder_obj.input_classes, configs, args, dependencies) - check_payload_overwrite(payload, payload_update, configs) - deep_payload_update(payload, payload_update) + dependencies = {"paths": [], "rel_paths": [], "roots": []} + if payload_obj is not None: + # Make sure we are actually trying to map to input classes + if len(input_classes) > 0: + # If configs are present then iterate through them and deal with the payload + if len(self._args.config) > 0: + for configs in self._args.config: + payload_update = payload_obj.payload( + input_classes, + ignore_args, + configs, + self._args, + dependencies, + ) + check_payload_overwrite(payload, payload_update, configs) + deep_payload_update(payload, payload_update) + # If there are no configs present we have to fall back only on cmd line args to fill out the necessary + # data -- this is essentially using spock as a drop in replacement of arg-parser + else: + payload_update = payload_obj.payload( + input_classes, ignore_args, None, self._args, dependencies + ) + check_payload_overwrite(payload, payload_update, None) + deep_payload_update(payload, payload_update) return payload - def save(self, file_name: str = None, user_specified_path: str = None, extra_info: bool = True, - file_extension: str = '.yaml'): - """Saves the current config setup to file with a UUID + def _save( + self, + payload, + file_name: str = None, + user_specified_path: str = None, + create_save_path: bool = True, + extra_info: bool = True, + file_extension: str = ".yaml", + tuner_payload=None, + fixed_uuid=None, + ): + """Private interface -- saves the current config setup to file with a UUID *Args*: - file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None + payload: Spockspace to save + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to just uuid if None user_specified_path: if user provides a path it will be used as the path to write + create_save_path: bool to create the path to save if called extra_info: additional info to write to saved config (run date and git info) file_extension: file type to write (default: yaml) + tuner_payload: tuner level payload (unsampled) + fixed_uuid: fixed uuid to allow for file overwrite *Returns*: @@ -160,10 +468,115 @@ def save(self, file_name: str = None, user_specified_path: str = None, extra_inf elif self._builder_obj.save_path is not None: save_path = Path(self._builder_obj.save_path) else: - raise ValueError('Save did not receive a valid path from: (1) markup file(s) or (2) ' - 'the keyword arg user_specified_path') + raise ValueError( + "Save did not receive a valid path from: (1) markup file(s) or (2) " + "the keyword arg user_specified_path" + ) # Call the saver class and save function self._saver_obj.save( - self._arg_namespace, save_path, file_name, self._create_save_path, extra_info, file_extension + payload, + save_path, + file_name, + create_save_path, + extra_info, + file_extension, + tuner_payload, + fixed_uuid, ) return self + + def save( + self, + file_name: str = None, + user_specified_path: str = None, + create_save_path: bool = True, + extra_info: bool = True, + file_extension: str = ".yaml", + add_tuner_sample: bool = False, + ): + """Saves the current config setup to file with a UUID + + *Args*: + + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to just uuid if None + user_specified_path: if user provides a path it will be used as the path to write + create_save_path: bool to create the path to save if called + extra_info: additional info to write to saved config (run date and git info) + file_extension: file type to write (default: yaml) + append_tuner_state: save the current tuner sample to the payload + + *Returns*: + + self so that functions can be chained + """ + if add_tuner_sample: + if self._tune_obj is None: + raise ValueError( + f"Called save method with add_tuner_sample as {add_tuner_sample} without passing any @spockTuner " + f"decorated classes -- please use the add_tuner_sample flag for saving only hyper-parameter tuning " + f"runs" + ) + file_name = ( + f"hp.sample.{self._sample_count+1}" + if file_name is None + else f"{file_name}.hp.sample.{self._sample_count+1}" + ) + self._save( + self._tuner_state, + file_name, + user_specified_path, + create_save_path, + extra_info, + file_extension, + ) + else: + self._save( + self._arg_namespace, + file_name, + user_specified_path, + create_save_path, + extra_info, + file_extension, + tuner_payload=self._tune_namespace + if self._tune_obj is not None + else None, + ) + return self + + def save_best( + self, + file_name: str = None, + user_specified_path: str = None, + create_save_path: bool = True, + extra_info: bool = True, + file_extension: str = ".yaml", + ): + """Saves the current best config setup to file + + *Args*: + + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to just uuid if None + user_specified_path: if user provides a path it will be used as the path to write + create_save_path: bool to create the path to save if called + extra_info: additional info to write to saved config (run date and git info) + file_extension: file type to write (default: yaml) + + *Returns*: + + self so that functions can be chained + """ + if self._tune_obj is None: + raise ValueError( + f"Called save_best method without passing any @spockTuner decorated classes -- please use the save()" + f" method for saving non hyper-parameter tuning runs" + ) + file_name = f"hp.best" if file_name is None else f"{file_name}.hp.best" + self._save( + Spockspace(**vars(self._arg_namespace), **vars(self.best[0])), + file_name, + user_specified_path, + create_save_path, + extra_info, + file_extension, + fixed_uuid=self._fixed_uuid, + ) diff --git a/spock/config.py b/spock/config.py index 28ee0114..7426af19 100644 --- a/spock/config.py +++ b/spock/config.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Creates the spock config decorator that wraps attrs""" -from spock.backend.attr.config import spock_attr +from spock.backend.config import spock_attr from spock.utils import _is_spock_instance # Simplified decorator for attrs spock = spock_attr # Public alias for checking if an object is a @spock annotated class -isinstance_spock =_is_spock_instance +isinstance_spock = _is_spock_instance diff --git a/spock/handlers.py b/spock/handlers.py index bd4fabda..fb277151 100644 --- a/spock/handlers.py +++ b/spock/handlers.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """I/O handlers for various file formats""" -from abc import ABC -from abc import abstractmethod import json import os import re -from spock import __version__ -from spock.utils import check_path_s3 -import toml import typing +from abc import ABC, abstractmethod from warnings import warn + +import pytomlpp import yaml +from spock import __version__ +from spock.utils import check_path_s3 + class Handler(ABC): """Base class for file type loaders @@ -24,6 +25,7 @@ class Handler(ABC): ABC for loaders """ + def load(self, path: str, s3_config=None) -> typing.Dict: """Load function for file type @@ -57,8 +59,15 @@ def _load(self, path: str) -> typing.Dict: """ raise NotImplementedError - def save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str, name: str, - create_path: bool = False, s3_config=None): + def save( + self, + out_dict: typing.Dict, + info_dict: typing.Optional[typing.Dict], + path: str, + name: str, + create_path: bool = False, + s3_config=None, + ): """Write function for file type This will handle local or s3 writes with the boolean is_s3 flag. If detected it will conditionally import @@ -84,12 +93,17 @@ def save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], p if is_s3: try: from spock.addons.s3.utils import handle_s3_save_path - handle_s3_save_path(temp_path=write_path, s3_path=path, name=name, s3_config=s3_config) + + handle_s3_save_path( + temp_path=write_path, s3_path=path, name=name, s3_config=s3_config + ) except ImportError: - print('Error importing spock s3 utils after detecting s3:// save path') + print("Error importing spock s3 utils after detecting s3:// save path") @abstractmethod - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str) -> str: + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ) -> str: """Write function for file type *Args*: @@ -124,14 +138,16 @@ def _handle_possible_s3_load_path(path: str, s3_config=None) -> str: if is_s3: try: from spock.addons.s3.utils import handle_s3_load_path + path = handle_s3_load_path(path=path, s3_config=s3_config) except ImportError: - print('Error importing spock s3 utils after detecting s3:// load path') + print("Error importing spock s3 utils after detecting s3:// load path") return path @staticmethod - def _handle_possible_s3_save_path(path: str, name: str, create_path: bool, - s3_config=None) -> typing.Tuple[str, bool]: + def _handle_possible_s3_save_path( + path: str, name: str, create_path: bool, s3_config=None + ) -> typing.Tuple[str, bool]: """Handles the possibility of having to save to a S3 path Checks to see if it detects a S3 uri and if so generates a tmp location to write the file to pre-upload @@ -149,15 +165,17 @@ def _handle_possible_s3_save_path(path: str, name: str, create_path: bool, is_s3 = check_path_s3(path=path) if is_s3: if s3_config is None: - raise ValueError('Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths') - write_path = f'{s3_config.temp_folder}/{name}' + raise ValueError( + "Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) + write_path = f"{s3_config.temp_folder}/{name}" # Strip double slashes if exist - write_path = write_path.replace(r'//', r'/') + write_path = write_path.replace(r"//", r"/") else: # Handle the path logic for non S3 if not os.path.exists(path) and create_path: os.makedirs(path) - write_path = f'{path}/{name}' + write_path = f"{path}/{name}" return write_path, is_s3 @staticmethod @@ -173,14 +191,14 @@ def write_extra_info(path, info_dict): """ # Write the commented info as new lines - with open(path, 'w+') as fid: + with open(path, "w+") as fid: # Write a spock header - fid.write(f'# Spock Version: {__version__}\n') + fid.write(f"# Spock Version: {__version__}\n") # Write info dict if not None if info_dict is not None: for k, v in info_dict.items(): - fid.write(f'{k}: {v}\n') - fid.write('\n') + fid.write(f"{k}: {v}\n") + fid.write("\n") class YAMLHandler(Handler): @@ -189,19 +207,23 @@ class YAMLHandler(Handler): Base YAML class """ + # override default SafeLoader behavior to correctly # interpret 1e1 (as opposed to 1.e+1) as 10 # https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number/30462009#30462009 yaml.SafeLoader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: + "tag:yaml.org,2002:float", + re.compile( + """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.') + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), ) def _load(self, path: str) -> typing.Dict: @@ -216,12 +238,14 @@ def _load(self, path: str) -> typing.Dict: base_payload: dictionary of read file """ - file_contents = open(path, 'r').read() - file_contents = re.sub(r'--([a-zA-Z0-9_]*)', r'\g<1>: True', file_contents) + file_contents = open(path, "r").read() + file_contents = re.sub(r"--([a-zA-Z0-9_]*)", r"\g<1>: True", file_contents) base_payload = yaml.safe_load(file_contents) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for YAML type *Args*: @@ -237,7 +261,7 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], self.write_extra_info(path=path, info_dict=info_dict) # Remove aliases in YAML dump yaml.Dumper.ignore_aliases = lambda *args: True - with open(path, 'a') as yaml_fid: + with open(path, "a") as yaml_fid: yaml.safe_dump(out_dict, yaml_fid, default_flow_style=False) return path @@ -248,6 +272,7 @@ class TOMLHandler(Handler): Base TOML class """ + def _load(self, path: str) -> typing.Dict: """TOML load function @@ -260,10 +285,12 @@ def _load(self, path: str) -> typing.Dict: base_payload: dictionary of read file """ - base_payload = toml.load(path) + base_payload = pytomlpp.load(path) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for TOML type *Args*: @@ -277,8 +304,8 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], """ # First write the commented info self.write_extra_info(path=path, info_dict=info_dict) - with open(path, 'a') as toml_fid: - toml.dump(out_dict, toml_fid) + with open(path, "a") as toml_fid: + pytomlpp.dump(out_dict, toml_fid) return path @@ -288,6 +315,7 @@ class JSONHandler(Handler): Base JSON class """ + def _load(self, path: str) -> typing.Dict: """JSON load function @@ -304,7 +332,9 @@ def _load(self, path: str) -> typing.Dict: base_payload = json.load(json_fid) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for JSON type *Args*: @@ -317,8 +347,10 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], """ if info_dict is not None: - warn('JSON does not support comments and thus cannot save extra info to file... removing extra info') + warn( + "JSON does not support comments and thus cannot save extra info to file... removing extra info" + ) info_dict = None - with open(path, 'a') as json_fid: - json.dump(out_dict, json_fid, indent=4, separators=(',', ': ')) + with open(path, "a") as json_fid: + json.dump(out_dict, json_fid, indent=4, separators=(",", ": ")) return path diff --git a/spock/utils.py b/spock/utils.py index 282430a5..0fbd15e9 100644 --- a/spock/utils.py +++ b/spock/utils.py @@ -1,28 +1,31 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Utility functions for Spock""" import ast -import attr -from enum import EnumMeta import os import re import socket import subprocess import sys -from time import localtime -from time import strftime +from enum import EnumMeta +from time import localtime, strftime from warnings import warn + +import attr import git + minor = sys.version_info.minor if minor < 7: from typing import GenericMeta as _GenericAlias else: from typing import _GenericAlias +from typing import Union + def check_path_s3(path: str) -> bool: """Checks the given path to see if it matches the s3:// regex @@ -37,7 +40,7 @@ def check_path_s3(path: str) -> bool: """ # Make a case insensitive s3 regex with single or double forward slash (due to posix stripping) - s3_regex = re.compile(r'(?i)^s3://?').search(path) + s3_regex = re.compile(r"(?i)^s3://?").search(path) # If it returns an object then the path is an s3 style reference return s3_regex is not None @@ -57,7 +60,7 @@ def _is_spock_instance(__obj: object): bool """ - return (__obj.__module__ == 'spock.backend.attr.config') and attr.has(__obj) + return (__obj.__module__ == "spock.backend.config") and attr.has(__obj) def make_argument(arg_name, arg_type, parser): @@ -80,15 +83,22 @@ def make_argument(arg_name, arg_type, parser): # For generic alias we take the input string and use a custom type callable to convert if isinstance(arg_type, _GenericAlias): parser.add_argument(arg_name, required=False, type=_handle_generic_type_args) + # For Unions -- python 3.6 can't deal with them correctly -- use the same ast method that generics require + elif ( + hasattr(arg_type, "__origin__") + and (arg_type.__origin__ is Union) + and (minor < 7) + ): + parser.add_argument(arg_name, required=False, type=_handle_generic_type_args) # For choice enums we need to check a few things first elif isinstance(arg_type, EnumMeta): type_set = list({type(val.value) for val in arg_type})[0] # if this is an enum of a class switch the type to str as this is how it gets matched - type_set = str if type_set.__name__ == 'type' else type_set + type_set = str if type_set.__name__ == "type" else type_set parser.add_argument(arg_name, required=False, type=type_set) # For booleans we map to store true elif arg_type == bool: - parser.add_argument(arg_name, required=False, action='store_true') + parser.add_argument(arg_name, required=False, action="store_true") # Else we are a simple base type which we can cast to else: parser.add_argument(arg_name, required=False, type=arg_type) @@ -98,7 +108,7 @@ def make_argument(arg_name, arg_type, parser): def _handle_generic_type_args(val): """Evaluates a string containing a Python literal - Seeing a list types will come in as string literal format, use ast to get the actual type + Seeing a list and tuple types will come in as string literal format, use ast to get the actual type *Args*: @@ -140,8 +150,8 @@ def make_blank_git(out_dict): out_dict: output dictionary with added git info """ - for key in ('BRANCH', 'COMMIT SHA', 'STATUS', 'ORIGIN'): - out_dict.update({f'# Git {key}': 'UNKNOWN'}) + for key in ("BRANCH", "COMMIT SHA", "STATUS", "ORIGIN"): + out_dict.update({f"# Git {key}": "UNKNOWN"}) return out_dict @@ -161,23 +171,38 @@ def add_repo_info(out_dict): repo = git.Repo(os.getcwd(), search_parent_directories=True) # Check if we are really in a detached head state as later info will fail if we are if minor < 7: - head_result = subprocess.run('git rev-parse --abbrev-ref --symbolic-full-name HEAD', stdout=subprocess.PIPE, - shell=True, check=False) + head_result = subprocess.run( + "git rev-parse --abbrev-ref --symbolic-full-name HEAD", + stdout=subprocess.PIPE, + shell=True, + check=False, + ) else: - head_result = subprocess.run('git rev-parse --abbrev-ref --symbolic-full-name HEAD', capture_output=True, - shell=True, check=False) - if head_result.stdout.decode().rstrip('\n') == 'HEAD': + head_result = subprocess.run( + "git rev-parse --abbrev-ref --symbolic-full-name HEAD", + capture_output=True, + shell=True, + check=False, + ) + if head_result.stdout.decode().rstrip("\n") == "HEAD": out_dict = make_blank_git(out_dict) else: - out_dict.update({'# Git Branch': repo.active_branch.name}) - out_dict.update({'# Git Commit': repo.active_branch.commit.hexsha}) - out_dict.update({'# Git Date': repo.active_branch.commit.committed_datetime}) - if len(repo.untracked_files) > 0 or len(repo.active_branch.commit.diff(None)) > 0: - git_status = 'DIRTY' + out_dict.update({"# Git Branch": repo.active_branch.name}) + out_dict.update({"# Git Commit": repo.active_branch.commit.hexsha}) + out_dict.update( + {"# Git Date": repo.active_branch.commit.committed_datetime} + ) + if ( + len(repo.untracked_files) > 0 + or len(repo.active_branch.commit.diff(None)) > 0 + ): + git_status = "DIRTY" else: - git_status = 'CLEAN' - out_dict.update({'# Git Status': git_status}) - out_dict.update({'# Git Origin': repo.active_branch.commit.repo.remotes.origin.url}) + git_status = "CLEAN" + out_dict.update({"# Git Status": git_status}) + out_dict.update( + {"# Git Origin": repo.active_branch.commit.repo.remotes.origin.url} + ) except git.InvalidGitRepositoryError: # pragma: no cover # But it's okay if we are not out_dict = make_blank_git(out_dict) @@ -195,16 +220,20 @@ def add_generic_info(out_dict): out_dict: output dictionary """ - out_dict.update({'# Machine FQDN': socket.getfqdn()}) - out_dict.update({'# Python Executable': sys.executable}) - out_dict.update({'# Python Version': f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}'}) - out_dict.update({'# Python Script': os.path.realpath(sys.argv[0])}) - out_dict.update({'# Run Date': strftime('%Y-%m-%d', localtime())}) - out_dict.update({'# Run Time': strftime('%H:%M:%S', localtime())}) + out_dict.update({"# Machine FQDN": socket.getfqdn()}) + out_dict.update({"# Python Executable": sys.executable}) + out_dict.update( + { + "# Python Version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + } + ) + out_dict.update({"# Python Script": os.path.realpath(sys.argv[0])}) + out_dict.update({"# Run Date": strftime("%Y-%m-%d", localtime())}) + out_dict.update({"# Run Time": strftime("%H:%M:%S", localtime())}) # Make a best effort to determine if run in a container - out_dict.update({'# Run w/ Docker': str(_maybe_docker())}) + out_dict.update({"# Run w/ Docker": str(_maybe_docker())}) # Make a best effort to determine if run in a container via k8s - out_dict.update({'# Run w/ Kubernetes': str(_maybe_k8s())}) + out_dict.update({"# Run w/ Kubernetes": str(_maybe_k8s())}) return out_dict @@ -223,10 +252,12 @@ def _maybe_docker(cgroup_path="/proc/self/cgroup"): """ # A few options seem to be at play here: # 1. Check for /.dockerenv -- docker should create this is any container - bool_env = os.path.exists('/.dockerenv') + bool_env = os.path.exists("/.dockerenv") # 2. Check /proc/self/cgroup for "docker" # https://stackoverflow.com/a/48710609 - bool_cgroup = os.path.isfile(cgroup_path) and any("docker" in line for line in open(cgroup_path)) + bool_cgroup = os.path.isfile(cgroup_path) and any( + "docker" in line for line in open(cgroup_path) + ) return bool_env or bool_cgroup @@ -247,7 +278,9 @@ def _maybe_k8s(cgroup_path="/proc/self/cgroup"): bool_env = os.environ.get("KUBERNETES_SERVICE_HOST") is not None # 2. Similar to docker check /proc/self/cgroup for "kubepods" # https://stackoverflow.com/a/48710609 - bool_cgroup = os.path.isfile(cgroup_path) and any("kubepods" in line for line in open(cgroup_path)) + bool_cgroup = os.path.isfile(cgroup_path) and any( + "kubepods" in line for line in open(cgroup_path) + ) return bool_env or bool_cgroup @@ -279,7 +312,7 @@ def deep_payload_update(source, updates): return source -def check_payload_overwrite(payload, updates, configs, overwrite=''): +def check_payload_overwrite(payload, updates, configs, overwrite=""): """Warns when parameters are overwritten across payloads as order will matter *Args*: @@ -294,11 +327,13 @@ def check_payload_overwrite(payload, updates, configs, overwrite=''): """ for k, v in updates.items(): if isinstance(v, dict) and v: - overwrite += (k + ":") + overwrite += k + ":" current_payload = {} if payload.get(k) is None else payload.get(k) check_payload_overwrite(current_payload, v, configs, overwrite=overwrite) else: if k in payload: - warn(f'Overriding an already set parameter {overwrite + k} from {configs}\n' - f'Be aware that value precedence is set by the order of the config files (last to load)...', - SyntaxWarning) + warn( + f"Overriding an already set parameter {overwrite + k} from {configs}\n" + f"Be aware that value precedence is set by the order of the config files (last to load)...", + SyntaxWarning, + ) diff --git a/tests/base/attr_configs_test.py b/tests/base/attr_configs_test.py index 82db5889..3d32ca86 100644 --- a/tests/base/attr_configs_test.py +++ b/tests/base/attr_configs_test.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 from enum import Enum @@ -88,6 +88,8 @@ class TypeConfig: tuple_p_str: Tuple[str, str] # Required Tuple -- Bool tuple_p_bool: Tuple[bool, bool] + # Required Tuple -- mixed + tuple_p_mixed: Tuple[int, float] # Required choice -- Str choice_p_str: StrChoice # Required choice -- Int diff --git a/tests/base/base_asserts_test.py b/tests/base/base_asserts_test.py index eb092e8a..8067ba41 100644 --- a/tests/base/base_asserts_test.py +++ b/tests/base/base_asserts_test.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 @@ -21,6 +21,7 @@ def test_all_set(self, arg_builder): assert arg_builder.TypeConfig.tuple_p_int == (10, 20) assert arg_builder.TypeConfig.tuple_p_str == ('Spock', 'Package') assert arg_builder.TypeConfig.tuple_p_bool == (True, False) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) assert arg_builder.TypeConfig.choice_p_str == 'option_1' assert arg_builder.TypeConfig.choice_p_int == 10 assert arg_builder.TypeConfig.choice_p_float == 10.0 diff --git a/tests/base/test_cmd_line.py b/tests/base/test_cmd_line.py index 699cd02b..9a3724f5 100644 --- a/tests/base/test_cmd_line.py +++ b/tests/base/test_cmd_line.py @@ -21,6 +21,7 @@ def arg_builder(monkeypatch): '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.tuple_p_mixed', '(5, 11.5)', '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", '--TypeConfig.choice_p_str', 'option_2', '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', @@ -28,6 +29,7 @@ def arg_builder(monkeypatch): '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", '--TypeConfig.list_choice_p_int', '[20]', '--TypeConfig.list_choice_p_float', '[20.0]', + '--TypeConfig.class_enum', 'NestedStuff', '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", @@ -48,6 +50,69 @@ def test_class_overrides(self, arg_builder): assert arg_builder.TypeConfig.tuple_p_int == (11, 21) assert arg_builder.TypeConfig.tuple_p_str == ('Hooray', 'Working') assert arg_builder.TypeConfig.tuple_p_bool == (False, True) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) + assert arg_builder.TypeConfig.choice_p_str == 'option_2' + assert arg_builder.TypeConfig.choice_p_int == 20 + assert arg_builder.TypeConfig.choice_p_float == 20.0 + assert arg_builder.TypeConfig.list_list_p_int == [[11, 21], [11, 21]] + assert arg_builder.TypeConfig.list_choice_p_str == ['option_2'] + assert arg_builder.TypeConfig.list_list_choice_p_str == [['option_2'], ['option_2']] + assert arg_builder.TypeConfig.list_choice_p_int == [20] + assert arg_builder.TypeConfig.list_choice_p_float == [20.0] + assert arg_builder.TypeConfig.class_enum.one == 12 + assert arg_builder.TypeConfig.class_enum.two == 'ancora' + assert arg_builder.NestedListStuff[0].one == 11 + assert arg_builder.NestedListStuff[0].two == 'Hooray' + assert arg_builder.NestedListStuff[1].one == 21 + assert arg_builder.NestedListStuff[1].two == 'Working' + + +class TestClassOnlyCmdLine: + """Testing command line overrides""" + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', + '--TypeConfig.bool_p', '--TypeConfig.int_p', '11', '--TypeConfig.float_p', '11.0', + '--TypeConfig.string_p', 'Hooray', + '--TypeConfig.list_p_float', '[11.0, 21.0]', '--TypeConfig.list_p_int', '[11, 21]', + '--TypeConfig.list_p_str', "['Hooray', 'Working']", + '--TypeConfig.list_p_bool', '[False, True]', + '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', + '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", + '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.tuple_p_mixed', '(5, 11.5)', + '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", + '--TypeConfig.choice_p_str', 'option_2', + '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', + '--TypeConfig.list_choice_p_str', "['option_2']", + '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", + '--TypeConfig.list_choice_p_int', '[20]', + '--TypeConfig.list_choice_p_float', '[20.0]', + '--TypeConfig.class_enum', 'NestedStuff', + '--TypeConfig.nested', 'NestedStuff', + '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', + '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', + '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", + ]) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') + return config.generate() + + def test_class_overrides(self, arg_builder): + assert arg_builder.TypeConfig.bool_p is True + assert arg_builder.TypeConfig.int_p == 11 + assert arg_builder.TypeConfig.float_p == 11.0 + assert arg_builder.TypeConfig.string_p == 'Hooray' + assert arg_builder.TypeConfig.list_p_float == [11.0, 21.0] + assert arg_builder.TypeConfig.list_p_int == [11, 21] + assert arg_builder.TypeConfig.list_p_str == ['Hooray', 'Working'] + assert arg_builder.TypeConfig.list_p_bool == [False, True] + assert arg_builder.TypeConfig.tuple_p_float == (11.0, 21.0) + assert arg_builder.TypeConfig.tuple_p_int == (11, 21) + assert arg_builder.TypeConfig.tuple_p_str == ('Hooray', 'Working') + assert arg_builder.TypeConfig.tuple_p_bool == (False, True) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) assert arg_builder.TypeConfig.choice_p_str == 'option_2' assert arg_builder.TypeConfig.choice_p_int == 20 assert arg_builder.TypeConfig.choice_p_float == 20.0 diff --git a/tests/base/test_config_arg_builder.py b/tests/base/test_config_arg_builder.py index 73295b30..58def648 100644 --- a/tests/base/test_config_arg_builder.py +++ b/tests/base/test_config_arg_builder.py @@ -6,6 +6,18 @@ import sys +class TestBasic(AllTypes): + """Testing basic functionality""" + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test.yaml']) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig) + return config.generate() + + class TestNoCmdLineKwarg(AllTypes): """Testing to see that the kwarg no cmd line works""" @staticmethod diff --git a/tests/base/test_type_specific.py b/tests/base/test_type_specific.py index e353e174..b3ecea04 100644 --- a/tests/base/test_type_specific.py +++ b/tests/base/test_type_specific.py @@ -62,11 +62,11 @@ def test_enum_class_missing(self, monkeypatch): ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') -class TestMixedGeneric: - def test_mixed_generic(self, monkeypatch): - with monkeypatch.context() as m: - with pytest.raises(TypeError): - @spock - class GenericFail: - generic_fail: Tuple[List[int], List[int], int] +# class TestMixedGeneric: +# def test_mixed_generic(self, monkeypatch): +# with monkeypatch.context() as m: +# with pytest.raises(TypeError): +# @spock +# class GenericFail: +# generic_fail: Tuple[List[int], List[int], int] diff --git a/tests/base/test_writers.py b/tests/base/test_writers.py index 33a1e99d..398074e0 100644 --- a/tests/base/test_writers.py +++ b/tests/base/test_writers.py @@ -43,10 +43,9 @@ def test_yaml_file_writer_create(self, monkeypatch, tmp_path): with monkeypatch.context() as m: m.setattr(sys, 'argv', ['', '--config', './tests/conf/yaml/test.yaml']) - config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder', - create_save_path=True) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder') # Test the chained version - config.save(user_specified_path=f'{tmp_path}/tmp', file_extension='.yaml').generate() + config.save(user_specified_path=f'{tmp_path}/tmp', create_save_path=True, file_extension='.yaml').generate() check_path = f'{str(tmp_path)}/tmp/*.yaml' fname = glob.glob(check_path)[0] with open(fname, 'r') as fin: @@ -64,7 +63,6 @@ def test_yaml_file_writer_save_path(self, monkeypatch): # Test the chained version now = datetime.datetime.now() curr_int_time = int(f'{now.year}{now.month}{now.day}{now.hour}{now.second}') - config_values = config.save(file_extension='.yaml', file_name=f'pytest.{curr_int_time}').generate() yaml_regex = re.compile(fr'pytest.{curr_int_time}.' fr'[a-fA-F0-9]{{8}}-[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{4}}-' @@ -101,7 +99,10 @@ def test_yaml_file_writer(self, monkeypatch, tmp_path): config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder') # Test the chained version with pytest.raises(FileNotFoundError): - config.save(user_specified_path=f'{str(tmp_path)}/foo.bar/fizz.buzz/', file_extension='.yaml').generate() + config.save( + user_specified_path=f'{str(tmp_path)}/foo.bar/fizz.buzz/', file_extension='.yaml', + create_save_path=False + ).generate() class TestInvalidExtensionTypeRaise: diff --git a/tests/conf/json/test.json b/tests/conf/json/test.json index 8ac6ff3e..e81dff46 100644 --- a/tests/conf/json/test.json +++ b/tests/conf/json/test.json @@ -12,6 +12,7 @@ "tuple_p_int": [10, 20], "tuple_p_str": ["Spock", "Package"], "tuple_p_bool": [true, false], + "tuple_p_mixed": [5, 11.5], "choice_p_str": "option_1", "choice_p_int": 10, "choice_p_float": 10.0, diff --git a/tests/conf/toml/test.toml b/tests/conf/toml/test.toml index 85bcdad0..6e9e0a77 100644 --- a/tests/conf/toml/test.toml +++ b/tests/conf/toml/test.toml @@ -25,6 +25,8 @@ tuple_p_int = [10, 20] tuple_p_str = ["Spock", "Package"] # Required Tuple -- Bool tuple_p_bool = [true, false] +# Required Tuple -- mixed +tuple_p_mixed = [5, 11.5] # Required Choice -- Str type choice_p_str = 'option_1' # Required Choice -- Int diff --git a/tests/conf/yaml/inherited.yaml b/tests/conf/yaml/inherited.yaml index 4d03e7ce..de292a22 100644 --- a/tests/conf/yaml/inherited.yaml +++ b/tests/conf/yaml/inherited.yaml @@ -24,6 +24,8 @@ tuple_p_float: [10.0, 20.0] tuple_p_int: [10, 20] # Required Tuple -- Str tuple_p_str: [Spock, Package] +# Required Tuple -- mixed +tuple_p_mixed: [5, 11.5] # Required Tuple -- Bool tuple_p_bool: [True, False] # Required Choice -- Str diff --git a/tests/conf/yaml/test.yaml b/tests/conf/yaml/test.yaml index b344ad5c..983a0b00 100644 --- a/tests/conf/yaml/test.yaml +++ b/tests/conf/yaml/test.yaml @@ -26,6 +26,8 @@ tuple_p_int: [10, 20] tuple_p_str: [Spock, Package] # Required Tuple -- Bool tuple_p_bool: [True, False] +# Required Tuple -- mixed +tuple_p_mixed: [5, 11.5] # Required Choice -- Str choice_p_str: option_1 # Required Choice -- Int diff --git a/tests/conf/yaml/test_class.yaml b/tests/conf/yaml/test_class.yaml index 88867675..93eb81e7 100644 --- a/tests/conf/yaml/test_class.yaml +++ b/tests/conf/yaml/test_class.yaml @@ -2,7 +2,7 @@ ### Required or Boolean Base Types ### TypeConfig: # Boolean - Set - bool_p_set: true + bool_p: true # Required Int int_p: 10 # Required Float @@ -46,7 +46,7 @@ TypeConfig: # Nested List configuration nested_list: NestedListStuff # Class Enum - class_enum: NestedStuff + class_enum: NestedListStuff NestedListStuff: - one: 10 two: hello diff --git a/tests/conf/yaml/test_hp.yaml b/tests/conf/yaml/test_hp.yaml new file mode 100644 index 00000000..8e15bf82 --- /dev/null +++ b/tests/conf/yaml/test_hp.yaml @@ -0,0 +1,32 @@ +# Test conf for all hyper-parameters +HPOne: + hp_int: + type: int + bounds: [ 10, 100 ] + log_scale: false + hp_float: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: false + hp_int_log: + type: int + bounds: [ 10, 100 ] + log_scale: true + hp_float_log: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: true + +HPTwo: + hp_choice_int: + type: int + choices: [10, 20, 40, 80 ] + hp_choice_float: + type: float + choices: [ 10.0, 20.0, 40.0, 80.0 ] + hp_choice_bool: + type: bool + choices: [ true, false ] + hp_choice_str: + type: str + choices: [ "hello", "ciao", "bonjour" ] \ No newline at end of file diff --git a/tests/conf/yaml/test_hp_cast.yaml b/tests/conf/yaml/test_hp_cast.yaml new file mode 100644 index 00000000..d8869880 --- /dev/null +++ b/tests/conf/yaml/test_hp_cast.yaml @@ -0,0 +1,32 @@ +# Test conf for all hyper-parameters +HPOne: + hp_int: + type: int + bounds: [ 10, 100 ] + log_scale: false + hp_float: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: false + hp_int_log: + type: int + bounds: [ 10, 100 ] + log_scale: true + hp_float_log: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: true + +HPTwo: + hp_choice_int: + type: int + choices: ["hello", "ciao", "bonjour" ] + hp_choice_float: + type: float + choices: [ 10.0, 20.0, 40.0, 80.0 ] + hp_choice_bool: + type: bool + choices: [ true, false ] + hp_choice_str: + type: str + choices: [ "hello", "ciao", "bonjour" ] diff --git a/tests/conf/yaml/test_hp_cast_bounds.yaml b/tests/conf/yaml/test_hp_cast_bounds.yaml new file mode 100644 index 00000000..c7cdb592 --- /dev/null +++ b/tests/conf/yaml/test_hp_cast_bounds.yaml @@ -0,0 +1,32 @@ +# Test conf for all hyper-parameters +HPOne: + hp_int: + type: int + bounds: [ 10, 100 ] + log_scale: false + hp_float: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: false + hp_int_log: + type: int + bounds: [ 'foo', 'bar' ] + log_scale: true + hp_float_log: + type: float + bounds: [ 10.0, 100.0 ] + log_scale: true + +HPTwo: + hp_choice_int: + type: int + choices: [10, 20, 40, 80 ] + hp_choice_float: + type: float + choices: [ 10.0, 20.0, 40.0, 80.0 ] + hp_choice_bool: + type: bool + choices: [ true, false ] + hp_choice_str: + type: str + choices: [ "hello", "ciao", "bonjour" ] \ No newline at end of file diff --git a/tests/conf/yaml/test_hp_compose.yaml b/tests/conf/yaml/test_hp_compose.yaml new file mode 100644 index 00000000..827b147f --- /dev/null +++ b/tests/conf/yaml/test_hp_compose.yaml @@ -0,0 +1,8 @@ +config: [test_hp.yaml] + +# Test conf for all hyper-parameters +HPOne: + hp_int: + type: int + bounds: [ 20, 200 ] + log_scale: false \ No newline at end of file diff --git a/tests/conf/yaml/test_optuna.yaml b/tests/conf/yaml/test_optuna.yaml new file mode 100644 index 00000000..b935e52d --- /dev/null +++ b/tests/conf/yaml/test_optuna.yaml @@ -0,0 +1,12 @@ +############################### +# optuna simple sklearn example +############################### + +LogisticRegressionHP: + c: + type: float + bounds: [1E-07, 10.0] + log_scale: true + solver: + type: str + choices: ["lbfgs", "saga"] \ No newline at end of file diff --git a/tests/s3/test_io.py b/tests/s3/test_io.py index 0f3b0262..089b0de4 100644 --- a/tests/s3/test_io.py +++ b/tests/s3/test_io.py @@ -2,7 +2,7 @@ import datetime from tests.base.base_asserts_test import * from spock.builder import ConfigArgBuilder -from spock.addons import S3Config +from spock.addons.s3 import S3Config from tests.base.attr_configs_test import * from tests.s3.fixtures_test import * import re diff --git a/tests/s3/test_raises.py b/tests/s3/test_raises.py index 6415c05d..bf636678 100644 --- a/tests/s3/test_raises.py +++ b/tests/s3/test_raises.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import datetime from spock.builder import ConfigArgBuilder -from spock.addons import S3Config +from spock.addons.s3 import S3Config from tests.base.attr_configs_test import * from tests.s3.fixtures_test import * import sys diff --git a/tests/tune/__init__.py b/tests/tune/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/tests/tune/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/tests/tune/attr_configs_test.py b/tests/tune/attr_configs_test.py new file mode 100644 index 00000000..0d62e82f --- /dev/null +++ b/tests/tune/attr_configs_test.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +from spock.addons.tune import spockTuner +from spock.addons.tune import ChoiceHyperParameter +from spock.addons.tune import RangeHyperParameter + + +@spockTuner +class HPOne: + hp_int: RangeHyperParameter + hp_float: RangeHyperParameter + hp_int_log: RangeHyperParameter + hp_float_log: RangeHyperParameter + + +@spockTuner +class HPTwo: + hp_choice_int: ChoiceHyperParameter + hp_choice_float: ChoiceHyperParameter + hp_choice_bool: ChoiceHyperParameter + hp_choice_str: ChoiceHyperParameter + + +@spockTuner +class LogisticRegressionHP: + c: RangeHyperParameter + solver: ChoiceHyperParameter diff --git a/tests/tune/base_asserts_test.py b/tests/tune/base_asserts_test.py new file mode 100644 index 00000000..5a3c1a3e --- /dev/null +++ b/tests/tune/base_asserts_test.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from math import log10 + + +class AllTypes: + def test_hp_one(self, arg_builder): + assert arg_builder._tune_namespace.HPOne.hp_int.bounds == (10, 100) + assert arg_builder._tune_namespace.HPOne.hp_int.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_int_log.bounds == (10, 100) + assert arg_builder._tune_namespace.HPOne.hp_int_log.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int_log.log_scale is True + assert arg_builder._tune_namespace.HPOne.hp_float.bounds == (10.0, 100.0) + assert arg_builder._tune_namespace.HPOne.hp_float.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_float_log.bounds == (10.0, 100.0) + assert arg_builder._tune_namespace.HPOne.hp_float_log.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float_log.log_scale is True + + def test_hp_two(self, arg_builder): + assert arg_builder._tune_namespace.HPTwo.hp_choice_int.type == 'int' + assert arg_builder._tune_namespace.HPTwo.hp_choice_int.choices == [10, 20, 40, 80] + assert arg_builder._tune_namespace.HPTwo.hp_choice_float.type == 'float' + assert arg_builder._tune_namespace.HPTwo.hp_choice_float.choices == [10.0, 20.0, 40.0, 80.0] + assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.type == 'bool' + assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.choices == [True, False] + assert arg_builder._tune_namespace.HPTwo.hp_choice_str.type == 'str' + assert arg_builder._tune_namespace.HPTwo.hp_choice_str.choices == ["hello", "ciao", "bonjour"] + + +class SampleTypes: + def test_sampling(self, arg_builder): + # Draw 100 random samples and make sure all fall within all of the bounds or sets + for _ in range(100): + hp_attrs = arg_builder.sample() + assert 10 <= hp_attrs.HPOne.hp_int <= 100 + assert isinstance(hp_attrs.HPOne.hp_int, int) is True + assert 10 <= hp_attrs.HPOne.hp_int_log <= 100 + assert isinstance(hp_attrs.HPOne.hp_int_log, int) is True + assert 10.0 <= hp_attrs.HPOne.hp_float <= 100.0 + assert isinstance(hp_attrs.HPOne.hp_float, float) is True + assert 10.0 <= hp_attrs.HPOne.hp_float_log <= 100.0 + assert isinstance(hp_attrs.HPOne.hp_float_log, float) is True + assert hp_attrs.HPTwo.hp_choice_int in [10, 20, 40, 80] + assert isinstance(hp_attrs.HPTwo.hp_choice_int, int) is True + assert hp_attrs.HPTwo.hp_choice_float in [10.0, 20.0, 40.0, 80.0] + assert isinstance(hp_attrs.HPTwo.hp_choice_float, float) is True + assert hp_attrs.HPTwo.hp_choice_bool in [True, False] + assert isinstance(hp_attrs.HPTwo.hp_choice_bool, bool) is True + assert hp_attrs.HPTwo.hp_choice_str in ["hello", "ciao", "bonjour"] + assert isinstance(hp_attrs.HPTwo.hp_choice_str, str) is True \ No newline at end of file diff --git a/tests/tune/test_cmd_line.py b/tests/tune/test_cmd_line.py new file mode 100644 index 00000000..5a8a63a5 --- /dev/null +++ b/tests/tune/test_cmd_line.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- +from tests.tune.attr_configs_test import * +import pytest +import sys +from spock.builder import ConfigArgBuilder +from spock.addons.tune import OptunaTunerConfig + + +class TestOptunaCmdLineOverride: + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp.yaml', + '--HPOne.hp_int.bounds', '(1, 1000)', + '--HPOne.hp_int_log.bounds', '(1, 1000)', + '--HPOne.hp_float.bounds', '(1.0, 1000.0)', + '--HPOne.hp_float_log.bounds', '(1.0, 1000.0)', + '--HPTwo.hp_choice_int.choices', '[1, 2, 4, 8]', + '--HPTwo.hp_choice_float.choices', '[1.0, 2.0, 4.0, 8.0]', + '--HPTwo.hp_choice_str.choices', "['is', 'it ', 'me', 'youre', 'looking', 'for']" + ]) + optuna_config = OptunaTunerConfig(study_name="Tests", direction="maximize") + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + return config + + def test_hp_one(self, arg_builder): + assert arg_builder._tune_namespace.HPOne.hp_int.bounds == (1, 1000) + assert arg_builder._tune_namespace.HPOne.hp_int.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_int_log.bounds == (1, 1000) + assert arg_builder._tune_namespace.HPOne.hp_int_log.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int_log.log_scale is True + assert arg_builder._tune_namespace.HPOne.hp_float.bounds == (1.0, 1000.0) + assert arg_builder._tune_namespace.HPOne.hp_float.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_float_log.bounds == (1.0, 1000.0) + assert arg_builder._tune_namespace.HPOne.hp_float_log.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float_log.log_scale is True + + def test_hp_two(self, arg_builder): + assert arg_builder._tune_namespace.HPTwo.hp_choice_int.type == 'int' + assert arg_builder._tune_namespace.HPTwo.hp_choice_int.choices == [1, 2, 4, 8] + assert arg_builder._tune_namespace.HPTwo.hp_choice_float.type == 'float' + assert arg_builder._tune_namespace.HPTwo.hp_choice_float.choices == [1.0, 2.0, 4.0, 8.0] + assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.type == 'bool' + assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.choices == [True, False] + assert arg_builder._tune_namespace.HPTwo.hp_choice_str.type == 'str' + assert arg_builder._tune_namespace.HPTwo.hp_choice_str.choices == ['is', 'it ', 'me', 'youre', 'looking', 'for'] + + def test_sampling(self, arg_builder): + # Draw 100 random samples and make sure all fall within all of the bounds or sets + for _ in range(100): + hp_attrs = arg_builder.sample() + assert 1 <= hp_attrs.HPOne.hp_int <= 1000 + assert isinstance(hp_attrs.HPOne.hp_int, int) is True + assert 1 <= hp_attrs.HPOne.hp_int_log <= 1000 + assert isinstance(hp_attrs.HPOne.hp_int_log, int) is True + assert 1.0 <= hp_attrs.HPOne.hp_float <= 1000.0 + assert isinstance(hp_attrs.HPOne.hp_float, float) is True + assert 1.0 <= hp_attrs.HPOne.hp_float_log <= 1000.0 + assert isinstance(hp_attrs.HPOne.hp_float_log, float) is True + assert hp_attrs.HPTwo.hp_choice_int in [1, 2, 4, 8] + assert isinstance(hp_attrs.HPTwo.hp_choice_int, int) is True + assert hp_attrs.HPTwo.hp_choice_float in [1.0, 2.0, 4.0, 8.0] + assert isinstance(hp_attrs.HPTwo.hp_choice_float, float) is True + assert hp_attrs.HPTwo.hp_choice_bool in [True, False] + assert isinstance(hp_attrs.HPTwo.hp_choice_bool, bool) is True + assert hp_attrs.HPTwo.hp_choice_str in ['is', 'it ', 'me', 'youre', 'looking', 'for'] + assert isinstance(hp_attrs.HPTwo.hp_choice_str, str) is True \ No newline at end of file diff --git a/tests/tune/test_optuna.py b/tests/tune/test_optuna.py new file mode 100644 index 00000000..6eb2b523 --- /dev/null +++ b/tests/tune/test_optuna.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +import datetime +from tests.tune.base_asserts_test import * +from tests.tune.attr_configs_test import * +import pytest +import os +import re +import sys +from spock.builder import ConfigArgBuilder +from spock.addons.tune import OptunaTunerConfig +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + + +class TestOptunaBasic(AllTypes): + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp.yaml']) + optuna_config = OptunaTunerConfig(study_name="Basic Tests", direction="maximize") + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + return config + + +class TestOptunaCompose(AllTypes): + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp_compose.yaml']) + optuna_config = OptunaTunerConfig(study_name="Basic Tests", direction="maximize") + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + return config + + def test_hp_one(self, arg_builder): + assert arg_builder._tune_namespace.HPOne.hp_int.bounds == (20, 200) + assert arg_builder._tune_namespace.HPOne.hp_int.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_int_log.bounds == (10, 100) + assert arg_builder._tune_namespace.HPOne.hp_int_log.type == 'int' + assert arg_builder._tune_namespace.HPOne.hp_int_log.log_scale is True + assert arg_builder._tune_namespace.HPOne.hp_float.bounds == (10.0, 100.0) + assert arg_builder._tune_namespace.HPOne.hp_float.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float.log_scale is False + assert arg_builder._tune_namespace.HPOne.hp_float_log.bounds == (10.0, 100.0) + assert arg_builder._tune_namespace.HPOne.hp_float_log.type == 'float' + assert arg_builder._tune_namespace.HPOne.hp_float_log.log_scale is True + + +class TestOptunaSample(SampleTypes): + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp.yaml']) + optuna_config = OptunaTunerConfig(study_name="Sample Tests", direction="maximize") + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + return config + + +class TestOptunaSaveTopLevel: + def test_save_top_level(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_optuna.yaml']) + # Optuna config -- this will internally spawn the study object for the define-and-run style which will be returned + # as part of the call to sample() + optuna_config = OptunaTunerConfig( + study_name="Iris Logistic Regression Tests", direction="maximize" + ) + now = datetime.datetime.now() + curr_int_time = int(f'{now.year}{now.month}{now.day}{now.hour}{now.second}') + config = ConfigArgBuilder(LogisticRegressionHP).tuner(optuna_config).save( + user_specified_path="/tmp", file_name=f'pytest.{curr_int_time}', + ).sample() + # Verify the sample was written out to file + yaml_regex = re.compile(fr'pytest.{curr_int_time}.' + fr'[a-fA-F0-9]{{8}}-[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{4}}-' + fr'[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{12}}.spock.cfg.yaml') + matches = [re.fullmatch(yaml_regex, val) for val in os.listdir('/tmp') + if re.fullmatch(yaml_regex, val) is not None] + fname = f'/tmp/{matches[0].string}' + assert os.path.exists(fname) + with open(fname, 'r') as fin: + print(fin.read()) + # Clean up if assert is good + if os.path.exists(fname): + os.remove(fname) + return config + + +class TestIrisOptuna: + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_optuna.yaml']) + # Optuna config -- this will internally spawn the study object for the define-and-run style which will be returned + # as part of the call to sample() + optuna_config = OptunaTunerConfig( + study_name="Iris Logistic Regression Tests", direction="maximize" + ) + config = ConfigArgBuilder(LogisticRegressionHP).tuner(optuna_config) + return config + + def test_iris(self, arg_builder): + # Load the iris data + X, y = load_iris(return_X_y=True) + # Split the Iris data + X_train, X_valid, y_train, y_valid = train_test_split(X, y) + + # Now we iterate through a bunch of optuna trials + for _ in range(10): + # The crux of spock support -- call save w/ the add_tuner_sample flag to write the current draw to file and + # then call save to return the composed Spockspace of the fixed parameters and the sampled parameters + # Under the hood spock uses the define-and-run Optuna interface -- thus it handled the underlying 'ask' call + # and returns the necessary trial object in the return dictionary to call 'tell' with the study object + now = datetime.datetime.now() + curr_int_time = int(f'{now.year}{now.month}{now.day}{now.hour}{now.second}') + hp_attrs = arg_builder.save( + add_tuner_sample=True, user_specified_path="/tmp", file_name=f'pytest.{curr_int_time}', + ).sample() + # Use the currently sampled parameters in a simple LogisticRegression from sklearn + clf = LogisticRegression( + C=hp_attrs.LogisticRegressionHP.c, + solver=hp_attrs.LogisticRegressionHP.solver, + ) + clf.fit(X_train, y_train) + val_acc = clf.score(X_valid, y_valid) + # Get the status of the tuner -- this dict will contain all the objects needed to update + tuner_status = arg_builder.tuner_status + # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study + # object + tuner_status["study"].tell(tuner_status["trial"], val_acc) + # Always save the current best set of hyper-parameters + arg_builder.save_best(user_specified_path='/tmp', file_name=f'pytest') + # Verify the sample was written out to file + yaml_regex = re.compile(fr'pytest.{curr_int_time}.hp.sample.[0-9]+.' + fr'[a-fA-F0-9]{{8}}-[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{4}}-' + fr'[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{12}}.spock.cfg.yaml') + matches = [re.fullmatch(yaml_regex, val) for val in os.listdir('/tmp') + if re.fullmatch(yaml_regex, val) is not None] + fname = f'/tmp/{matches[0].string}' + assert os.path.exists(fname) + with open(fname, 'r') as fin: + print(fin.read()) + # Clean up if assert is good + if os.path.exists(fname): + os.remove(fname) + + best_config, best_metric = arg_builder.best + print(f'Best HP Config:\n{best_config}') + print(f'Best Metric: {best_metric}') + # Verify the sample was written out to file + yaml_regex = re.compile(fr'pytest.hp.best.' + fr'[a-fA-F0-9]{{8}}-[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{4}}-' + fr'[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{12}}.spock.cfg.yaml') + matches = [re.fullmatch(yaml_regex, val) for val in os.listdir('/tmp') + if re.fullmatch(yaml_regex, val) is not None] + fname = f'/tmp/{matches[0].string}' + assert os.path.exists(fname) + with open(fname, 'r') as fin: + print(fin.read()) + # Clean up if assert is good + if os.path.exists(fname): + os.remove(fname) diff --git a/tests/tune/test_raises.py b/tests/tune/test_raises.py new file mode 100644 index 00000000..db660449 --- /dev/null +++ b/tests/tune/test_raises.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +from tests.tune.attr_configs_test import * +import pytest +import sys +from spock.builder import ConfigArgBuilder +import optuna + + +class TestIncorrectTunerConfig: + def test_incorrect_tuner_config(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp.yaml']) + optuna_config = optuna.create_study(study_name="Tests", direction='minimize') + with pytest.raises(TypeError): + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + + +class TestInvalidCastChoice: + def test_invalid_cast_choice(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp_cast.yaml']) + optuna_config = optuna.create_study(study_name="Tests", direction='minimize') + with pytest.raises(TypeError): + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) + + +class TestInvalidCastRange: + def test_invalid_cast_range(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_hp_cast_bounds.yaml']) + optuna_config = optuna.create_study(study_name="Tests", direction='minimize') + with pytest.raises(ValueError): + config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config)