Skip to content

Commit

Permalink
Callable Types & Post Init Hooks (#226)
Browse files Browse the repository at this point in the history
* Added support for simple `typing.Callable` types (WIP: advanced versions)
* Added support for post init hooks that allow for validation on parameters defined within `@spock` decorated classes. 
Additionally, added some common validation check to utils (within, greater than, less than, etc.)
* Updated unit tests to support Python 3.10
* Additional unit tests
* linted
  • Loading branch information
ncilfone authored Mar 11, 2022
1 parent 2efe7d0 commit 0d4b82a
Show file tree
Hide file tree
Showing 23 changed files with 695 additions and 61 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-pytest-s3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -26,7 +26,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}

- name: Install dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-pytest-tune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -26,7 +26,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}

- name: Install dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -26,7 +26,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}

- name: Install dependencies
run: |
Expand Down
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
<p align="center">
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache%202.0-9cf"/></a>
<a href="https://bestpractices.coreinfrastructure.org/projects/5551"><img src="https://bestpractices.coreinfrastructure.org/projects/5551/badge"/></a>
<a><img src="https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master"/></a>
<a href="https://coveralls.io/github/fidelity/spock?branch=master"><img src="https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master"/></a>
<a><img src="https://github.com/fidelity/spock/workflows/docs/badge.svg"/></a>
</p>

<p align="center">
<a><img src="https://img.shields.io/badge/python-3.6+-informational.svg"/></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"/></a>
<a href="https://badge.fury.io/py/spock-config"><img src="https://badge.fury.io/py/spock-config.svg"/></a>
<a href="https://coveralls.io/github/fidelity/spock?branch=master"><img src="https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master"/></a>
<a><img src="https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master"/></a>
<a><img src="https://github.com/fidelity/spock/workflows/docs/badge.svg"/></a>
<a href="https://pepy.tech/badge/spock-config"><img src="https://static.pepy.tech/personalized-badge/spock-config?period=total&units=international_system&left_color=grey&right_color=orange&left_text=Downloads"/></a>
</p>

<h3 align="center">
Expand Down Expand Up @@ -97,6 +101,12 @@ See [Releases](https://github.com/fidelity/spock/releases) for more information.

<details>

#### March 11th, 2022
* Added support for simple `typing.Callable` types (WIP: advanced versions)
* Added support for post init hooks that allow for validation on parameters defined within `@spock` decorated classes.
Additionally, added some common validation check to utils (within, greater than, less than, etc.)
* Updated unit tests to support Python 3.10

#### January 26th, 2022
* Added `evolve` support to the underlying `SpockBuilder` class. This provides functionality similar to the underlying
attrs library ([attrs.evolve](https://www.attrs.org/en/stable/api.html#attrs.evolve)). `evolve()` creates a new
Expand All @@ -110,12 +120,6 @@ passed into `*args` within the main `SpockBuilder` API
* Updated main API interface for better top-level imports (backwards compatible): `ConfigArgBuilder`->`SpockBuilder`
* Added stubs to the underlying decorator that should help with type hinting in VSCode (pylance/pyright)

#### December 14, 2021
* Refactored the backend to better handle nested dependencies (and for clarity)
* Refactored the docs to use Docusaurus

#### August 17, 2021
* Added hyper-parameter tuning backend support for Ax via Service API
</details>

## Original Implementation
Expand Down
2 changes: 1 addition & 1 deletion requirements/S3_REQUIREMENTS.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
boto3~=1.20
botocore~=1.24
hurry.filesize==0.9
hurry.filesize~=0.9
s3transfer~=0.5
14 changes: 10 additions & 4 deletions spock/backend/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# SPDX-License-Identifier: Apache-2.0

"""Handles the building/saving of the configurations from the Spock config classes"""

import sys
import typing
from abc import ABC, abstractmethod
from enum import EnumMeta
from typing import List
Expand All @@ -17,7 +18,7 @@
from spock.backend.spaces import BuilderSpace
from spock.backend.wrappers import Spockspace
from spock.graph import Graph
from spock.utils import make_argument
from spock.utils import _SpockVariadicGenericAlias, make_argument


class BaseBuilder(ABC): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -255,10 +256,11 @@ def _make_group_override_parser(parser, class_obj, class_name):
)
for val in class_obj.__attrs_attrs__:
val_type = val.metadata["type"] if "type" in val.metadata else val.type
# Check if the val type has __args__ -- this catches lists?
# Check if the val type has __args__ -- this catches GenericAlias classes
# TODO (ncilfone): Fix up this super super ugly logic
if (
hasattr(val_type, "__args__")
not isinstance(val_type, _SpockVariadicGenericAlias)
and hasattr(val_type, "__args__")
and ((list(set(val_type.__args__))[0]).__module__ == class_name)
and attr.has((list(set(val_type.__args__))[0]))
):
Expand All @@ -274,6 +276,10 @@ def _make_group_override_parser(parser, class_obj, class_name):
arg_name = f"--{str(attr_name)}.{val.name}"
val_type = str
group_parser = make_argument(arg_name, val_type, group_parser)
# This catches callables -- need to be of type str which will be use in importlib
elif isinstance(val.type, _SpockVariadicGenericAlias):
arg_name = f"--{str(attr_name)}.{val.name}"
group_parser = make_argument(arg_name, str, group_parser)
else:
arg_name = f"--{str(attr_name)}.{val.name}"
group_parser = make_argument(arg_name, val_type, group_parser)
Expand Down
3 changes: 3 additions & 0 deletions spock/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def _process_class(cls, kw_only: bool, make_init: bool, dynamic: bool):
auto_attribs=True,
init=make_init,
)
# Copy over the post init function
if hasattr(cls, "__post_hook__"):
obj.__post_hook__ = cls.__post_hook__
# For each class we dynamically create we need to register it within the system modules for pickle to work
setattr(sys.modules["spock"].backend.config, obj.__name__, obj)
# Swap the __doc__ string from cls to obj
Expand Down
85 changes: 82 additions & 3 deletions spock/backend/field_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,26 @@

"""Handles registering field attributes for spock classes -- deals with the recursive nature of dependencies"""

import importlib
import sys
from abc import ABC, abstractmethod
from enum import EnumMeta
from typing import List, Type

from attr import NOTHING, Attribute

from spock.args import SpockArguments
from spock.backend.spaces import AttributeSpace, BuilderSpace, ConfigSpace
from spock.exceptions import _SpockInstantiationError, _SpockNotOptionalError
from spock.utils import _check_iterable, _is_spock_instance, _is_spock_tune_instance
from spock.exceptions import (
_SpockInstantiationError,
_SpockNotOptionalError,
_SpockValueError,
)
from spock.utils import (
_check_iterable,
_is_spock_instance,
_is_spock_tune_instance,
_SpockVariadicGenericAlias,
)


class RegisterFieldTemplate(ABC):
Expand Down Expand Up @@ -318,6 +328,69 @@ def _handle_and_register_enum(
builder_space.spock_space[enum_cls.__name__] = attr_space.field


class RegisterCallableField(RegisterFieldTemplate):
"""Class that registers callable types
Attributes:
special_keys: dictionary to check special keys
"""

def __init__(self):
"""Init call to RegisterSimpleField
Args:
"""
super(RegisterCallableField, self).__init__()

def handle_attribute_from_config(
self, attr_space: AttributeSpace, builder_space: BuilderSpace
):
"""Handles setting a simple attribute when it is a spock class type
Args:
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
builder_space: named_tuple containing the arguments and spock_space
Returns:
"""
# These are always going to be strings... cast just in case
str_field = str(
builder_space.arguments[attr_space.config_space.name][
attr_space.attribute.name
]
)
module, fn = str_field.rsplit(".", 1)
try:
call_ref = getattr(importlib.import_module(module), fn)
attr_space.field = call_ref
except Exception as e:
raise _SpockValueError(
f"Attempted to import module {module} and callable {fn} however it could not be found on the current "
f"python path: {e}"
)

def handle_optional_attribute_type(
self, attr_space: AttributeSpace, builder_space: BuilderSpace
):
"""Not implemented for this type
Args:
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
builder_space: named_tuple containing the arguments and spock_space
Raises:
_SpockNotOptionalError
"""
print("hi")
raise _SpockNotOptionalError(
f"Parameter `{attr_space.attribute.name}` within `{attr_space.config_space.name}` is of "
f"type `{type(attr_space.attribute.type)}` which seems to be unsupported -- "
f"are you missing an @spock decorator on a base python class?"
)


class RegisterSimpleField(RegisterFieldTemplate):
"""Class that registers basic python types
Expand Down Expand Up @@ -606,6 +679,9 @@ def recurse_generate(cls, spock_cls, builder_space: BuilderSpace):
# References to tuner classes
elif _is_spock_tune_instance(attribute.type):
handler = RegisterTuneCls()
# References to callables
elif isinstance(attribute.type, _SpockVariadicGenericAlias):
handler = RegisterCallableField()
# Basic field
else:
handler = RegisterSimpleField()
Expand All @@ -617,6 +693,9 @@ def recurse_generate(cls, spock_cls, builder_space: BuilderSpace):
# error on instantiation
try:
spock_instance = spock_cls(**fields)
# If there is a __post_hook__ dunder method then call it
if hasattr(spock_cls, "__post_hook__"):
spock_instance.__post_hook__()
except Exception as e:
raise _SpockInstantiationError(
f"Spock class `{spock_cls.__name__}` could not be instantiated -- attrs message: {e}"
Expand Down
23 changes: 20 additions & 3 deletions spock/backend/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,31 @@ def _clean_output(self, out_dict):
for idx, list_val in enumerate(val):
tmp_dict = {}
for inner_key, inner_val in list_val.items():
tmp_dict = self._convert(tmp_dict, inner_val, inner_key)
tmp_dict = self._convert_tuples_2_lists(
tmp_dict, inner_val, inner_key
)
val[idx] = tmp_dict
clean_inner_dict = val
else:
for inner_key, inner_val in val.items():
clean_inner_dict = self._convert(
clean_inner_dict = self._convert_tuples_2_lists(
clean_inner_dict, inner_val, inner_key
)
clean_dict.update({key: clean_inner_dict})
return clean_dict

def _convert(self, clean_inner_dict, inner_val, inner_key):
def _convert_tuples_2_lists(self, clean_inner_dict, inner_val, inner_key):
"""Convert tuples to lists
Args:
clean_inner_dict: dictionary to update
inner_val: current value
inner_key: current key
Returns:
updated dictionary where tuples are cast back to lists
"""
# Convert tuples to lists so they get written correctly
if isinstance(inner_val, tuple):
clean_inner_dict.update(
Expand Down Expand Up @@ -277,6 +290,10 @@ def _recursively_handle_clean(
if repeat_flag:
clean_val = list(set(clean_val))[-1]
out_dict.update({key: clean_val})
# Catch any callables -- convert back to the str representation
elif callable(val):
call_2_str = f"{val.__module__}.{val.__name__}"
out_dict.update({key: call_2_str})
# If it's a spock class but has a parent then just use the class name to reference the values
elif (val_name in all_cls) and parent_name is not None:
out_dict.update({key: val_name})
Expand Down
Loading

0 comments on commit 0d4b82a

Please sign in to comment.