From 7e51f843868372d04f53794205da943780d45d90 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 24 Aug 2022 21:18:23 -0700 Subject: [PATCH] Upgrades the `@does` decorator to be more generally usable and more readable. Adds the following features for #186: 1. The ability to have more complex arguments 2. The ability to have an argument mapping All of this is documented. Note this also fixes it breaking with optional dependencies (#185) --- decorators.md | 51 +++++++--- hamilton/dev_utils/deprecation.py | 4 +- hamilton/function_modifiers.py | 162 ++++++++++++++++++++++-------- hamilton/node.py | 19 ++++ tests/test_function_modifiers.py | 134 +++++++++++++++++++++--- 5 files changed, 299 insertions(+), 71 deletions(-) diff --git a/decorators.md b/decorators.md index 217c9033..04e9d3ff 100644 --- a/decorators.md +++ b/decorators.md @@ -165,31 +165,58 @@ def my_func(...) -> pd.DataFrame: ``` ## @does -`@does` is a decorator that essentially allows you to run a function over all the input parameters. So you can't pass -any old function to `@does`, instead the function passed has to take any amount of inputs and process them all in the same way. +`@does` is a decorator that allows you to replace the decorated function with the behavior from another +function. This allows for easy code-reuse when building repeated logic. You do this by decorating a +function with`@does`, which takes in two parameters: +1. `replacing_function` Required -- a function that takes in a "compatible" set of arguments. This means that it +will work when passing the corresponding keyword arguments to the decorated function. +2. `**argument_mapping` -- a mapping of arguments from the replacing function to the replacing function. This makes for easy reuse of +functions. + ```python import pandas as pd from hamilton.function_modifiers import does -import internal_package_with_logic -def sum_series(**series: pd.Series) -> pd.Series: +def _sum_series(**series: pd.Series) -> pd.Series: """This function takes any number of inputs and sums them all together.""" - ... + return sum(series) -@does(sum_series) +@does(_sum_series) def D_XMAS_GC_WEIGHTED_BY_DAY(D_XMAS_GC_WEIGHTED_BY_DAY_1: pd.Series, D_XMAS_GC_WEIGHTED_BY_DAY_2: pd.Series) -> pd.Series: """Adds D_XMAS_GC_WEIGHTED_BY_DAY_1 and D_XMAS_GC_WEIGHTED_BY_DAY_2""" pass +``` + +In the above example `@does` applies `_sum_series` to the function `D_XMAS_GC_WEIGHTED_BY_DAY`. +Note we don't need any parameter replacement as `_sum_series` takes in just `**kwargs`, enabling it +to work with any set of parameters (and thus any old function). -@does(internal_package_with_logic.identity_function) -def copy_of_x(x: pd.Series) -> pd.Series: - """Just returns x""" +```python +import pandas as pd +from hamilton.function_modifiers import does + +import internal_company_logic + +def _load_data(db: str, table: str) -> pd.DataFrame: + """Helper function to load data using your internal company logic""" + return internal_company_logic.read_table(db=db, table=table) + +@does(_load_data, db='marketing_spend_db', table='marketing_spend_table') +def marketing_spend_data(marketing_spend_db: str, marketing_spend_table: str) -> pd.Series: + """Loads marketing spend data from the database""" + pass + +@does(_load_data, db='client_acquisition_db', table='client_acquisition_table') +def client_acquisition_data(client_acquisition_db: str, client_acquisition_table: str) -> pd.Series: + """Loads client acquisition data from the database""" pass ``` -The example here is a function, that all that it does, is sum all the parameters together. So we can annotate it with -the `@does` decorator and pass it the `sum_series` function. -The `@does` decorator is currently limited to just allow functions that consist only of one argument, a generic `**kwargs`. + +In the above example, `@does` applies our internal function `_load_data`, which applies custom +logic to load a table from a database in the data warehouse. Note that we map the parameters -- in the first example, +the value of the parameter `marketing_spend_db` is passed to `db`, and the value of the parameter `marketing_spend_table` +is passed to `table`. ## @model `@model` allows you to abstract a function that is a model. You will need to implement models that make sense for diff --git a/hamilton/dev_utils/deprecation.py b/hamilton/dev_utils/deprecation.py index fe3f70d7..1062e461 100644 --- a/hamilton/dev_utils/deprecation.py +++ b/hamilton/dev_utils/deprecation.py @@ -159,8 +159,8 @@ def __call__(self, fn: Callable): TODO -- use @singledispatchmethod when we no longer support 3.6/3.7 https://docs.python.org/3/library/functools.html#functools.singledispatchmethod - @param fn: function (or class) to decorate - @return: The decorated function. + :param fn: function (or class) to decorate + :return: The decorated function. """ # In this case we just do a standard decorator if isinstance(fn, types.FunctionType): diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index 2b96f537..01287117 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -72,8 +72,8 @@ def value(literal_value: Any) -> LiteralDependency: """Specifies that a parameterized dependency comes from a "literal" source. E.G. value("foo") means that the value is actually the string value "foo" - @param literal_value: Python literal value to use - @return: A LiteralDependency object -- a signifier to the internal framework of the dependency type + :param literal_value: Python literal value to use + :return: A LiteralDependency object -- a signifier to the internal framework of the dependency type """ if isinstance(literal_value, LiteralDependency): return literal_value @@ -85,8 +85,8 @@ def source(dependency_on: Any) -> UpstreamDependency: This means that it comes from a node somewhere else. E.G. source("foo") means that it should be assigned the value that "foo" outputs. - @param dependency_on: Upstream node to come from - @return:An UpstreamDependency object -- a signifier to the internal framework of the dependency type. + :param dependency_on: Upstream node to come from + :return:An UpstreamDependency object -- a signifier to the internal framework of the dependency type. """ if isinstance(dependency_on, UpstreamDependency): return dependency_on @@ -112,7 +112,7 @@ def __init__( def concat(upstream_parameter: str, literal_parameter: str) -> Any: return f'{upstream_parameter}{literal_parameter}' - @param parametrization: **kwargs with one of two things: + :param parametrization: **kwargs with one of two things: - a tuple of assignments (consisting of literals/upstream specifications), and docstring - just assignments, in which case it parametrizes the existing docstring """ @@ -626,71 +626,151 @@ def ensure_function_empty(fn: Callable): class does(function_modifiers_base.NodeCreator): - def __init__(self, replacing_function: Callable): - """ - Constructor for a modifier that replaces the annotated functions functionality with something else. + def __init__(self, replacing_function: Callable, **argument_mapping: Union[str, List[str]]): + """Constructor for a modifier that replaces the annotated functions functionality with something else. Right now this has a very strict validation requirements to make compliance with the framework easy. + :param replacing_function: The function to replace the original function with + :param argument_mapping: A mapping of argument name in the replacing function to argument name in the decorating function """ self.replacing_function = replacing_function + self.argument_mapping = argument_mapping @staticmethod - def ensure_output_types_match(fn: Callable, todo: Callable): - """ - Ensures that the output types of two functions match. + def ensure_output_types_match(og_function: Callable, replacing_function: Callable): + """Ensures that the output types of two functions match. + :param og_function: Function we're decorating + :param replacing_function: Function that'll replace it with functionality + :return: True if they match, false otherwise """ - annotation_fn = inspect.signature(fn).return_annotation - annotation_todo = inspect.signature(todo).return_annotation + annotation_fn = inspect.signature(og_function).return_annotation + annotation_todo = inspect.signature(replacing_function).return_annotation if not type_utils.custom_subclass_check(annotation_fn, annotation_todo): raise InvalidDecoratorException( f"Output types: {annotation_fn} and {annotation_todo} are not compatible" ) @staticmethod - def ensure_function_kwarg_only(fn: Callable): + def map_kwargs(kwargs: Dict[str, Any], argument_mapping: Dict[str, str]) -> Dict[str, Any]: + """Maps kwargs using the argument mapping. + This does 2 things: + 1. Replaces all kwargs in passed_in_kwargs with their mapping + 2. Injects all defaults from the origin function signature + + :param kwargs: Keyword arguments that will be passed into a hamilton function. + :param argument_mapping: Mapping of those arguments to a replacing function's arguments. + :return: The new kwargs for the replacing function's arguments. """ - Ensures that a function is kwarg only. Meaning that it only has one parameter similar to **kwargs. + output = {**kwargs} + for arg_mapped_to, original_arg in argument_mapping.items(): + if original_arg in kwargs and arg_mapped_to not in argument_mapping.values(): + del output[original_arg] + # Note that if it is not there it could be a **kwarg + output[arg_mapped_to] = kwargs[original_arg] + return output + + @staticmethod + def test_function_signatures_compatible( + fn_signature: inspect.Signature, + replace_with_signature: inspect.Signature, + argument_mapping: Dict[str, str], + ) -> bool: + """Tests whether a function signature and the signature of the replacing function are compatible. + + :param fn_signature: + :param replace_with_signature: + :param argument_mapping: + :return: True if they're compatible, False otherwise + """ + # The easy (and robust) way to do this is to use the bind with a set of dummy arguments and test if it breaks. + # This way we're not reinventing the wheel. + SENTINEL_ARG_VALUE = ... # does not matter as we never use it + # We initialize as the default values, as they'll always be injected in + dummy_param_values = { + key: SENTINEL_ARG_VALUE + for key, param_spec in fn_signature.parameters.items() + if param_spec.default != inspect.Parameter.empty + } + # Then we update with the dummy values. Again, replacing doesn't matter (we'll be mimicking it later) + dummy_param_values.update({key: SENTINEL_ARG_VALUE for key in fn_signature.parameters}) + dummy_param_values = does.map_kwargs(dummy_param_values, argument_mapping) + try: + # Python signatures have a bind() capability which does exactly what we want to do + # Throws a type error if it is not valid + replace_with_signature.bind(**dummy_param_values) + except TypeError: + return False + return True + + @staticmethod + def ensure_function_signature_compatible( + og_function: Callable, replacing_function: Callable, argument_mapping: Dict[str, str] + ): + """Ensures that a function signature is compatible with the replacing function, given the argument mapping + + :param og_function: Function that's getting replaced (decorated with `@does`) + :param replacing_function: A function that gets called in its place (passed in by `@does`) + :param argument_mapping: The mapping of arguments from fn to replace_with + :return: """ - parameters = inspect.signature(fn).parameters - if len(parameters) > 1: + fn_parameters = inspect.signature(og_function).parameters + invalid_fn_parameters = [] + for param_name, param_spec in fn_parameters.items(): + if param_spec.kind not in { + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + }: + invalid_fn_parameters.append(param_name) + + if invalid_fn_parameters: raise InvalidDecoratorException( - "Too many parameters -- for now @does can only use **kwarg functions. " - f"Found params: {parameters}" + f"Decorated function for @does (and really, all of hamilton), " + f"can only consist of keyword-friendly arguments. " + f"The following parameters for {og_function.__name__} are not keyword-friendly: {invalid_fn_parameters}" ) - ((_, parameter),) = parameters.items() - if not parameter.kind == inspect.Parameter.VAR_KEYWORD: + if not does.test_function_signatures_compatible( + inspect.signature(og_function), inspect.signature(replacing_function), argument_mapping + ): raise InvalidDecoratorException( - f"Must have only one parameter, and that parameter must be a **kwargs " - f"parameter. Instead, found: {parameter}" + f"The following function signatures are not compatible for use with @does: " + f"{og_function.__name__} with signature {inspect.signature(og_function)} " + f"and replacing function {replacing_function.__name__} with signature {inspect.signature(replacing_function)}. " + f"Mapping for arguments provided was: {argument_mapping}. You can fix this by either adjusting " + f"the signature for the replacing function *or* adjusting the mapping." ) def validate(self, fn: Callable): - """ - Validates that the function: + """Validates that the function: - Is empty (we don't want to be overwriting actual code) - - is keyword argument only (E.G. has just **kwargs in its argument list) + - Has a compatible return type + - Matches the function signature with the appropriate mapping :param fn: Function to validate :raises: InvalidDecoratorException """ ensure_function_empty(fn) - does.ensure_function_kwarg_only(self.replacing_function) + does.ensure_function_signature_compatible( + fn, self.replacing_function, self.argument_mapping + ) does.ensure_output_types_match(fn, self.replacing_function) def generate_node(self, fn: Callable, config) -> node.Node: + """Returns one node which has the replaced functionality + :param fn: Function to decorate + :param config: Configuration (not used in this) + :return: A node with the function in `@does` injected, + and the same parameters/types as the original function. """ - Returns one node which has the replaced functionality - :param fn: - :param config: - :return: - """ - fn_signature = inspect.signature(fn) - return node.Node( - fn.__name__, - typ=fn_signature.return_annotation, - doc_string=fn.__doc__ if fn.__doc__ is not None else "", - callabl=self.replacing_function, - input_types={key: value.annotation for key, value in fn_signature.parameters.items()}, - tags=get_default_tags(fn), - ) + + def replacing_function(__fn=fn, **kwargs): + final_kwarg_values = { + key: param_spec.default + for key, param_spec in inspect.signature(fn).parameters.items() + if param_spec.default != inspect.Parameter.empty + } + final_kwarg_values.update(kwargs) + final_kwarg_values = does.map_kwargs(final_kwarg_values, self.argument_mapping) + return self.replacing_function(**final_kwarg_values) + + return node.Node.from_fn(fn).copy_with(callabl=replacing_function) class dynamic_transform(function_modifiers_base.NodeCreator): diff --git a/hamilton/node.py b/hamilton/node.py index 47c06d2a..bbd060bb 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -188,3 +188,22 @@ def from_fn(fn: Callable, name: str = None) -> "Node": callabl=fn, tags={"module": module}, ) + + def copy_with(self, **overrides) -> "Node": + """Copies a node with the specified overrides for the constructor arguments. + Utility function for creating a node -- useful for modifying it. + + :param kwargs: kwargs to use in place of the node. Passed to the constructor. + :return: A node copied from self with the specified keyword arguments replaced. + """ + constructor_args = dict( + name=self.name, + typ=self.type, + doc_string=self.documentation, + callabl=self.callable, + node_source=self.node_source, + input_types=self.input_types, + tags=self.tags, + ) + constructor_args.update(**overrides) + return Node(**constructor_args) diff --git a/tests/test_function_modifiers.py b/tests/test_function_modifiers.py index 987f2e2d..99caf3fa 100644 --- a/tests/test_function_modifiers.py +++ b/tests/test_function_modifiers.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Set import numpy as np @@ -308,29 +309,91 @@ def yes_code(): ensure_function_empty(yes_code) -def test_fn_kwarg_only_validator(): - def kwarg_only(**kwargs): - pass +## Functions for @does -- these are the functions we're "replacing" +def _no_params() -> int: + pass - def more_args(param1, param2, *args, **kwargs): - pass - def kwargs_and_args(*args, **kwargs): - pass +def _one_param(a: int) -> int: + pass - def args_only(*args): - pass - with pytest.raises(function_modifiers.InvalidDecoratorException): - does.ensure_function_kwarg_only(more_args) +def _two_params(a: int, b: int) -> int: + pass - with pytest.raises(function_modifiers.InvalidDecoratorException): - does.ensure_function_kwarg_only(kwargs_and_args) - with pytest.raises(function_modifiers.InvalidDecoratorException): - does.ensure_function_kwarg_only(args_only) +def _three_params(a: int, b: int, c: int) -> int: + pass + + +def _three_params_with_defaults(a: int, b: int = 1, c: int = 2) -> int: + pass + + +## functions we can/can't replace them with +def _empty() -> int: + return 1 + - does.ensure_function_kwarg_only(kwarg_only) +def _kwargs(**kwargs: int) -> int: + return sum(kwargs.values()) + + +def _kwargs_with_a(a: int, **kwargs: int) -> int: + return a + sum(kwargs.values()) + + +def _just_a(a: int) -> int: + return a + + +def _just_b(b: int) -> int: + return b + + +def _a_b_c(a: int, b: int, c: int) -> int: + return a + b + c + + +@pytest.mark.parametrize( + "fn,replace_with,argument_mapping,matches", + [ + (_no_params, _empty, {}, True), + (_no_params, _kwargs, {}, True), + (_no_params, _kwargs_with_a, {}, False), + (_no_params, _just_a, {}, False), + (_no_params, _a_b_c, {}, False), + (_one_param, _empty, {}, False), + (_one_param, _kwargs, {}, True), + (_one_param, _kwargs_with_a, {}, True), + (_one_param, _just_a, {}, True), + (_one_param, _just_b, {}, False), + (_one_param, _just_b, {"b": "a"}, True), # Replacing a with b makes the signatures match + (_one_param, _just_b, {"c": "a"}, False), # Replacing a with b makes the signatures match + (_two_params, _empty, {}, False), + (_two_params, _kwargs, {}, True), + (_two_params, _kwargs_with_a, {}, True), # b gets fed to kwargs + (_two_params, _kwargs_with_a, {"foo": "b"}, True), # Any kwargs work + (_two_params, _kwargs_with_a, {"bar": "a"}, False), # No param bar + (_two_params, _just_a, {}, False), + (_two_params, _just_b, {}, False), + (_three_params, _a_b_c, {}, True), + (_three_params, _a_b_c, {"d": "a"}, False), + (_three_params, _a_b_c, {}, True), + (_three_params, _a_b_c, {"a": "b", "b": "a"}, True), # Weird case but why not? + (_three_params, _kwargs_with_a, {}, True), + (_three_params_with_defaults, _a_b_c, {}, True), + (_three_params_with_defaults, _a_b_c, {"d": "a"}, False), + (_three_params_with_defaults, _a_b_c, {}, True), + ], +) +def test_ensure_function_signatures_compatible(fn, replace_with, argument_mapping, matches): + assert ( + does.test_function_signatures_compatible( + inspect.signature(fn), inspect.signature(replace_with), argument_mapping + ) + == matches + ) def test_compatible_return_types(): @@ -376,6 +439,45 @@ def to_modify(param1: List[int], param2: List[int]) -> int: assert node.documentation == to_modify.__doc__ +def test_does_function_modifier_optionals(): + def sum_(param0: int, **kwargs: int) -> int: + return sum(kwargs.values()) + + def to_modify(param0: int, param1: int = 1, param2: int = 2) -> int: + """This sums the inputs it gets...""" + pass + + annotation = does(sum_) + node_ = annotation.generate_node(to_modify, {}) + assert node_.name == "to_modify" + assert node_.input_types["param0"][1] == DependencyType.REQUIRED + assert node_.input_types["param1"][1] == DependencyType.OPTIONAL + assert node_.input_types["param2"][1] == DependencyType.OPTIONAL + assert node_.callable(param0=0) == 3 + assert node_.callable(param0=0, param1=0, param2=0) == 0 + assert node_.documentation == to_modify.__doc__ + + +def test_does_with_argument_mapping(): + def _sum_multiply(param0: int, param1: int, param2: int) -> int: + return param0 + param1 * param2 + + def to_modify(parama: int, paramb: int = 1, paramc: int = 2) -> int: + """This sums the inputs it gets...""" + pass + + annotation = does(_sum_multiply, param0="parama", param1="paramb", param2="paramc") + node = annotation.generate_node(to_modify, {}) + assert node.name == "to_modify" + assert node.input_types["parama"][1] == DependencyType.REQUIRED + assert node.input_types["paramb"][1] == DependencyType.OPTIONAL + assert node.input_types["paramc"][1] == DependencyType.OPTIONAL + assert node.callable(parama=0) == 2 + assert node.callable(parama=0, paramb=1, paramc=2) == 2 + assert node.callable(parama=1, paramb=4) == 9 + assert node.documentation == to_modify.__doc__ + + def test_model_modifier(): config = { "my_column_model_params": {