Skip to content
This repository has been archived by the owner on May 8, 2023. It is now read-only.

Commit

Permalink
Revert "Fix schema generator (#124)" (#127)
Browse files Browse the repository at this point in the history
This reverts commit 877474a.
  • Loading branch information
mvpatel2000 committed Jul 28, 2022
1 parent 877474a commit a88a1b5
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 176 deletions.
8 changes: 4 additions & 4 deletions tests/fixtures/commented_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ required_choice:
# EnumStringField{MOSAIC, PYTORCH_LIGHTNING} (Required). Description: enum int field.
enumstringfield:
jsonfield: # JSON (Required). Description: Required json type.
boolfield: # bool (Required). Description: bool field.
boolfield: # int (Required). Description: int field.
three: # ChoiceThreeHparam
commonfield: # bool (Required). Description: bool common field.
# ChoiceHparamParent (Required). Description: choice Hparam field. Options: ChoiceOneHparam, ChoiceTwoHparam.
Expand All @@ -77,7 +77,7 @@ required_choice:
# EnumStringField{MOSAIC, PYTORCH_LIGHTNING} (Required). Description: enum int field.
enumstringfield:
jsonfield: # JSON (Required). Description: Required json type.
boolfield: # bool (Required). Description: bool field.
boolfield: # int (Required). Description: int field.
strfield: # str (Required). Description: str field.
# Optional[ChoiceHparamParent] (Required). Description: choice Hparam field. Options: ChoiceOneHparam, ChoiceTwoHparam, ChoiceThreeHparam.
nullable_required_choice:
Expand All @@ -100,7 +100,7 @@ required_choice_list:
# EnumStringField{MOSAIC, PYTORCH_LIGHTNING} (Required). Description: enum int field.
enumstringfield:
jsonfield: # JSON (Required). Description: Required json type.
boolfield: # bool (Required). Description: bool field.
boolfield: # int (Required). Description: int field.
three: # ChoiceThreeHparam
commonfield: # bool (Required). Description: bool common field.
# ChoiceHparamParent (Required). Description: choice Hparam field. Options: ChoiceOneHparam, ChoiceTwoHparam.
Expand All @@ -119,7 +119,7 @@ required_choice_list:
# EnumStringField{MOSAIC, PYTORCH_LIGHTNING} (Required). Description: enum int field.
enumstringfield:
jsonfield: # JSON (Required). Description: Required json type.
boolfield: # bool (Required). Description: bool field.
boolfield: # int (Required). Description: int field.
strfield: # str (Required). Description: str field.
# Optional[List[ChoiceHparamParent]] (Required). Description: choice Hparam field. Options: ChoiceOneHparam, ChoiceTwoHparam, ChoiceThreeHparam.
nullable_required_choice_list:
Expand Down
36 changes: 0 additions & 36 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,42 +178,6 @@
third_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
- shaved_bears+first:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
- shaved_bears+second:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, False,
textwrap.dedent("""
---
bears:
- shaved_bearsfirst:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
- shaved_bearssecond:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
""")
],
[
BearsHparams, True,
textwrap.dedent("""
---
bears:
shaved_bears:
first_action: "Procure bears"
last_action: "Release bears into wild with stylish new haircuts"
""")
],
])
def test_validate_json_schema_from_strings(hparam_class: Type[Hparams], success: bool, data: str):
with contextlib.nullcontext() if success else pytest.raises(ValidationError):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_yahp_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_register_new_hparam_choice(choice_one_yaml_input: YamlInput):
assert isinstance(choice_one_hparam.choice, ChoiceOneHparam)

# Check that new registered hparams can be created
root_hparams_data['choice'] = {'empty': {}}
root_hparams_data['choice'] = {'empty': None}
choice_empty = ChoiceHparamRoot.create(data=root_hparams_data)

assert isinstance(choice_empty.choice, EmptyHparam)
Expand Down
4 changes: 2 additions & 2 deletions tests/yahp_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def choice_one_hparams(choice_one_yaml_input: YamlInput) -> ChoiceOneHparam:
@dataclass
class ChoiceTwoHparam(ChoiceHparamParent):
primitive_hparam: PrimitiveHparam = hp.required(doc='Primitive Hparams')
boolfield: bool = hp.required(doc='bool field')
boolfield: int = hp.required(doc='int field')

def validate(self):
assert isinstance(self.boolfield, bool)
Expand Down Expand Up @@ -411,7 +411,7 @@ def optional_field_empty_object_yaml_input(hparams_tempdir: pathlib.Path) -> Yam

@pytest.fixture
def optional_field_null_object_yaml_input(hparams_tempdir: pathlib.Path) -> YamlInput:
data = {'choice': {'one': {}}}
data = {'choice': {'one': None}}
return generate_named_tuple_from_data(hparams_tempdir=hparams_tempdir,
input_data=data,
filepath='optional_field_null_object.yaml')
Expand Down
13 changes: 6 additions & 7 deletions yahp/create_object/create_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,21 +661,20 @@ def _get_hparams(
cli_f, output_f, validate = get_hparams_file_from_cli(cli_args=remaining_cli_args,
argparse_name_registry=argparse_name_registry,
argument_parsers=argparsers)

if cli_f is not None:
if f is not None:
raise ValueError('File cannot be specified via both function arguments and the CLI')
f = cli_f

# Validate was specified, so only validate instead of instantiating
if validate:
print(f'Validating YAML against {constructor.__name__}...')
cls = ensure_hparams_cls(constructor)
cls.validate_yaml(f=f)
cls.validate_yaml(f=cli_f)
# exit so we don't attempt to parse and instantiate
print('\nSuccessfully validated YAML!')
sys.exit(0)

if cli_f is not None:
if f is not None:
raise ValueError('File cannot be specified via both function arguments and the CLI')
f = cli_f

if f is not None:
if data is not None:
raise ValueError(
Expand Down
53 changes: 9 additions & 44 deletions yahp/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,16 @@ def validate(self):
)

@classmethod
def _build_json_schema(cls: Type[THparams], _cls_def: Dict[str, Any], allow_recursion: bool) -> None:
"""Recursive private helper for generating and returning a JSONSchema dictionary.
def get_json_schema(cls) -> Dict[str, Any]:
"""Generates and returns a JSONSchema dictionary."""

Args:
_cls_def (Optional[Dict[str, Any]]): Keeps a reference to previously built Hparmam
classes and enums which can be used with references to make schemas more concise
and readable.
allow_recursion (bool): Whether to recursively parse subclasses.
"""
res = {
'type': 'object',
'properties': {},
'additionalProperties': False,
}
class_type_hints = get_type_hints(cls)
for f in sorted(fields(cls), key=lambda f: f.name):
for f in fields(cls):
if not f.init:
continue

Expand All @@ -358,50 +352,21 @@ def _build_json_schema(cls: Type[THparams], _cls_def: Dict[str, Any], allow_recu
hparams_type = type_helpers.HparamsType(class_type_hints[f.name])
# Name is found in registry, set possible values as types in a union type
if cls.hparams_registry and f.name in cls.hparams_registry and len(cls.hparams_registry[f.name].keys()) > 0:
res['properties'][f.name] = get_registry_json_schema(hparams_type, cls.hparams_registry[f.name],
_cls_def, allow_recursion)
res['properties'][f.name] = get_registry_json_schema(hparams_type, cls.hparams_registry[f.name])
else:
res['properties'][f.name] = get_type_json_schema(hparams_type, _cls_def, allow_recursion)
res['properties'][f.name] = get_type_json_schema(hparams_type)
res['properties'][f.name]['description'] = f.metadata['doc']

# Add schema to _cls_def. Hparams classes are always inserted into defs and referenced to
# in built schemas. If this function was called from `get_type_json_schema``, that function
# will add a reference to this class. If this was called from the root Hparams class the
# schema is being generated for, res will be pulled from _cls_defs
_cls_def[cls.__qualname__] = res

@classmethod
def get_json_schema(cls: Type[THparams]) -> Dict[str, Any]:
"""Generates and returns a JSONSchema dictionary."""
_cls_def = {}
cls._build_json_schema(_cls_def=_cls_def, allow_recursion=True)
res = _cls_def[cls.__qualname__]

# Delete top level name. By default, all Hparams classes are added to _cls_def. However,
# the top level Hparams class is not referenced anywhere (as it is the root), so we can
# remove it from _cls_def.
del _cls_def[cls.__qualname__]
# Add definitions to top level of schema
for key, value in _cls_def.items():
if '$defs' not in res:
res['$defs'] = {}
res['$defs'][key] = value

return res

@classmethod
def dump_jsonschema(cls: Type[THparams], f: Union[TextIO, str, pathlib.Path], **kwargs: Any):
"""Dump the JSONSchema to ``f``.
Args:
f (Union[str, None, TextIO, pathlib.PurePath], optional): Writes json to this file.
kwargs: (Any): Keyword args to be passed to `json.dump`.
"""
def dump_jsonschema(cls, f: Union[TextIO, str, pathlib.Path]):
"""Dump the JSONSchema to ``f``."""
if isinstance(f, TextIO) or isinstance(f, TextIOWrapper):
json.dump(cls.get_json_schema(), f, **kwargs)
json.dump(cls.get_json_schema(), f)
else:
with open(f, 'w') as file:
json.dump(cls.get_json_schema(), file, **kwargs)
json.dump(cls.get_json_schema(), file)

@classmethod
def validate_yaml(cls: Type[THparams],
Expand Down
107 changes: 25 additions & 82 deletions yahp/utils/json_schema_helpers.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,31 @@
from __future__ import annotations

import copy
import inspect
import re
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict

from yahp.utils import type_helpers


def get_registry_json_schema(f_type: type_helpers.HparamsType, registry: Dict[str, Any], _cls_def: Dict[str, Any],
allow_recursion: bool):
def get_registry_json_schema(f_type: type_helpers.HparamsType, registry: Dict[str, Any]):
"""Convert type into corresponding JSON Schema. As the given name is in the `hparams_registry`,
create objects for each possible entry in the registry and treat as union type.
Args:
f_type (HparamsType): The type to be parsed.
registry (Dict[str, Any]): A registry to unpack.
_cls_def ([Dict[str, Any]]): Keeps a reference to previously built Hparmam
classes and enums which can be used with references to make schemas more concise
and readable.
allow_recursion (bool): Indicates whether parent Hparam class was autoyahp generated
"""
res = {'anyOf': []}
for key in sorted(registry.keys()):
# Accept any string prefixed by the key. In yahp, a key can be specified multiple times using
# key+X syntax, so prefix checking is required
for key, value in registry.items():
res['anyOf'].append({
'type': 'object',
'properties': {
key: get_type_json_schema(type_helpers.HparamsType(registry[key]), _cls_def, allow_recursion)
},
'patternProperties': {
f'^{re.escape(key)}\\+':
get_type_json_schema(type_helpers.HparamsType(registry[key]), _cls_def, allow_recursion)
key: get_type_json_schema(type_helpers.HparamsType(value))
},
'additionalProperties': False,
})
return _check_for_list_and_optional(f_type, res)


def get_type_json_schema(f_type: type_helpers.HparamsType, _cls_def: Dict[str, Any], allow_recursion: bool):
# Import inside function to avoid circular dependencies
from yahp.hparams import Hparams
def get_type_json_schema(f_type: type_helpers.HparamsType):
"""Convert type into corresponding JSON Schema. We first check for union types and recursively
handle each component. If a type is not union, we know it is a singleton type, so it must be
either a primitive, Enum, JSON, or Hparam-like. Dictionaries are treated as JSON types, and
list and optionals are handled in a post-processing step in `_check_for_list_and_optional`.
Args:
f_type (HparamsType): The type to be parsed.
_cls_def ([Dict[str, Any]]): Keeps a reference to previously built Hparmam
classes and enums which can be used with references to make schemas more concise
and readable.
allow_recursion (bool): Indicates whether parent Hparam class was autoyahp generated
"""
# Import inside function to resolve circular dependencies
from yahp.auto_hparams import ensure_hparams_cls
Expand All @@ -65,7 +37,7 @@ def get_type_json_schema(f_type: type_helpers.HparamsType, _cls_def: Dict[str, A
# Add all union types using anyOf
res = {'anyOf': []}
for union_type in f_type.types:
res['anyOf'].append(get_type_json_schema(type_helpers.HparamsType(union_type), _cls_def, allow_recursion))
res['anyOf'].append(get_type_json_schema(type_helpers.HparamsType(union_type)))
# Primitive Types
elif f_type.type is str:
res = {'type': 'string'}
Expand All @@ -77,50 +49,19 @@ def get_type_json_schema(f_type: type_helpers.HparamsType, _cls_def: Dict[str, A
res = {'type': 'number'}
# Enum
elif inspect.isclass(f_type.type) and issubclass(f_type.type, Enum):
# Build schema and add to _cls_def if not present
if f_type.type.__qualname__ not in _cls_def:
# Get all possible keys and values which are of type str
names = list(f_type.type._member_map_.keys())
names.extend([
f_type.type._member_map_[key].value for key in names if type(f_type.type._member_map_[key].value) == str
])
names = sorted(list(set([name.upper() for name in names])))
# Build case insensitive regex so we match both lowercase or upper case
member_names: List[Dict[str, Any]] = [{
'type': 'string',
'pattern': f'(?i)^{re.escape(name)}$'
} for name in names]
# Add all non-str keys and values
enum_attributes = [name.value for name in f_type.type if type(name.value) != str]
enum_attributes = sorted(list(set(enum_attributes)), key=lambda x: str(x))
if len(enum_attributes) > 0:
member_names.append({'enum': enum_attributes})
# Build oneOf to create an enum which is case insensitive
res = {'oneOf': member_names}
_cls_def[f_type.type.__qualname__] = copy.deepcopy(res)
res = {'$ref': f'#/$defs/{f_type.type.__qualname__}'}
# JSON or unschemable types
# Enum attributes can either be specified lowercase or uppercase
member_names = [name.lower() for name in f_type.type._member_names_]
member_names.extend([name.upper() for name in f_type.type._member_names_])
res = {'enum': member_names}
# JSON
elif f_type.type == type_helpers._JSONDict:
res = {
'type': 'object',
}
# Hparam class
elif callable(f_type.type):
# Attempt to autoyahp if the parent class was not autoyahped or if the parameter is Hparams
# class.
if allow_recursion or inspect.isclass(f_type.type) and issubclass(f_type.type, Hparams):
hparam_class = ensure_hparams_cls(f_type.type)
# Disallow recursion if class was autoyahped
allow_recursion = hparam_class == f_type.type
# Build schema and add to _cls_def if not present. _build_json_schema adds to _cls_def
# internally, so we only need to call the function.
if hparam_class not in _cls_def:
hparam_class._build_json_schema(_cls_def=_cls_def, allow_recursion=allow_recursion)
res = {'$ref': f'#/$defs/{hparam_class.__qualname__}'}
# Otherwise, if we have a callable parameter of autoyahped class, either require None if
# its possible or throw an error.
else:
res = {'type': 'object'}
hparam_class = ensure_hparams_cls(f_type.type)
res = hparam_class.get_json_schema()
else:
raise ValueError('Unexpected type when constructing JSON Schema.')

Expand All @@ -130,19 +71,21 @@ def get_type_json_schema(f_type: type_helpers.HparamsType, _cls_def: Dict[str, A
def _check_for_list_and_optional(f_type: type_helpers.HparamsType, schema: Dict[str, Any]) -> Dict[str, Any]:
"""Wrap JSON Schema with list schema or optional schema if specified.
"""
if not f_type.is_list and not f_type.is_optional:
return schema

# Accept singletons
res = {'oneOf': [schema]}
res = schema
# Wrap type in list
if f_type.is_list:
res['oneOf'].append({
res = {
'type': 'array',
'items': schema,
})
'items': res,
}
# Wrap type for optional
if f_type.is_optional:
res['oneOf'].append({'type': 'null'})

res = {
'oneOf': [
{
'type': 'null'
},
res,
]
}
return res

0 comments on commit a88a1b5

Please sign in to comment.