diff --git a/docs/advanced_features/Command-Line-Overrides.md b/docs/advanced_features/Command-Line-Overrides.md index 6a81cbe5..46efc070 100644 --- a/docs/advanced_features/Command-Line-Overrides.md +++ b/docs/advanced_features/Command-Line-Overrides.md @@ -32,7 +32,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' optimizer: Optimizer cache_path: Optional[str] @@ -70,14 +70,52 @@ But with command line overrides we can also pass parameter arguments to override file: ```bash -$ python tutorial.py --config tutorial.yaml --cache_path /tmp/trash +$ python tutorial.py --config tutorial.yaml --DataConfig.cache_path /tmp/trash ``` -Each parameter can be overridden at the global level or the class specific level with the syntax `--name.parameter`. For -instance, our previous example would override any parameters named `cache_path` regardless of what class they are -defined in. In this case `cache_path` in both `ModelConfig` and `DataConfig`. To override just a class specific value -we would use the class specific override: +Each parameter can be overridden **ONLY** at the class specific level with the syntax `--classname.parameter`. For +instance, our previous example would only override the `DataConfig.cache_path` and not the `ModelConfig.cache_path` even +though they have the same parameter name (due to the different class names). ```bash $ python tutorial.py --config tutorial.yaml --DataConfig.cache_path /tmp/trash +``` + +### Overriding List/Tuple of Repeated `@spock` Classes + +For `List` of Repeated `@spock` Classes the syntax is slightly different to allow for the repeated nature of the type. +Given the below example code: + +```python +from spock.config import spock +from typing import List + + +@spock +class NestedListStuff: + one: int + two: str + +@spock +class TypeConfig: + nested_list: List[NestedListStuff] # To Set Default Value append '= NestedListStuff' +``` + +With YAML definitions: + +```yaml +# Nested List configuration +nested_list: NestedListStuff +NestedListStuff: + - one: 10 + two: hello + - one: 20 + two: bye +``` + +We could override the parameters like so (note that the len must match the defined length from the YAML): + +```bash +$ python tutorial.py --config tutorial.yaml --TypeConfig.nested_list.NestedListStuff.one [1,2] \ +--TypeConfig.nested_list.NestedListStuff.two [ciao,ciao] ``` \ No newline at end of file diff --git a/docs/advanced_features/Defaults.md b/docs/advanced_features/Defaults.md index 23bc4d89..1304a759 100644 --- a/docs/advanced_features/Defaults.md +++ b/docs/advanced_features/Defaults.md @@ -34,7 +34,7 @@ class ModelConfig: lr: float = 0.01 n_features: int dropout: List[float] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' ``` diff --git a/docs/advanced_features/Inheritance.md b/docs/advanced_features/Inheritance.md index fd32cba2..38d2c566 100644 --- a/docs/advanced_features/Inheritance.md +++ b/docs/advanced_features/Inheritance.md @@ -38,7 +38,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' optimizer: Optimizer diff --git a/docs/advanced_features/Local-Definitions.md b/docs/advanced_features/Local-Definitions.md index f331ccd1..9e52d3fb 100644 --- a/docs/advanced_features/Local-Definitions.md +++ b/docs/advanced_features/Local-Definitions.md @@ -35,7 +35,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' optimizer: Optimizer cache_path: Optional[str] diff --git a/docs/advanced_features/Optional-Parameters.md b/docs/advanced_features/Optional-Parameters.md index 966076c9..899191a1 100644 --- a/docs/advanced_features/Optional-Parameters.md +++ b/docs/advanced_features/Optional-Parameters.md @@ -34,7 +34,7 @@ class ModelConfig: lr: float = 0.01 n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' ``` diff --git a/docs/advanced_features/Parameter-Groups.md b/docs/advanced_features/Parameter-Groups.md index 80f209ba..b14a6b09 100644 --- a/docs/advanced_features/Parameter-Groups.md +++ b/docs/advanced_features/Parameter-Groups.md @@ -36,7 +36,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' optimizer: Optimizer diff --git a/docs/basic_tutorial/Building.md b/docs/basic_tutorial/Building.md index 59059b66..26d55de5 100644 --- a/docs/basic_tutorial/Building.md +++ b/docs/basic_tutorial/Building.md @@ -25,7 +25,7 @@ class Activation(Enum): class ModelConfig: n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation ``` diff --git a/docs/basic_tutorial/Configuration-Files.md b/docs/basic_tutorial/Configuration-Files.md index 0b8ea765..9cf7240a 100644 --- a/docs/basic_tutorial/Configuration-Files.md +++ b/docs/basic_tutorial/Configuration-Files.md @@ -31,7 +31,7 @@ class Activation(Enum): class ModelConfig: n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation ``` diff --git a/docs/basic_tutorial/Define.md b/docs/basic_tutorial/Define.md index 63476c82..7f739429 100644 --- a/docs/basic_tutorial/Define.md +++ b/docs/basic_tutorial/Define.md @@ -20,9 +20,11 @@ standard library while `Enum` is within the `enum` standard library): | int | Optional[int] | Basic integer type parameter (e.g. 2) | | str | Optional[str] | Basic string type parameter (e.g. 'foo') | | List[type] | Optional[List[type]] | Basic list type parameter of base types such as int, float, etc. (e.g. [10.0, 2.0]) | -| Tuple[type] | Optional[Tuple[type]] | Basic tuple type parameter of base types such as int, float, etc. (e.g. (10, 2)) | +| Tuple[type] | Optional[Tuple[type]] | Basic tuple type parameter of base types such as int, float, etc. Length enforced unlike List. (e.g. (10, 2)) | | Enum | Optional[Enum] | Parameter that must be from a defined set of values of base types such as int, float, etc. | +Use `List` types when the length of the `Iterable` is not fixed and `Tuple` when length needs to be strictly enforced. + Parameters that are specified without the `Optional[]` type will be considered **REQUIRED** and therefore will raise an Exception if not value is specified. @@ -58,7 +60,7 @@ class Activation(Enum): class ModelConfig: n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation ``` @@ -103,7 +105,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation ``` @@ -123,11 +125,11 @@ spock Basic Tutorial configuration(s): ModelConfig (Main model configuration for a basic neural net) - save_path Optional[SavePath] spock special keyword -- path to write out spock config state (default: None) - n_features int number of data features - dropout List[float] dropout rate for each layer - hidden_sizes Tuple[int] hidden size for each layer - activation Activation choice from the Activation enum of the activation function to use + save_path Optional[SavePath] spock special keyword -- path to write out spock config state (default: None) + n_features int number of data features + dropout List[float] dropout rate for each layer + hidden_sizes Tuple[int, int, int] hidden size for each layer + activation Activation choice from the Activation enum of the activation function to use Activation (Options for activation functions) relu str relu activation @@ -141,8 +143,9 @@ In another file let's write our simple neural network code: `basic_nn.py` Notice that even before we've built and linked all of the related `spock` components together we are referencing the parameters we have defined in our `spock` class. Below we are passing in the `ModelConfig` class as a parameter -`model_config` to the `__init__` function where we can then access the parameters with `.` notation. We could have -also passed in individual parameters instead if that is the preferred syntax. +`model_config` to the `__init__` function where we can then access the parameters with `.` notation (if we import +the `ModelConfig` class here and add it as a type hint to `model_config` most IDE auto-complete will work out of the +box). We could have also passed in individual parameters instead if that is the preferred syntax. ```python import torch.nn as nn diff --git a/docs/basic_tutorial/Saving.md b/docs/basic_tutorial/Saving.md index 174e6bf9..2b676454 100644 --- a/docs/basic_tutorial/Saving.md +++ b/docs/basic_tutorial/Saving.md @@ -24,7 +24,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation ``` @@ -83,7 +83,26 @@ def main(): # A simple description description = 'spock Tutorial' # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder(ModelConfig, desc=description, create_save_path=True).save().generate() + config = ConfigArgBuilder(ModelConfig, desc=description).save(create_save_path=True).generate() + # One can now access the Spock config object by class name with the returned namespace + # For instance... + print(config.ModelConfig) +``` + +### Override UUID Filename + +By default `spock` uses an automatically generated UUID as the filename when saving. This can be overridden with the +`file_name` keyword argument. The specified filename will be appended with .spock.cfg.file_extension (e.g. .yaml, +.toml or. json). + +In: `tutorial.py` + +```python +def main(): + # A simple description + description = 'spock Tutorial' + # Build out the parser by passing in Spock config objects as *args after description + config = ConfigArgBuilder(ModelConfig, desc=description).save(file_name='cool_name_here').generate() # One can now access the Spock config object by class name with the returned namespace # For instance... print(config.ModelConfig) diff --git a/examples/tutorial/advanced/tutorial.py b/examples/tutorial/advanced/tutorial.py index e292768e..13a01a3c 100644 --- a/examples/tutorial/advanced/tutorial.py +++ b/examples/tutorial/advanced/tutorial.py @@ -25,7 +25,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: Optional[List[float]] - hidden_sizes: Tuple[int] = (32, 32, 32) + hidden_sizes: Tuple[int, int, int] = (32, 32, 32) activation: Activation = 'relu' optimizer: Optimizer cache_path: Optional[str] diff --git a/examples/tutorial/basic/tutorial.py b/examples/tutorial/basic/tutorial.py index aac3da3c..19c953ae 100644 --- a/examples/tutorial/basic/tutorial.py +++ b/examples/tutorial/basic/tutorial.py @@ -35,7 +35,7 @@ class ModelConfig: save_path: SavePath n_features: int dropout: List[float] - hidden_sizes: Tuple[int] + hidden_sizes: Tuple[int, int, int] activation: Activation diff --git a/spock/backend/attr/builder.py b/spock/backend/attr/builder.py index 35cf87a8..9b76ee7d 100644 --- a/spock/backend/attr/builder.py +++ b/spock/backend/attr/builder.py @@ -46,54 +46,7 @@ def print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1): sys.exit(exit_code) def _handle_help_info(self): - # List to catch Enum classes and handle post spock wrapped attr classes - enum_list = [] - for attrs_class in self.input_classes: - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(attrs_class) - print(' ' + attrs_class.__name__ + f' ({class_doc})') - # Keep a running info_dict of all the attribute level info - info_dict = {} - for val in attrs_class.__attrs_attrs__: - # If the type is an enum we need to handle it outside of this attr loop - # Match the style of nested enums and return a string of module.name notation - if isinstance(val.type, EnumMeta): - enum_list.append(f'{val.type.__module__}.{val.type.__name__}') - # if there is a type (implied Iterable) -- check it for nested Enums - nested_enums = self._extract_enum_types(val.metadata['type']) if 'type' in val.metadata else [] - if len(nested_enums) > 0: - enum_list.extend(nested_enums) - # Grab the base or type info depending on what is provided - type_string = repr(val.metadata['type']) if 'type' in val.metadata else val.metadata['base'] - # Regex out the typing info if present - type_string = re.sub(r'typing.', '', type_string) - # Regex out any nested_enums that have module path information - for enum_val in nested_enums: - split_enum = f"{'.'.join(enum_val.split('.')[:-1])}." - type_string = re.sub(split_enum, '', type_string) - # Regex the string to see if it matches any Enums in the __main__ module space - # for val in sys.modules - # Construct the type with the metadata - if 'optional' in val.metadata: - type_string = f"Optional[{type_string}]" - info_dict.update(self._match_attribute_docs(val.name, attr_docs, type_string, val.default)) - self._handle_attributes_print(info_dict=info_dict) - # Convert the enum list to a set to remove dupes and then back to a list so it is iterable - enum_list = list(set(enum_list)) - # Iterate any Enum type classes - for enum in enum_list: - enum = self._get_enum_from_sys_modules(enum) - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(enum) - print(' ' + enum.__name__ + f' ({class_doc})') - info_dict = {} - for val in enum: - info_dict.update(self._match_attribute_docs( - attr_name=val.name, - attr_docs=attr_docs, - attr_type_str=type(val.value).__name__ - )) - self._handle_attributes_print(info_dict=info_dict) + self._attrs_help(self.input_classes) def _handle_arguments(self, args, class_obj): attr_name = class_obj.__name__ diff --git a/spock/backend/base.py b/spock/backend/base.py index 42859906..f2784659 100644 --- a/spock/backend/base.py +++ b/spock/backend/base.py @@ -8,6 +8,7 @@ from abc import ABC from abc import abstractmethod import argparse +import attr from attr import NOTHING from enum import EnumMeta import os @@ -21,6 +22,7 @@ from spock.handlers import YAMLHandler from spock.utils import add_info from spock.utils import make_argument +from typing import List class Spockspace(argparse.Namespace): @@ -361,9 +363,8 @@ def get_config_paths(self): def _build_override_parsers(self, desc): """Creates parsers for command-line overrides - Builds the basic command line parser for configs and help, iterates through all defined attr to make - a general override parser, and then iterates through each attr instance to make namespace specific override - parsers + Builds the basic command line parser for configs and help then iterates through each attr instance to make + namespace specific cmd line override parsers *Args*: @@ -377,49 +378,13 @@ def _build_override_parsers(self, desc): parser = argparse.ArgumentParser(description=desc, add_help=False) parser.add_argument('-c', '--config', required=False, nargs='+', default=[]) parser.add_argument('-h', '--help', action='store_true') - # Build out a general parser for parent level attr - parser = self._make_general_override_parser(parser=parser, input_classes=self.input_classes) - # Build out each class specific parser + # Build out each class override specific parser for val in self.input_classes: parser = self._make_group_override_parser(parser=parser, class_obj=val) - # args = parser.parse_args() - args, _ = parser.parse_known_args(sys.argv) + args = parser.parse_args() return args - def _make_general_override_parser(self, parser, input_classes): - """Makes a general level override parser - - Flattens all the attrs into a single dictionary and makes a general level parser for the attr name - - *Args*: - - parser: argument parser - input_classes: list of input classes for a specific backend - - *Returns*: - - parser: argument parser with new general overrides - - """ - # Make all names list - all_attr = {} - for class_obj in input_classes: - for val in class_obj.__attrs_attrs__: - val_type = val.metadata['type'] if 'type' in val.metadata else val.type - if hasattr(all_attr, val.name): - if all_attr[val.name] is not val_type: - print(f"Warning: Ignoring general override for {val.name} as the class specific types differ") - else: - all_attr.update({val.name: val_type}) - self._check_protected_keys(all_attr) - group_parser = parser.add_argument_group(title="General Overrides") - for k, v in all_attr.items(): - arg_name = '--' + k - group_parser = make_argument(arg_name, v, group_parser) - return parser - - @staticmethod - def _make_group_override_parser(parser, class_obj): + def _make_group_override_parser(self, parser, class_obj): """Makes a name specific override parser for a given class obj Takes a class object of the backend and adds a new argument group with argument names given with name @@ -439,8 +404,16 @@ def _make_group_override_parser(parser, class_obj): group_parser = parser.add_argument_group(title=str(attr_name) + " Specific Overrides") for val in class_obj.__attrs_attrs__: val_type = val.metadata['type'] if 'type' in val.metadata else val.type - arg_name = '--' + str(attr_name) + '.' + val.name - group_parser = make_argument(arg_name, val_type, group_parser) + # Check if the val type has __args__ + # TODO (ncilfone): Fix up this super super ugly logic + if hasattr(val_type, '__args__') and ((list(set(val_type.__args__))[0]).__module__ == 'spock.backend.attr.config') and attr.has((list(set(val_type.__args__))[0])): + args = (list(set(val_type.__args__))[0]) + for inner_val in args.__attrs_attrs__: + arg_name = f"--{str(attr_name)}.{val.name}.{args.__name__}.{inner_val.name}" + group_parser = make_argument(arg_name, List[inner_val.type], group_parser) + else: + arg_name = f"--{str(attr_name)}.{val.name}" + group_parser = make_argument(arg_name, val_type, group_parser) return parser @staticmethod @@ -610,8 +583,8 @@ def _handle_attributes_print(self, info_dict): # Blank for spacing :-/ print('') - def _extract_enum_types(self, typed): - """Takes a high level type and recursively extracts any enum types + def _extract_other_types(self, typed): + """Takes a high level type and recursively extracts any enum or class types *Args*: @@ -619,28 +592,104 @@ def _extract_enum_types(self, typed): *Returns*: - return_list: list of nums (dot notation of module_path.enum_name) + return_list: list of nums (dot notation of module_path.enum_name or module_path.class_name) """ return_list = [] if hasattr(typed, '__args__'): for val in typed.__args__: - recurse_return = self._extract_enum_types(val) + recurse_return = self._extract_other_types(val) if isinstance(recurse_return, list): return_list.extend(recurse_return) else: - return_list.append(self._extract_enum_types(val)) - elif isinstance(typed, EnumMeta): + return_list.append(self._extract_other_types(val)) + elif isinstance(typed, EnumMeta) or (typed.__module__ == 'spock.backend.attr.config'): return f'{typed.__module__}.{typed.__name__}' return return_list + def _attrs_help(self, input_classes): + """Handles walking through a list classes to get help info + + For each class this function will search __doc__ and attempt to pull out help information for both the class + itself and each attribute within the class. If it finds a repeated class in a iterable object it will + recursively call self to handle information + + *Args*: + + input_classes: list of attr classes + + *Returns*: + + None + + """ + # List to catch Enums and classes and handle post spock wrapped attr classes + other_list = [] + covered_set = set() + for attrs_class in input_classes: + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(attrs_class) + print(' ' + attrs_class.__name__ + f' ({class_doc})') + # Keep a running info_dict of all the attribute level info + info_dict = {} + for val in attrs_class.__attrs_attrs__: + # If the type is an enum we need to handle it outside of this attr loop + # Match the style of nested enums and return a string of module.name notation + if isinstance(val.type, EnumMeta): + other_list.append(f'{val.type.__module__}.{val.type.__name__}') + # if there is a type (implied Iterable) -- check it for nested Enums or classes + nested_others = self._extract_other_types(val.metadata['type']) if 'type' in val.metadata else [] + if len(nested_others) > 0: + other_list.extend(nested_others) + # Grab the base or type info depending on what is provided + type_string = repr(val.metadata['type']) if 'type' in val.metadata else val.metadata['base'] + # Regex out the typing info if present + type_string = re.sub(r'typing.', '', type_string) + # Regex out any nested_others that have module path information + for other_val in nested_others: + split_other = f"{'.'.join(other_val.split('.')[:-1])}." + type_string = re.sub(split_other, '', type_string) + # Regex the string to see if it matches any Enums in the __main__ module space + # for val in sys.modules + # Construct the type with the metadata + if 'optional' in val.metadata: + type_string = f"Optional[{type_string}]" + info_dict.update(self._match_attribute_docs(val.name, attr_docs, type_string, val.default)) + # Add to covered so we don't print help twice in the case of some recursive nesting + covered_set.add(f'{attrs_class.__module__}.{attrs_class.__name__}') + self._handle_attributes_print(info_dict=info_dict) + # Convert the enum list to a set to remove dupes and then back to a list so it is iterable -- set diff to not + # repeat + other_list = list(set(other_list) - covered_set) + # Iterate any Enum type classes + for other in other_list: + # if it's longer than 2 then it's an embedded Spock class + if '.'.join(other.split('.')[:-1]) == 'spock.backend.attr.config': + class_type = self._get_from_sys_modules(other) + # Invoke recursive call for the class + self._attrs_help([class_type]) + # Fall back to enum style + else: + enum = self._get_from_sys_modules(other) + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(enum) + print(' ' + enum.__name__ + f' ({class_doc})') + info_dict = {} + for val in enum: + info_dict.update(self._match_attribute_docs( + attr_name=val.name, + attr_docs=attr_docs, + attr_type_str=type(val.value).__name__ + )) + self._handle_attributes_print(info_dict=info_dict) + @staticmethod - def _get_enum_from_sys_modules(enum_name): - """Gets the enum class from a dot notation name + def _get_from_sys_modules(cls_name): + """Gets the class from a dot notation name *Args*: - enum_name: dot notation enum name + cls_name: dot notation enum name *Returns*: @@ -648,7 +697,7 @@ def _get_enum_from_sys_modules(enum_name): """ # Split on dot notation - split_string = enum_name.split('.') + split_string = cls_name.split('.') module = None for idx, val in enumerate(split_string): # idx = 0 will always be a call to the sys.modules dict @@ -831,27 +880,68 @@ def _handle_overrides(self, payload, args): """ skip_keys = ['config', 'help'] for k, v in vars(args).items(): - # If the name has a . then we are at the class level so we need to get the dict and check - if len(k.split('.')) > 1: - dict_key = k.split('.')[0] - val_name = k.split('.')[1] - if k not in skip_keys and v is not None: - # Handle bool types slightly differently as they are store_true - if isinstance(vars(args)[k], bool): - if vars(args)[k] is not False: - payload = self._dict_payload_override(payload, dict_key, val_name, v) - else: - payload = self._dict_payload_override(payload, dict_key, val_name, v) - # else search the first level + if k not in skip_keys and v is not None: + payload = self._handle_payload_override(payload, k, v) + return payload + + @staticmethod + def _handle_payload_override(payload, key, value): + """Handles the complex logic needed for List[spock class] overrides + + Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str + names have to be mapped backed to sys.modules and can be set at either the general or class level. + + *Args*: + + payload: current payload dictionary + key: current arg key + value: value at current arg key + + *Returns*: + + payload: modified payload with overrides + + """ + key_split = key.split('.') + curr_ref = payload + for idx, split in enumerate(key_split): + # If the root isn't in the payload then it needs to be added but only for the first key split + if idx == 0 and (split not in payload): + payload.update({split: {}}) + # Check for curr_ref switch over -- verify by checking the sys modules names + if idx != 0 and (split in payload) and (isinstance(curr_ref, str)) and (hasattr(sys.modules['spock'].backend.attr.config, split)): + curr_ref = payload[split] + elif idx != 0 and (split in payload) and (isinstance(payload[split], str)) and (hasattr(sys.modules['spock'].backend.attr.config, payload[split])): + curr_ref = payload[split] + # elif check if it's the last value and figure out the override + elif idx == (len(key_split)-1): + # Handle bool(s) a bit differently as they are store_true + if isinstance(curr_ref, dict) and isinstance(value, bool): + if value is not False: + curr_ref[split] = value + # If we are at the dictionary level we should be able to just payload override + elif isinstance(curr_ref, dict) and not isinstance(value, bool): + curr_ref[split] = value + # If we are at a list level it must be some form of repeated class since this is the end of the class + # tree -- check the instance type but also make sure the cmd-line override is the correct len + elif isinstance(curr_ref, list) and len(value) == len(curr_ref): + # Walk the list and check for the key + for ref_idx, val in enumerate(curr_ref): + if split in val: + val[split] = value[ref_idx] + else: + raise ValueError(f'cmd-line override failed for {key} -- ' + f'Failed to find key {split} within lowest level List[Dict]') + elif isinstance(curr_ref, list) and len(value) != len(curr_ref): + raise ValueError(f'cmd-line override failed for {key} -- ' + f'Specified key {split} with len {len(value)} does not match len {len(curr_ref)} ' + f'of List[Dict]') + else: + raise ValueError(f'cmd-line override failed for {key} -- ' + f'Failed to find key {split} within lowest level Dict') + # If it's not keep walking the current payload else: - # Override the value in the payload if present - if k not in skip_keys and v is not None: - # Handle bool types slightly differently as they are store_true - if isinstance(vars(args)[k], bool): - if vars(args)[k] is not False: - payload.update({k: v}) - else: - payload.update({k: v}) + curr_ref = curr_ref[split] return payload @staticmethod @@ -873,7 +963,7 @@ def _dict_payload_override(payload, dict_key, val_name, value): payload: updated payload dictionary """ - if not hasattr(payload, dict_key): + if dict_key not in payload: payload.update({dict_key: {}}) - payload[dict_key].update({val_name: value}) + payload[dict_key][val_name] = value return payload diff --git a/spock/config.py b/spock/config.py index 1c91ebf5..28ee0114 100644 --- a/spock/config.py +++ b/spock/config.py @@ -6,6 +6,10 @@ """Creates the spock config decorator that wraps attrs""" from spock.backend.attr.config import spock_attr +from spock.utils import _is_spock_instance # Simplified decorator for attrs spock = spock_attr + +# Public alias for checking if an object is a @spock annotated class +isinstance_spock =_is_spock_instance diff --git a/spock/utils.py b/spock/utils.py index 681ae83f..b2e901ba 100644 --- a/spock/utils.py +++ b/spock/utils.py @@ -6,6 +6,7 @@ """Utility functions for Spock""" import ast +import attr from enum import EnumMeta import os import socket @@ -22,6 +23,24 @@ from typing import _GenericAlias +def _is_spock_instance(__obj: object): + """Checks if the object is a @spock decorated class + + Private interface that checks to see if the object passed in is registered within the spock module and also + is a class with attrs attributes (__attrs_attrs__) + + *Args*: + + __obj: class to inspect + + *Returns*: + + bool + + """ + return (__obj.__module__ == 'spock.backend.attr.config') and attr.has(__obj) + + def make_argument(arg_name, arg_type, parser): """Make argparser argument based on type @@ -34,7 +53,7 @@ def make_argument(arg_name, arg_type, parser): arg_type: type of the argument parser: current parser - Returns: + *Returns*: parser: updated argparser @@ -45,6 +64,8 @@ def make_argument(arg_name, arg_type, parser): # For choice enums we need to check a few things first elif isinstance(arg_type, EnumMeta): type_set = list({type(val.value) for val in arg_type})[0] + # if this is an enum of a class switch the type to str as this is how it gets matched + type_set = str if type_set.__name__ == 'type' else type_set parser.add_argument(arg_name, required=False, type=type_set) # For booleans we map to store true elif arg_type == bool: @@ -56,6 +77,19 @@ def make_argument(arg_name, arg_type, parser): def _handle_generic_type_args(val): + """Evaluates a string containing a Python literal + + Seeing a list types will come in as string literal format, use ast to get the actual type + + *Args*: + + val: string literal + + *Returns*: + + the underlying string literal type + + """ return ast.literal_eval(val) @@ -82,7 +116,7 @@ def make_blank_git(out_dict): out_dict: current output dictionary - Returns: + *Returns*: out_dict: output dictionary with added git info @@ -163,7 +197,7 @@ def _maybe_docker(cgroup_path="/proc/self/cgroup"): cgroup_path: path to cgroup file - Returns: + *Returns*: boolean of best effort docker determination @@ -184,7 +218,7 @@ def _maybe_k8s(cgroup_path="/proc/self/cgroup"): cgroup_path: path to cgroup file - Returns: + *Returns*: boolean of best effort k8s determination diff --git a/tests/attr/test_all_attr.py b/tests/attr/test_all_attr.py index cde10a6e..33f84032 100644 --- a/tests/attr/test_all_attr.py +++ b/tests/attr/test_all_attr.py @@ -7,6 +7,7 @@ import glob import pytest from spock.builder import ConfigArgBuilder +from spock.config import isinstance_spock from tests.attr.attr_configs_test import * import sys @@ -220,16 +221,24 @@ def arg_builder(monkeypatch): with monkeypatch.context() as m: m.setattr(sys, 'argv', ['', '--config', './tests/conf/yaml/test.yaml', - '--bool_p', '--int_p', '11', '--TypeConfig.float_p', '11.0', '--string_p', 'Hooray', - '--list_p_float', '[11.0, 21.0]', '--list_p_int', '[11, 21]', - '--list_p_str', "['Hooray', 'Working']", '--list_p_bool', '[False, True]', - '--tuple_p_float', '(11.0, 21.0)', '--tuple_p_int', '(11, 21)', - '--tuple_p_str', "('Hooray', 'Working')", '--tuple_p_bool', '(False, True)', - '--list_list_p_int', "[[11, 21], [11, 21]]", '--choice_p_str', 'option_2', - '--choice_p_int', '20', '--choice_p_float', '20.0', - '--list_choice_p_str', "['option_2']", - '--list_list_choice_p_str', "[['option_2'], ['option_2']]", - '--list_choice_p_int', '[20]', '--list_choice_p_float', '[20.0]' + '--TypeConfig.bool_p', '--TypeConfig.int_p', '11', '--TypeConfig.float_p', '11.0', + '--TypeConfig.string_p', 'Hooray', + '--TypeConfig.list_p_float', '[11.0,21.0]', '--TypeConfig.list_p_int', '[11, 21]', + '--TypeConfig.list_p_str', "['Hooray', 'Working']", + '--TypeConfig.list_p_bool', '[False, True]', + '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', + '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", + '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", + '--TypeConfig.choice_p_str', 'option_2', + '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', + '--TypeConfig.list_choice_p_str', "['option_2']", + '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", + '--TypeConfig.list_choice_p_int', '[20]', + '--TypeConfig.list_choice_p_float', '[20.0]', + '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', + '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', + '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", ]) config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') return config.generate() @@ -255,6 +264,71 @@ def test_overrides(self, arg_builder): assert arg_builder.TypeConfig.list_list_choice_p_str == [['option_2'], ['option_2']] assert arg_builder.TypeConfig.list_choice_p_int == [20] assert arg_builder.TypeConfig.list_choice_p_float == [20.0] + assert arg_builder.TypeConfig.class_enum.one == 12 + assert arg_builder.TypeConfig.class_enum.two == 'ancora' + assert arg_builder.NestedListStuff[0].one == 11 + assert arg_builder.NestedListStuff[0].two == 'Hooray' + assert arg_builder.NestedListStuff[1].one == 21 + assert arg_builder.NestedListStuff[1].two == 'Working' + + +class TestClassCmdLineOverride: + """Testing command line overrides""" + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', '--config', + './tests/conf/yaml/test_class.yaml', + '--TypeConfig.bool_p', '--TypeConfig.int_p', '11', '--TypeConfig.float_p', '11.0', + '--TypeConfig.string_p', 'Hooray', + '--TypeConfig.list_p_float', '[11.0,21.0]', '--TypeConfig.list_p_int', '[11, 21]', + '--TypeConfig.list_p_str', "['Hooray', 'Working']", + '--TypeConfig.list_p_bool', '[False, True]', + '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', + '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", + '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", + '--TypeConfig.choice_p_str', 'option_2', + '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', + '--TypeConfig.list_choice_p_str', "['option_2']", + '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", + '--TypeConfig.list_choice_p_int', '[20]', + '--TypeConfig.list_choice_p_float', '[20.0]', + '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', + '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', + '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", + ]) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') + return config.generate() + + def test_class_overrides(self, arg_builder): + assert arg_builder.TypeConfig.bool_p is True + assert arg_builder.TypeConfig.int_p == 11 + assert arg_builder.TypeConfig.float_p == 11.0 + assert arg_builder.TypeConfig.string_p == 'Hooray' + assert arg_builder.TypeConfig.list_p_float == [11.0, 21.0] + assert arg_builder.TypeConfig.list_p_int == [11, 21] + assert arg_builder.TypeConfig.list_p_str == ['Hooray', 'Working'] + assert arg_builder.TypeConfig.list_p_bool == [False, True] + assert arg_builder.TypeConfig.tuple_p_float == (11.0, 21.0) + assert arg_builder.TypeConfig.tuple_p_int == (11, 21) + assert arg_builder.TypeConfig.tuple_p_str == ('Hooray', 'Working') + assert arg_builder.TypeConfig.tuple_p_bool == (False, True) + assert arg_builder.TypeConfig.choice_p_str == 'option_2' + assert arg_builder.TypeConfig.choice_p_int == 20 + assert arg_builder.TypeConfig.choice_p_float == 20.0 + assert arg_builder.TypeConfig.list_list_p_int == [[11, 21], [11, 21]] + assert arg_builder.TypeConfig.list_choice_p_str == ['option_2'] + assert arg_builder.TypeConfig.list_list_choice_p_str == [['option_2'], ['option_2']] + assert arg_builder.TypeConfig.list_choice_p_int == [20] + assert arg_builder.TypeConfig.list_choice_p_float == [20.0] + assert arg_builder.TypeConfig.class_enum.one == 12 + assert arg_builder.TypeConfig.class_enum.two == 'ancora' + assert arg_builder.NestedListStuff[0].one == 11 + assert arg_builder.NestedListStuff[0].two == 'Hooray' + assert arg_builder.NestedListStuff[1].one == 21 + assert arg_builder.NestedListStuff[1].two == 'Working' class TestConfigKwarg(AllTypes): @@ -263,6 +337,7 @@ class TestConfigKwarg(AllTypes): @pytest.fixture def arg_builder(monkeypatch): with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['']) config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder', configs=['./tests/conf/yaml/test.yaml']) return config.generate() @@ -423,6 +498,14 @@ def test_yaml_file_writer(self, monkeypatch, tmp_path): config.save(user_specified_path=str(tmp_path)+'/foo.bar/fizz.buzz/', file_extension='.yaml').generate() +class TestIsInstance: + def test_isinstance(self): + """Test that isinstance is behaving correctly""" + assert isinstance_spock(TypeConfig) is True + assert isinstance_spock(object) is False + assert isinstance_spock(StrChoice) is False + + # TOML TESTS class TestAllTypesTOML(AllTypes): """Check all required types work as expected """ diff --git a/tests/conf/yaml/test_class.yaml b/tests/conf/yaml/test_class.yaml new file mode 100644 index 00000000..88867675 --- /dev/null +++ b/tests/conf/yaml/test_class.yaml @@ -0,0 +1,57 @@ +# conf file for all YAML tests +### Required or Boolean Base Types ### +TypeConfig: + # Boolean - Set + bool_p_set: true + # Required Int + int_p: 10 + # Required Float + float_p: 1e1 + # Required String + string_p: Spock + # Required List -- Float + list_p_float: [10.0, 20.0] + # Required List -- Int + list_p_int: [10, 20] + # Required List of Lists + list_list_p_int: [[10, 20], [10, 20]] + # Required List -- Str + list_p_str: [Spock, Package] + # Required List -- Bool + list_p_bool: [True, False] + # Required Tuple -- Float + tuple_p_float: [10.0, 20.0] + # Required Tuple -- Int + tuple_p_int: [10, 20] + # Required Tuple -- Str + tuple_p_str: [Spock, Package] + # Required Tuple -- Bool + tuple_p_bool: [True, False] + # Required Choice -- Str + choice_p_str: option_1 + # Required Choice -- Int + choice_p_int: 10 + # Required Choice -- Str + choice_p_float: 10.0 + # Required List of Choice -- Str + list_choice_p_str: [option_1] + # Required List of List of Choice -- Str + list_list_choice_p_str: [[option_1], [option_1]] + # Required List of Choice -- Int + list_choice_p_int: [10] + # Required List of Choice -- Float + list_choice_p_float: [10.0] + # Nested Configuration + nested: NestedStuff + # Nested List configuration + nested_list: NestedListStuff + # Class Enum + class_enum: NestedStuff +NestedListStuff: + - one: 10 + two: hello + - one: 20 + two: bye +NestedStuff: + one: 11 + two: ciao \ No newline at end of file diff --git a/tests/debug/debug.py b/tests/debug/debug.py index 5864c765..70fa5fad 100644 --- a/tests/debug/debug.py +++ b/tests/debug/debug.py @@ -5,8 +5,9 @@ from typing import Tuple from enum import Enum from spock.builder import ConfigArgBuilder +from spock.config import isinstance_spock from params.first import Test -from params.first import Stuff, OtherStuff +from params.first import NestedListStuff, Stuff, OtherStuff from spock.backend.attr.typed import SavePath import pickle from argparse import Namespace @@ -100,14 +101,16 @@ def main(): - attrs_class = ConfigArgBuilder(Test, Stuff, OtherStuff, desc='I am a description').save(user_specified_path='/tmp').generate() + attrs_class = ConfigArgBuilder( + Test, NestedListStuff, + desc='I am a description' + ).save(user_specified_path='/tmp').generate() # with open('/tmp/debug.pickle', 'wb') as fid: # pickle.dump(attrs_class, file=fid) # with open('/tmp/debug.pickle', 'rb') as fid: # attrs_load = pickle.load(fid) # attrs_class = ConfigArgBuilder(Test, Test2).generate() - print(attrs_class) # print(attrs_load) # dc_class = ConfigArgBuilder(OldInherit).generate() diff --git a/tests/debug/debug.yaml b/tests/debug/debug.yaml index 7d072f63..941d8117 100644 --- a/tests/debug/debug.yaml +++ b/tests/debug/debug.yaml @@ -8,18 +8,18 @@ ### fix_me: [[11, 2], [22, 1]] ##new_choice: pear ##borken: Stuff -Stuff: - one: 10 - two: hello +#Stuff: +# one: 10 +# two: hello ## - one: 10 ## two: hello ## - one: 20 ## two: bye # ##more_borken: OtherStuff -OtherStuff: - three: 20 - four: ciao +#OtherStuff: +# three: 20 +# four: ciao ####Test2: #### other: 12 ####Test: @@ -31,10 +31,26 @@ OtherStuff: # - hi: 20 # bye: 45.0 +#one: 18 +#fail: [[1, 2], [3, 4]] +#flipper: True +#nested_list: NestedListStuff +test: [ 1, 2 ] + Test: - test: [1, 2] - fail: [[1, 2], [3, 4]] - most_broken: Stuff +# test: [ 1, 2 ] +# fail: [ [ 1, 2 ], [ 3, 4 ] ] +# most_broken: Stuff +# one: 18 +# flipper: false + nested_list: NestedListStuff +# new_choice: pear + +NestedListStuff: + - maybe: 10 + more: hello + - maybe: 20 + more: bye # borken: RepeatStuff ##ccccombo_breaker: 10 diff --git a/tests/debug/params/first.py b/tests/debug/params/first.py index 3a9e8a9a..76f63100 100644 --- a/tests/debug/params/first.py +++ b/tests/debug/params/first.py @@ -6,17 +6,17 @@ from typing import Optional from .second import Choice - -# class Choice(Enum): -# """Blah -# -# Attributes: -# pear: help pears -# banana: help bananas # -# """ -# pear = 'pear' -# banana = 'banana' +class Choice(Enum): + """Blah + + Attributes: + pear: help pears + banana: help bananas + + """ + pear = 'pear' + banana = 'banana' @spock @@ -57,6 +57,19 @@ class ClassStuff(Enum): stuff = Stuff +@spock +class NestedListStuff: + """Class enum + + Attributes: + maybe: some val + more: some other value + + """ + maybe: int + more: str + + @spock class Test: """High level docstring that just so happens to be multiline adfjads;lfja;sdlkjfklasjflkasjlkfjal;sdfjlkajsdfl;kja @@ -69,10 +82,15 @@ class Test: test: you are my only hopes most_broken: class stuff enum new_choice: choice type optionality - + one: just a basic parameter + nested_list: Repeated list of a class type """ - new_choice: Optional[Choice] - fail: Tuple[Tuple[int, int], Tuple[int, int]] - test: Optional[List[int]] - most_broken: ClassStuff + # new_choice: Choice + # fail: Tuple[Tuple[int, int], Tuple[int, int]] + test: List[int] + # fail: List[List[int]] + # flipper: bool + # most_broken: ClassStuff + # one: int + nested_list: List[NestedListStuff]