diff --git a/hamilton/base.py b/hamilton/base.py index 469643a7..ddbcdd49 100644 --- a/hamilton/base.py +++ b/hamilton/base.py @@ -5,14 +5,18 @@ import abc import collections import inspect -import typing +import logging +from typing import Any, Dict, List, Tuple, Type, Union import numpy as np import pandas as pd import typing_inspect +from pandas.core.indexes import extension as pd_extension from . import node +logger = logging.getLogger(__name__) + class ResultMixin(object): """Base class housing the static function. @@ -23,7 +27,7 @@ class ResultMixin(object): @staticmethod @abc.abstractmethod - def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def build_result(**outputs: Dict[str, Any]) -> Any: """This function builds the result given the computed values.""" pass @@ -32,7 +36,7 @@ class DictResult(ResultMixin): """Simple function that returns the dict of column -> value results.""" @staticmethod - def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Dict: + def build_result(**outputs: Dict[str, Any]) -> Dict: """This function builds a simple dict of output -> computed values.""" return outputs @@ -41,9 +45,122 @@ class PandasDataFrameResult(ResultMixin): """Mixin for building a pandas dataframe from the result""" @staticmethod - def build_result(**outputs: typing.Dict[str, typing.Any]) -> pd.DataFrame: + def pandas_index_types( + outputs: Dict[str, Any] + ) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]: + """This function creates three dictionaries according to whether there is an index type or not. + + The three dicts we create are: + 1. Dict of index type to list of outputs that match it. + 2. Dict of time series / categorical index types to list of outputs that match it. + 3. Dict of `no-index` key to list of outputs with no index type. + + :param outputs: the dict we're trying to create a result from. + :return: dict of all index types, dict of time series/categorical index types, dict if there is no index + """ + all_index_types = collections.defaultdict(list) + time_indexes = collections.defaultdict(list) + no_indexes = collections.defaultdict(list) + + def index_key_name(pd_object: Union[pd.DataFrame, pd.Series]) -> str: + """Creates a string helping identify the index and it's type. + Useful for disambiguating time related indexes.""" + return f"{pd_object.index.__class__.__name__}:::{pd_object.index.dtype}" + + def get_parent_time_index_type(): + """Helper to pull the right time index parent class.""" + if hasattr( + pd_extension, "NDArrayBackedExtensionIndex" + ): # for python 3.7+ & pandas >= 1.2 + index_type = pd_extension.NDArrayBackedExtensionIndex + elif hasattr(pd_extension, "ExtensionIndex"): # for python 3.6 & pandas <= 1.2 + index_type = pd_extension.ExtensionIndex + else: + index_type = None # weird case, but not worth breaking for. + return index_type + + for output_name, output_value in outputs.items(): + if isinstance( + output_value, (pd.DataFrame, pd.Series) + ): # if it has an index -- let's grab it's type + dict_key = index_key_name(output_value) + if isinstance(output_value.index, get_parent_time_index_type()): + # it's a time index -- these will produce garbage if not aligned properly. + time_indexes[dict_key].append(output_name) + elif isinstance( + output_value, pd.Index + ): # there is no index on this - so it's just an integer one. + int_index = pd.Series( + [1, 2, 3], index=[0, 1, 2] + ) # dummy to get right values for string. + dict_key = index_key_name(int_index) + else: + dict_key = "no-index" + no_indexes[dict_key].append(output_name) + all_index_types[dict_key].append(output_name) + return all_index_types, time_indexes, no_indexes + + @staticmethod + def check_pandas_index_types_match( + all_index_types: Dict[str, List[str]], + time_indexes: Dict[str, List[str]], + no_indexes: Dict[str, List[str]], + ) -> bool: + """Checks that pandas index types match. + + This only logs warning errors, and if debug is enabled, a debug statement to list index types. + """ + no_index_length = len(no_indexes) + time_indexes_length = len(time_indexes) + all_indexes_length = len(all_index_types) + number_with_indexes = all_indexes_length - no_index_length + types_match = True # default to True + # if there is more than one time index + if time_indexes_length > 1: + logger.warning( + f"WARNING: Time/Categorical index type mismatches detected - check output to ensure Pandas " + f"is doing what you intend to do. Else change the index types to match. Set logger to debug " + f"to see index types." + ) + types_match = False + # if there is more than one index type and it's not explained by the time indexes then + if number_with_indexes > 1 and all_indexes_length > time_indexes_length: + logger.warning( + f"WARNING: Multiple index types detected - check output to ensure Pandas is " + f"doing what you intend to do. Else change the index types to match. Set logger to debug to " + f"see index types." + ) + types_match = False + elif number_with_indexes == 1 and no_index_length > 0: + logger.warning( + f"WARNING: a single pandas index was found, but there are also {no_index_length} outputs without " + f"an index. Those values will be made constants throughout the values of the index." + ) + # Strictly speaking the index types match -- there is only one -- so setting to True. + types_match = True + # if all indexes matches no indexes + elif no_index_length == all_indexes_length: + logger.warning( + "It appears no Pandas index type was detected. This will likely break when trying to " + "create a DataFrame. E.g. are you requesting all scalar values? Use a different result " + "builder or return at least one Pandas object with an index." + ) + types_match = False + if logger.isEnabledFor(logging.DEBUG): + import pprint + + pretty_string = pprint.pformat(dict(all_index_types)) + logger.debug(f"Index types encountered:\n{pretty_string}.") + return types_match + + @staticmethod + def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame: # TODO check inputs are pd.Series, arrays, or scalars -- else error - # TODO do a basic index check across pd.Series and flag where mismatches occur? + output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs) + # this next line just log warnings + # we don't actually care about the result since this is the current default behavior. + PandasDataFrameResult.check_pandas_index_types_match(*output_index_type_tuple) + if len(outputs) == 1: (value,) = outputs.values() # this works because it's length 1. if isinstance(value, pd.DataFrame): @@ -54,6 +171,40 @@ def build_result(**outputs: typing.Dict[str, typing.Any]) -> pd.DataFrame: return pd.DataFrame(outputs) +class StrictIndexTypePandasDataFrameResult(PandasDataFrameResult): + """A ResultBuilder that produces a dataframe only if the index types match exactly. + + Note: If there is no index type on some outputs, e.g. the value is a scalar, as long as there exists a single pandas + index type, no error will be thrown, because a dataframe can be easily created. + + To use: + from hamilton import base, driver + strict_builder = base.StrictIndexTypePandasDataFrameResult() + adapter = base.SimplePythonGraphAdapter(strict_builder) + ... + dr = driver.Driver(config, *modules, adapter=adapter) + df = dr.execute(...) # this will now error if index types mismatch. + """ + + @staticmethod + def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame: + # TODO check inputs are pd.Series, arrays, or scalars -- else error + output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs) + indexes_match = PandasDataFrameResult.check_pandas_index_types_match( + *output_index_type_tuple + ) + if not indexes_match: + import pprint + + pretty_string = pprint.pformat(dict(output_index_type_tuple[0])) + raise ValueError( + "Error: pandas index types did not match exactly. " + f"Found the following indexes:\n{pretty_string}" + ) + + return PandasDataFrameResult.build_result(**outputs) + + class NumpyMatrixResult(ResultMixin): """Mixin for building a Numpy Matrix from the result of walking the graph. @@ -61,7 +212,7 @@ class NumpyMatrixResult(ResultMixin): """ @staticmethod - def build_result(**outputs: typing.Dict[str, typing.Any]) -> np.matrix: + def build_result(**outputs: Dict[str, Any]) -> np.matrix: """Builds a numpy matrix from the passed in, inputs. :param outputs: function_name -> np.array. @@ -108,7 +259,7 @@ class HamiltonGraphAdapter(ResultMixin): @staticmethod @abc.abstractmethod - def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: + def check_input_type(node_type: Type, input_value: Any) -> bool: """Used to check whether the user inputs match what the execution strategy & functions can handle. :param node_type: The type of the node. @@ -119,7 +270,7 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: @staticmethod @abc.abstractmethod - def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: + def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: """Used to check whether two types are equivalent. This is used when the function graph is being created and we're statically type checking the annotations @@ -132,7 +283,7 @@ def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) pass @abc.abstractmethod - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: """Given a node that represents a hamilton function, execute it. Note, in some adapters this might just return some type of "future". @@ -147,8 +298,8 @@ class SimplePythonDataFrameGraphAdapter(HamiltonGraphAdapter, PandasDataFrameRes """This is the default (original Hamilton) graph adapter. It uses plain python and builds a dataframe result.""" @staticmethod - def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: - if node_type == typing.Any: + def check_input_type(node_type: Type, input_value: Any) -> bool: + if node_type == Any: return True elif inspect.isclass(node_type) and isinstance(input_value, node_type): return True @@ -171,10 +322,10 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: return False @staticmethod - def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: + def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: return node_type == input_type - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: return node.callable(**kwargs) @@ -186,6 +337,6 @@ def __init__(self, result_builder: ResultMixin): if self.result_builder is None: raise ValueError("You must provide a ResultMixin object for `result_builder`.") - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def build_result(self, **outputs: Dict[str, Any]) -> Any: """Delegates to the result builder function supplied.""" return self.result_builder.build_result(**outputs) diff --git a/tests/test_base.py b/tests/test_base.py index 7f239ff6..2c6206e9 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -137,12 +137,23 @@ def test_SimplePythonDataFrameGraphAdapter_check_input_type_mismatch(node_type, {"a": pd.Series([1, 2, 3]), "b": pd.Series([11, 12, 13]), "c": pd.Series([1, 1, 1])} ), ), + ( + { + "a": pd.Series([1, 2, 3]), + "b": pd.Series([11, 12, 13]), + "c": pd.Series([11, 12, 13]).index, + }, + pd.DataFrame( + {"a": pd.Series([1, 2, 3]), "b": pd.Series([11, 12, 13]), "c": pd.Series([0, 1, 2])} + ), + ), ], ids=[ "test-single-series", "test-single-dataframe", "test-multiple-series", "test-multiple-series-with-scalar", + "test-multiple-series-with-index", ], ) def test_PandasDataFrameResult_build_result(outputs, expected_result): @@ -177,7 +188,201 @@ def test_PandasDataFrameResult_build_result(outputs, expected_result): ], ) def test_PandasDataFrameResult_build_result_errors(outputs): - """Tests the happy case of PandasDataFrameResult.build_result()""" + """Tests the error case of PandasDataFrameResult.build_result()""" pdfr = base.PandasDataFrameResult() with pytest.raises(ValueError): pdfr.build_result(**outputs) + + +@pytest.mark.parametrize( + "outputs,expected_result", + [ + ({"a": pd.Series([1, 2, 3])}, ({"RangeIndex:::int64": ["a"]}, {}, {})), + ( + {"a": pd.Series([1, 2, 3]), "b": pd.Series([3, 4, 5])}, + ({"RangeIndex:::int64": ["a", "b"]}, {}, {}), + ), + ( + { + "b": pd.Series( + [3, 4, 5], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ) + }, + ( + {"DatetimeIndex:::datetime64[ns]": ["b"]}, + {"DatetimeIndex:::datetime64[ns]": ["b"]}, + {}, + ), + ), + ({"c": 1}, ({"no-index": ["c"]}, {}, {"no-index": ["c"]})), + ( + { + "a": pd.Series([1, 2, 3]), + "b": 1, + "c": pd.Series( + [3, 4, 5], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ), + }, + ( + { + "DatetimeIndex:::datetime64[ns]": ["c"], + "RangeIndex:::int64": ["a"], + "no-index": ["b"], + }, + {"DatetimeIndex:::datetime64[ns]": ["c"]}, + {"no-index": ["b"]}, + ), + ), + ({"a": pd.DataFrame({"a": [1, 2, 3]})}, ({"RangeIndex:::int64": ["a"]}, {}, {})), + ({"a": pd.Series([1, 2, 3]).index}, ({"Int64Index:::int64": ["a"]}, {}, {})), + ], + ids=[ + "int-index", + "int-index-double", + "ts-index", + "no-index", + "multiple-different-indexes", + "df-index", + "index-object", + ], +) +def test_PandasDataFrameResult_pandas_index_types(outputs, expected_result): + """Tests exercising the function to return pandas index types from outputs""" + pdfr = base.PandasDataFrameResult() + actual = pdfr.pandas_index_types(outputs) + assert dict(actual[0]) == expected_result[0] + assert dict(actual[1]) == expected_result[1] + assert dict(actual[2]) == expected_result[2] + + +@pytest.mark.parametrize( + "all_index_types,time_indexes,no_indexes,expected_result", + [ + ({"foo": ["a", "b", "c"]}, {}, {}, True), + ({"int-index": ["a"], "no-index": ["b"]}, {}, {"no-index": ["b"]}, True), + ({"ts-1": ["a"], "ts-2": ["b"]}, {"ts-1": ["a"], "ts-2": ["b"]}, {}, False), + ({"float-index": ["a"], "int-index": ["b"]}, {}, {}, False), + ({"no-index": ["a", "b"]}, {}, {"no-index": ["a", "b"]}, False), + ], + ids=[ + "all-the-same", # True + "single-index-with-no-index", # True + "multiple-ts", # False + "multiple-indexes-not-ts", # False + "no-indexes-at-all", # False4 + ], +) +def test_PandasDataFrameResult_check_pandas_index_types_match( + all_index_types, time_indexes, no_indexes, expected_result +): + """Tests exercising the function to determine whether pandas index types match""" + # setup to test conditional if statement on logger level + import logging + + logger = logging.getLogger("hamilton.base") # get logger of base module. + logger.setLevel(logging.DEBUG) + pdfr = base.PandasDataFrameResult() + actual = pdfr.check_pandas_index_types_match(all_index_types, time_indexes, no_indexes) + assert actual == expected_result + + +@pytest.mark.parametrize( + "outputs,expected_result", + [ + ({"a": pd.Series([1, 2, 3])}, pd.DataFrame({"a": pd.Series([1, 2, 3])})), + ( + { + "a": pd.Series( + [1, 2, 3], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ), + "b": pd.Series( + [3, 4, 5], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ), + }, + pd.DataFrame( + { + "a": pd.Series( + [1, 2, 3], + index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS"), + ), + "b": pd.Series( + [3, 4, 5], + index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS"), + ), + } + ), + ), + ( + { + "a": pd.Series( + [1, 2, 3], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ), + "b": 4, + }, + pd.DataFrame( + { + "a": pd.Series( + [1, 2, 3], + index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS"), + ), + "b": 4, + } + ), + ), + ], + ids=[ + "test-same-index-simple", + "test-same-index-ts", + "test-index-with-scalar", + ], +) +def test_StrictIndexTypePandasDataFrameResult_build_result(outputs, expected_result): + """Tests the happy case of StrictIndexTypePandasDataFrameResult.build_result()""" + sitpdfr = base.StrictIndexTypePandasDataFrameResult() + actual = sitpdfr.build_result(**outputs) + pd.testing.assert_frame_equal(actual, expected_result) + + +@pytest.mark.parametrize( + "outputs", + [ + ( + { + "a": pd.Series([1, 2, 3], index=[0, 1, 2]), + "b": pd.Series([1, 2, 3], index=[0.0, 1.0, 2.0]), + } + ), + ( + { + "series1": pd.Series( + [1, 2, 3], index=pd.DatetimeIndex(["2022-01", "2022-02", "2022-03"], freq="MS") + ), + "series2": pd.Series( + [4, 5, 6], + index=pd.PeriodIndex(year=[2022, 2022, 2022], month=[1, 2, 3], freq="M"), + ), + "series3": pd.Series( + [4, 5, 6], + index=pd.PeriodIndex( + year=[2022, 2022, 2022], month=[1, 1, 1], day=[3, 4, 5], freq="B" + ), + ), + "series4": pd.Series( + [4, 5, 6], + index=pd.PeriodIndex( + year=[2022, 2022, 2022], month=[1, 1, 1], day=[4, 11, 18], freq="W" + ), + ), + } + ), + ], + ids=[ + "test-int-float", + "test-different-ts-indexes", + ], +) +def test_StrictIndexTypePandasDataFrameResult_build_result_errors(outputs): + """Tests the error case of StrictIndexTypePandasDataFrameResult.build_result()""" + sitpdfr = base.StrictIndexTypePandasDataFrameResult() + with pytest.raises(ValueError): + sitpdfr.build_result(**outputs)