Skip to content
This repository has been archived by the owner on May 8, 2023. It is now read-only.

Commit

Permalink
JSON Schema Support (#123)
Browse files Browse the repository at this point in the history
* yaml validation

* fix lint and typing

* add IO tests

* update installs

* Update meta.yaml

Co-authored-by: ravi-mosaicml <ravi@mosaicml.com>

* Update setup.py

Co-authored-by: ravi-mosaicml <ravi@mosaicml.com>

* Update yahp/hparams.py

Co-authored-by: ravi-mosaicml <ravi@mosaicml.com>

* resolve multiple comments

* support autoyahp

* add optional support

* support registries

* switch to anyOf

* nits and kitchen sink test

Co-authored-by: ravi-mosaicml <ravi@mosaicml.com>
  • Loading branch information
mvpatel2000 and ravi-mosaicml committed Jul 22, 2022
1 parent a3d21be commit 8d1880e
Show file tree
Hide file tree
Showing 8 changed files with 481 additions and 14 deletions.
1 change: 1 addition & 0 deletions meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ requirements:
- python >=3.7
- pyyaml >=5.4.1
- ruamel.yaml >=0.17.10
- jsonschema >=4.7.2, <4.8

test:
requires:
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

exec(open('yahp/version.py', 'r', encoding='utf-8').read())

install_requires = ['PyYAML>=5.4.1', 'ruamel.yaml>=0.17.10', 'docstring_parser>=0.14.1,<=0.15']
install_requires = [
'PyYAML>=5.4.1', 'ruamel.yaml>=0.17.10', 'docstring_parser>=0.14.1,<=0.15', 'jsonschema>=4.7.2,<4.8'
]

extra_deps = {}

Expand Down
226 changes: 226 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import contextlib
import json
import os
import pathlib
import textwrap
from typing import Type

import pytest
import yaml
from jsonschema import ValidationError

from tests.yahp_fixtures import (BearsHparams, ChoiceHparamParent, KitchenSinkHparams, PrimitiveHparam,
ShavingBearsHparam)
from yahp.hparams import Hparams


@pytest.mark.parametrize('hparam_class,success,data', [
[
PrimitiveHparam, True,
textwrap.dedent("""
---
intfield: 1
strfield: hello
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
random_item2: two
random_item3: true
random_item4: 0.1
random_subdict:
random_subdict_item5: 12
random_subdict_item6: 1337
random_list:
- 1
- 3
- 99
random_list_of_dict:
- sub_dict: 12
sub_dict_item: 1
- sub_dict2: 14
sub_dict_item: 43
""")
],
[
PrimitiveHparam, False,
textwrap.dedent("""
---
strfield: hello
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
random_item2: two
random_item3: true
random_item4: 0.1
random_subdict:
random_subdict_item5: 12
random_subdict_item6: 1337
random_list:
- 1
- 3
- 99
random_list_of_dict:
- sub_dict: 12
sub_dict_item: 1
- sub_dict2: 14
sub_dict_item: 43
""")
],
[
KitchenSinkHparams,
True,
textwrap.dedent("""
---
required_int_field: 1
nullable_required_int_field:
required_bool_field: False
nullable_required_bool_field:
required_enum_field_list:
- RED
- red
required_enum_field_with_default: green
nullable_required_enum_field: None
nullable_required_enum_field:
required_union_bool_str_field: hi
nullable_required_union_bool_str_field:
required_int_list_field:
- 1
- 2
- 42
nullable_required_list_int_field:
- 24
nullable_required_list_union_bool_str_field:
required_list_union_bool_str_field:
- True
- hi
- False
- bye
required_subhparams_field: { default_false: False, default_true: True }
nullable_required_subhparams_field:
required_subhparams_field_list:
- { default_false: False }
- {}
nullable_required_subhparams_field_list:
required_choice:
one:
commonfield: True
intfield: 3
nullable_required_choice:
required_choice_list: []
nullable_required_choice_list:
optional_float_field_default_1: 1.1
"""),
],
[ChoiceHparamParent, True,
textwrap.dedent("""
---
commonfield: True
""")],
[ChoiceHparamParent, False,
textwrap.dedent("""
---
commonfield: 1
""")],
[
ShavingBearsHparam, True,
textwrap.dedent("""
---
parameters:
random_field:
shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
other_random_field: "cool"
""")
],
[BearsHparams, True, textwrap.dedent("""
---
bears:
""")],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- unshaved_bears:
second_action: "Procure bears"
third_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
- unshaved_bears:
second_action: "Procure bears"
third_action: "Release bears into wild with stylish new haircuts"
""")
],
])
def test_validate_json_schema_from_strings(hparam_class: Type[Hparams], success: bool, data: str):
with contextlib.nullcontext() if success else pytest.raises(ValidationError):
hparam_class.validate_yaml(data=yaml.safe_load(data))


@pytest.mark.parametrize('hparam_class,success,file', [
[ShavingBearsHparam, True,
os.path.join(os.path.dirname(__file__), 'inheritance/shaving_bears.yaml')],
])
def test_validate_json_schema_from_file(hparam_class: Type[Hparams], success: bool, file: str):
with contextlib.nullcontext() if success else pytest.raises(ValidationError):
hparam_class.validate_yaml(f=file)


@pytest.mark.parametrize('hparam_class', [
ShavingBearsHparam,
ChoiceHparamParent,
PrimitiveHparam,
KitchenSinkHparams,
BearsHparams,
])
def test_write_and_read_json_schema_from_name(hparam_class: Type[Hparams], tmp_path: pathlib.Path):
file = os.path.join(tmp_path, 'schema.json')
hparam_class.dump_jsonschema(file)
with open(file) as f:
loaded_schema = json.load(f)
generated_schema = hparam_class.get_json_schema()
assert loaded_schema == generated_schema


@pytest.mark.parametrize('hparam_class', [
ShavingBearsHparam,
ChoiceHparamParent,
PrimitiveHparam,
KitchenSinkHparams,
BearsHparams,
])
def test_write_and_read_json_schema_from_file(hparam_class: Type[Hparams], tmp_path: pathlib.Path):
file = os.path.join(tmp_path, 'schema.json')
with open(file, 'w') as f:
hparam_class.dump_jsonschema(f)
with open(file) as f:
loaded_schema = json.load(f)
generated_schema = hparam_class.get_json_schema()
assert loaded_schema == generated_schema
67 changes: 65 additions & 2 deletions tests/yahp_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import textwrap
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Any, Dict, List, NamedTuple, Optional, Union, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast

import pytest
import yaml

import yahp as hp
from yahp.hparams import Hparams
from yahp.types import JSON


Expand Down Expand Up @@ -118,7 +119,7 @@ def primitive_yaml_input(hparams_tempdir: pathlib.Path) -> YamlInput:
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: pytorch_lightning
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
Expand Down Expand Up @@ -591,3 +592,65 @@ def optional_required_yaml_input(hparams_tempdir) -> YamlInput:
return generate_named_tuple_from_data(hparams_tempdir=hparams_tempdir,
input_data=data,
filepath='optional_required.yaml')


@dataclass
class ShavedBearsHparam(hp.Hparams):
first_action: str = hp.required(doc='str field')
last_action: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.first_action, str)
assert isinstance(self.last_action, str)
super().validate()


@dataclass
class UnshavedBearsHparam(hp.Hparams):
second_action: str = hp.required(doc='str field')
third_action: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.second_action, str)
assert isinstance(self.third_action, str)
super().validate()


@dataclass
class ParametersHparam(hp.Hparams):
random_field: Optional[int] = hp.required(doc='int field')
shaved_bears: ShavedBearsHparam = hp.required(doc='ShavedBears Hparams')
other_random_field: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.random_field, int)
assert isinstance(self.shaved_bears, ParametersHparam)
self.shaved_bears.validate()
assert isinstance(self.other_random_field, str)
super().validate()


@dataclass
class ShavingBearsHparam(hp.Hparams):
parameters: ParametersHparam = hp.required(doc='Parameters Hparams')

def validate(self):
assert isinstance(self.parameters, ParametersHparam)
self.parameters.validate()
super().validate()


bears_registry: Dict[str, Type[hp.Hparams]] = {
'shaved_bears': ShavedBearsHparam,
'unshaved_bears': UnshavedBearsHparam,
}


@dataclass
class BearsHparams(hp.Hparams):

hparams_registry = {
'bears': bears_registry,
}

bears: Optional[List[Hparams]] = hp.required(doc='bear field')
11 changes: 8 additions & 3 deletions yahp/create_object/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def get_hparams_file_from_cli(
cli_args: List[str],
argparse_name_registry: ArgparseNameRegistry,
argument_parsers: List[argparse.ArgumentParser],
) -> Tuple[Optional[str], Optional[str]]:
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
parser = argparse.ArgumentParser(add_help=False)
argument_parsers.append(parser)
argparse_name_registry.reserve('f', 'file', 'd', 'dump')
argparse_name_registry.reserve('f', 'file', 'd', 'dump', 'validate')
parser.add_argument('-f',
'--file',
type=str,
Expand All @@ -139,8 +139,13 @@ def get_hparams_file_from_cli(
metavar='stdout',
help='Dump the resulting Hparams to the specified YAML file (defaults to `stdout`) and exit.',
)
parser.add_argument(
'--validate',
action='store_true',
help='Whether to validate YAML against Hparams.',
)
parsed_args, cli_args[:] = parser.parse_known_args(cli_args)
return parsed_args.file, parsed_args.dump
return parsed_args.file, parsed_args.dump, parsed_args.validate


def get_commented_map_options_from_cli(
Expand Down
20 changes: 14 additions & 6 deletions yahp/create_object/create_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def _get_hparams(
)
if cm_options is not None:
output_file, interactive, add_docs = cm_options
print(f'Generating a template for {constructor.__name__}')
print(f'Generating a template for {constructor.__name__}...')
cls = ensure_hparams_cls(constructor)
if output_file == 'stdout':
cls.dump(add_docs=add_docs, interactive=interactive, output=sys.stdout)
Expand All @@ -655,13 +655,21 @@ def _get_hparams(
with open(output_file, 'x') as f:
cls.dump(add_docs=add_docs, interactive=interactive, output=f)
# exit so we don't attempt to parse and instantiate if generate template is passed
print()
print('Finished')
print('\nFinished')
sys.exit(0)

cli_f, output_f, validate = get_hparams_file_from_cli(cli_args=remaining_cli_args,
argparse_name_registry=argparse_name_registry,
argument_parsers=argparsers)
# Validate was specified, so only validate instead of instantiating
if validate:
print(f'Validating YAML against {constructor.__name__}...')
cls = ensure_hparams_cls(constructor)
cls.validate_yaml(f=cli_f)
# exit so we don't attempt to parse and instantiate
print('\nSuccessfully validated YAML!')
sys.exit(0)

cli_f, output_f = get_hparams_file_from_cli(cli_args=remaining_cli_args,
argparse_name_registry=argparse_name_registry,
argument_parsers=argparsers)
if cli_f is not None:
if f is not None:
raise ValueError('File cannot be specified via both function arguments and the CLI')
Expand Down
Loading

0 comments on commit 8d1880e

Please sign in to comment.