Skip to content
This repository has been archived by the owner on May 8, 2023. It is now read-only.

Json schema #129

Merged
merged 14 commits into from
Jul 28, 2022
1 change: 1 addition & 0 deletions meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ requirements:
- python >=3.7
- pyyaml >=5.4.1
- ruamel.yaml >=0.17.10
- jsonschema >=4.7.2, <4.8

test:
requires:
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

exec(open('yahp/version.py', 'r', encoding='utf-8').read())

install_requires = ['PyYAML>=5.4.1', 'ruamel.yaml>=0.17.10', 'docstring_parser>=0.14.1,<=0.15']
install_requires = [
'PyYAML>=5.4.1', 'ruamel.yaml>=0.17.10', 'docstring_parser>=0.14.1,<=0.15', 'jsonschema>=4.7.2,<4.8'
]

extra_deps = {}

Expand Down
226 changes: 226 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import contextlib
import json
import os
import pathlib
import textwrap
from typing import Type

import pytest
import yaml
from jsonschema import ValidationError

from tests.yahp_fixtures import (BearsHparams, ChoiceHparamParent, KitchenSinkHparams, PrimitiveHparam,
ShavingBearsHparam)
from yahp.hparams import Hparams


@pytest.mark.parametrize('hparam_class,success,data', [
[
PrimitiveHparam, True,
textwrap.dedent("""
---
intfield: 1
strfield: hello
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
random_item2: two
random_item3: true
random_item4: 0.1
random_subdict:
random_subdict_item5: 12
random_subdict_item6: 1337
random_list:
- 1
- 3
- 99
random_list_of_dict:
- sub_dict: 12
sub_dict_item: 1
- sub_dict2: 14
sub_dict_item: 43
""")
],
[
PrimitiveHparam, False,
textwrap.dedent("""
---
strfield: hello
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
random_item2: two
random_item3: true
random_item4: 0.1
random_subdict:
random_subdict_item5: 12
random_subdict_item6: 1337
random_list:
- 1
- 3
- 99
random_list_of_dict:
- sub_dict: 12
sub_dict_item: 1
- sub_dict2: 14
sub_dict_item: 43
""")
],
[
KitchenSinkHparams,
True,
textwrap.dedent("""
---
required_int_field: 1
nullable_required_int_field:
required_bool_field: False
nullable_required_bool_field:
required_enum_field_list:
- RED
- red
required_enum_field_with_default: green
nullable_required_enum_field: None
nullable_required_enum_field:
required_union_bool_str_field: hi
nullable_required_union_bool_str_field:
required_int_list_field:
- 1
- 2
- 42
nullable_required_list_int_field:
- 24
nullable_required_list_union_bool_str_field:
required_list_union_bool_str_field:
- True
- hi
- False
- bye
required_subhparams_field: { default_false: False, default_true: True }
nullable_required_subhparams_field:
required_subhparams_field_list:
- { default_false: False }
- {}
nullable_required_subhparams_field_list:
required_choice:
one:
commonfield: True
intfield: 3
nullable_required_choice:
required_choice_list: []
nullable_required_choice_list:
optional_float_field_default_1: 1.1
"""),
],
[ChoiceHparamParent, True,
textwrap.dedent("""
---
commonfield: True
""")],
[ChoiceHparamParent, False,
textwrap.dedent("""
---
commonfield: 1
""")],
[
ShavingBearsHparam, True,
textwrap.dedent("""
---
parameters:
random_field:
shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
other_random_field: "cool"
""")
],
[BearsHparams, True, textwrap.dedent("""
---
bears:
""")],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- unshaved_bears:
second_action: "Procure bears"
third_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
- unshaved_bears:
second_action: "Procure bears"
third_action: "Release bears into wild with stylish new haircuts"
""")
],
])
def test_validate_json_schema_from_strings(hparam_class: Type[Hparams], success: bool, data: str):
with contextlib.nullcontext() if success else pytest.raises(ValidationError):
hparam_class.validate_yaml(data=yaml.safe_load(data))


@pytest.mark.parametrize('hparam_class,success,file', [
[ShavingBearsHparam, True,
os.path.join(os.path.dirname(__file__), 'inheritance/shaving_bears.yaml')],
])
def test_validate_json_schema_from_file(hparam_class: Type[Hparams], success: bool, file: str):
with contextlib.nullcontext() if success else pytest.raises(ValidationError):
hparam_class.validate_yaml(f=file)


@pytest.mark.parametrize('hparam_class', [
ShavingBearsHparam,
ChoiceHparamParent,
PrimitiveHparam,
KitchenSinkHparams,
BearsHparams,
])
def test_write_and_read_json_schema_from_name(hparam_class: Type[Hparams], tmp_path: pathlib.Path):
file = os.path.join(tmp_path, 'schema.json')
hparam_class.dump_jsonschema(file)
with open(file) as f:
loaded_schema = json.load(f)
generated_schema = hparam_class.get_json_schema()
assert loaded_schema == generated_schema


@pytest.mark.parametrize('hparam_class', [
ShavingBearsHparam,
ChoiceHparamParent,
PrimitiveHparam,
KitchenSinkHparams,
BearsHparams,
])
def test_write_and_read_json_schema_from_file(hparam_class: Type[Hparams], tmp_path: pathlib.Path):
file = os.path.join(tmp_path, 'schema.json')
with open(file, 'w') as f:
hparam_class.dump_jsonschema(f)
with open(file) as f:
loaded_schema = json.load(f)
generated_schema = hparam_class.get_json_schema()
assert loaded_schema == generated_schema
67 changes: 65 additions & 2 deletions tests/yahp_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import textwrap
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Any, Dict, List, NamedTuple, Optional, Union, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast

import pytest
import yaml

import yahp as hp
from yahp.hparams import Hparams
from yahp.types import JSON


Expand Down Expand Up @@ -118,7 +119,7 @@ def primitive_yaml_input(hparams_tempdir: pathlib.Path) -> YamlInput:
floatfield: 0.5
boolfield: true
enumintfield: ONE
enumstringfield: pytorch_lightning
enumstringfield: mosaic
jsonfield:
empty_item: {}
random_item: 1
Expand Down Expand Up @@ -591,3 +592,65 @@ def optional_required_yaml_input(hparams_tempdir) -> YamlInput:
return generate_named_tuple_from_data(hparams_tempdir=hparams_tempdir,
input_data=data,
filepath='optional_required.yaml')


@dataclass
class ShavedBearsHparam(hp.Hparams):
first_action: str = hp.required(doc='str field')
last_action: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.first_action, str)
assert isinstance(self.last_action, str)
super().validate()


@dataclass
class UnshavedBearsHparam(hp.Hparams):
second_action: str = hp.required(doc='str field')
third_action: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.second_action, str)
assert isinstance(self.third_action, str)
super().validate()


@dataclass
class ParametersHparam(hp.Hparams):
random_field: Optional[int] = hp.required(doc='int field')
shaved_bears: ShavedBearsHparam = hp.required(doc='ShavedBears Hparams')
other_random_field: str = hp.required(doc='str field')

def validate(self):
assert isinstance(self.random_field, int)
assert isinstance(self.shaved_bears, ParametersHparam)
self.shaved_bears.validate()
assert isinstance(self.other_random_field, str)
super().validate()


@dataclass
class ShavingBearsHparam(hp.Hparams):
parameters: ParametersHparam = hp.required(doc='Parameters Hparams')

def validate(self):
assert isinstance(self.parameters, ParametersHparam)
self.parameters.validate()
super().validate()


bears_registry: Dict[str, Type[hp.Hparams]] = {
'shaved_bears': ShavedBearsHparam,
'unshaved_bears': UnshavedBearsHparam,
}


@dataclass
class BearsHparams(hp.Hparams):

hparams_registry = {
'bears': bears_registry,
}

bears: Optional[List[Hparams]] = hp.required(doc='bear field')
11 changes: 8 additions & 3 deletions yahp/create_object/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def get_hparams_file_from_cli(
cli_args: List[str],
argparse_name_registry: ArgparseNameRegistry,
argument_parsers: List[argparse.ArgumentParser],
) -> Tuple[Optional[str], Optional[str]]:
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
parser = argparse.ArgumentParser(add_help=False)
argument_parsers.append(parser)
argparse_name_registry.reserve('f', 'file', 'd', 'dump')
argparse_name_registry.reserve('f', 'file', 'd', 'dump', 'validate')
parser.add_argument('-f',
'--file',
type=str,
Expand All @@ -139,8 +139,13 @@ def get_hparams_file_from_cli(
metavar='stdout',
help='Dump the resulting Hparams to the specified YAML file (defaults to `stdout`) and exit.',
)
parser.add_argument(
'--validate',
action='store_true',
help='Whether to validate YAML against Hparams.',
)
parsed_args, cli_args[:] = parser.parse_known_args(cli_args)
return parsed_args.file, parsed_args.dump
return parsed_args.file, parsed_args.dump, parsed_args.validate


def get_commented_map_options_from_cli(
Expand Down
20 changes: 14 additions & 6 deletions yahp/create_object/create_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def _get_hparams(
)
if cm_options is not None:
output_file, interactive, add_docs = cm_options
print(f'Generating a template for {constructor.__name__}')
print(f'Generating a template for {constructor.__name__}...')
cls = ensure_hparams_cls(constructor)
if output_file == 'stdout':
cls.dump(add_docs=add_docs, interactive=interactive, output=sys.stdout)
Expand All @@ -655,13 +655,21 @@ def _get_hparams(
with open(output_file, 'x') as f:
cls.dump(add_docs=add_docs, interactive=interactive, output=f)
# exit so we don't attempt to parse and instantiate if generate template is passed
print()
print('Finished')
print('\nFinished')
sys.exit(0)

cli_f, output_f, validate = get_hparams_file_from_cli(cli_args=remaining_cli_args,
argparse_name_registry=argparse_name_registry,
argument_parsers=argparsers)
# Validate was specified, so only validate instead of instantiating
if validate:
print(f'Validating YAML against {constructor.__name__}...')
cls = ensure_hparams_cls(constructor)
cls.validate_yaml(f=cli_f)
# exit so we don't attempt to parse and instantiate
print('\nSuccessfully validated YAML!')
sys.exit(0)

cli_f, output_f = get_hparams_file_from_cli(cli_args=remaining_cli_args,
argparse_name_registry=argparse_name_registry,
argument_parsers=argparsers)
if cli_f is not None:
if f is not None:
raise ValueError('File cannot be specified via both function arguments and the CLI')
Expand Down
Loading