-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: decouple pandas postprocessing operator (#18710)
- Loading branch information
1 parent
ea12024
commit 8d6aff3
Showing
16 changed files
with
1,364 additions
and
1,002 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from superset.utils.pandas_postprocessing.aggregate import aggregate | ||
from superset.utils.pandas_postprocessing.boxplot import boxplot | ||
from superset.utils.pandas_postprocessing.compare import compare | ||
from superset.utils.pandas_postprocessing.contribution import contribution | ||
from superset.utils.pandas_postprocessing.cum import cum | ||
from superset.utils.pandas_postprocessing.diff import diff | ||
from superset.utils.pandas_postprocessing.geography import ( | ||
geodetic_parse, | ||
geohash_decode, | ||
geohash_encode, | ||
) | ||
from superset.utils.pandas_postprocessing.pivot import pivot | ||
from superset.utils.pandas_postprocessing.prophet import prophet | ||
from superset.utils.pandas_postprocessing.resample import resample | ||
from superset.utils.pandas_postprocessing.rolling import rolling | ||
from superset.utils.pandas_postprocessing.select import select | ||
from superset.utils.pandas_postprocessing.sort import sort | ||
from superset.utils.pandas_postprocessing.utils import _flatten_column_after_pivot | ||
|
||
__all__ = [ | ||
"aggregate", | ||
"boxplot", | ||
"compare", | ||
"contribution", | ||
"cum", | ||
"diff", | ||
"geohash_encode", | ||
"geohash_decode", | ||
"geodetic_parse", | ||
"pivot", | ||
"prophet", | ||
"resample", | ||
"rolling", | ||
"select", | ||
"sort", | ||
"_flatten_column_after_pivot", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import Any, Dict, List | ||
|
||
from pandas import DataFrame | ||
|
||
from superset.utils.pandas_postprocessing.utils import ( | ||
_get_aggregate_funcs, | ||
validate_column_args, | ||
) | ||
|
||
|
||
@validate_column_args("groupby") | ||
def aggregate( | ||
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] | ||
) -> DataFrame: | ||
""" | ||
Apply aggregations to a DataFrame. | ||
:param df: Object to aggregate. | ||
:param groupby: columns to aggregate | ||
:param aggregates: A mapping from metric column to the function used to | ||
aggregate values. | ||
:raises QueryObjectValidationError: If the request in incorrect | ||
""" | ||
aggregates = aggregates or {} | ||
aggregate_funcs = _get_aggregate_funcs(df, aggregates) | ||
if groupby: | ||
df_groupby = df.groupby(by=groupby) | ||
else: | ||
df_groupby = df.groupby(lambda _: True) | ||
return df_groupby.agg(**aggregate_funcs).reset_index(drop=not groupby) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import numpy as np | ||
from flask_babel import gettext as _ | ||
from pandas import DataFrame, Series | ||
|
||
from superset.exceptions import QueryObjectValidationError | ||
from superset.utils.core import PostProcessingBoxplotWhiskerType | ||
from superset.utils.pandas_postprocessing.aggregate import aggregate | ||
|
||
|
||
def boxplot( | ||
df: DataFrame, | ||
groupby: List[str], | ||
metrics: List[str], | ||
whisker_type: PostProcessingBoxplotWhiskerType, | ||
percentiles: Optional[ | ||
Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]] | ||
] = None, | ||
) -> DataFrame: | ||
""" | ||
Calculate boxplot statistics. For each metric, the operation creates eight | ||
new columns with the column name suffixed with the following values: | ||
- `__mean`: the mean | ||
- `__median`: the median | ||
- `__max`: the maximum value excluding outliers (see whisker type) | ||
- `__min`: the minimum value excluding outliers (see whisker type) | ||
- `__q1`: the median | ||
- `__q1`: the first quartile (25th percentile) | ||
- `__q3`: the third quartile (75th percentile) | ||
- `__count`: count of observations | ||
- `__outliers`: the values that fall outside the minimum/maximum value | ||
(see whisker type) | ||
:param df: DataFrame containing all-numeric data (temporal column ignored) | ||
:param groupby: The categories to group by (x-axis) | ||
:param metrics: The metrics for which to calculate the distribution | ||
:param whisker_type: The confidence level type | ||
:return: DataFrame with boxplot statistics per groupby | ||
""" | ||
|
||
def quartile1(series: Series) -> float: | ||
return np.nanpercentile(series, 25, interpolation="midpoint") | ||
|
||
def quartile3(series: Series) -> float: | ||
return np.nanpercentile(series, 75, interpolation="midpoint") | ||
|
||
if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: | ||
|
||
def whisker_high(series: Series) -> float: | ||
upper_outer_lim = quartile3(series) + 1.5 * ( | ||
quartile3(series) - quartile1(series) | ||
) | ||
return series[series <= upper_outer_lim].max() | ||
|
||
def whisker_low(series: Series) -> float: | ||
lower_outer_lim = quartile1(series) - 1.5 * ( | ||
quartile3(series) - quartile1(series) | ||
) | ||
return series[series >= lower_outer_lim].min() | ||
|
||
elif whisker_type == PostProcessingBoxplotWhiskerType.PERCENTILE: | ||
if ( | ||
not isinstance(percentiles, (list, tuple)) | ||
or len(percentiles) != 2 | ||
or not isinstance(percentiles[0], (int, float)) | ||
or not isinstance(percentiles[1], (int, float)) | ||
or percentiles[0] >= percentiles[1] | ||
): | ||
raise QueryObjectValidationError( | ||
_( | ||
"percentiles must be a list or tuple with two numeric values, " | ||
"of which the first is lower than the second value" | ||
) | ||
) | ||
low, high = percentiles[0], percentiles[1] | ||
|
||
def whisker_high(series: Series) -> float: | ||
return np.nanpercentile(series, high) | ||
|
||
def whisker_low(series: Series) -> float: | ||
return np.nanpercentile(series, low) | ||
|
||
else: | ||
whisker_high = np.max | ||
whisker_low = np.min | ||
|
||
def outliers(series: Series) -> Set[float]: | ||
above = series[series > whisker_high(series)] | ||
below = series[series < whisker_low(series)] | ||
return above.tolist() + below.tolist() | ||
|
||
operators: Dict[str, Callable[[Any], Any]] = { | ||
"mean": np.mean, | ||
"median": np.median, | ||
"max": whisker_high, | ||
"min": whisker_low, | ||
"q1": quartile1, | ||
"q3": quartile3, | ||
"count": np.ma.count, | ||
"outliers": outliers, | ||
} | ||
aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = { | ||
f"{metric}__{operator_name}": {"column": metric, "operator": operator} | ||
for operator_name, operator in operators.items() | ||
for metric in metrics | ||
} | ||
return aggregate(df, groupby=groupby, aggregates=aggregates) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import List, Optional | ||
|
||
import pandas as pd | ||
from flask_babel import gettext as _ | ||
from pandas import DataFrame | ||
|
||
from superset.constants import PandasPostprocessingCompare | ||
from superset.exceptions import QueryObjectValidationError | ||
from superset.utils.core import TIME_COMPARISION | ||
from superset.utils.pandas_postprocessing.utils import validate_column_args | ||
|
||
|
||
@validate_column_args("source_columns", "compare_columns") | ||
def compare( # pylint: disable=too-many-arguments | ||
df: DataFrame, | ||
source_columns: List[str], | ||
compare_columns: List[str], | ||
compare_type: Optional[PandasPostprocessingCompare], | ||
drop_original_columns: Optional[bool] = False, | ||
precision: Optional[int] = 4, | ||
) -> DataFrame: | ||
""" | ||
Calculate column-by-column changing for select columns. | ||
:param df: DataFrame on which the compare will be based. | ||
:param source_columns: Main query columns | ||
:param compare_columns: Columns being compared | ||
:param compare_type: Type of compare. Choice of `absolute`, `percentage` or `ratio` | ||
:param drop_original_columns: Whether to remove the source columns and | ||
compare columns. | ||
:param precision: Round a change rate to a variable number of decimal places. | ||
:return: DataFrame with compared columns. | ||
:raises QueryObjectValidationError: If the request in incorrect. | ||
""" | ||
if len(source_columns) != len(compare_columns): | ||
raise QueryObjectValidationError( | ||
_("`compare_columns` must have the same length as `source_columns`.") | ||
) | ||
if compare_type not in tuple(PandasPostprocessingCompare): | ||
raise QueryObjectValidationError( | ||
_("`compare_type` must be `difference`, `percentage` or `ratio`") | ||
) | ||
if len(source_columns) == 0: | ||
return df | ||
|
||
for s_col, c_col in zip(source_columns, compare_columns): | ||
if compare_type == PandasPostprocessingCompare.DIFF: | ||
diff_series = df[s_col] - df[c_col] | ||
elif compare_type == PandasPostprocessingCompare.PCT: | ||
diff_series = ( | ||
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision) | ||
) | ||
else: | ||
# compare_type == "ratio" | ||
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision) | ||
diff_df = diff_series.to_frame( | ||
name=TIME_COMPARISION.join([compare_type, s_col, c_col]) | ||
) | ||
df = pd.concat([df, diff_df], axis=1) | ||
|
||
if drop_original_columns: | ||
df = df.drop(source_columns + compare_columns, axis=1) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from decimal import Decimal | ||
from typing import List, Optional | ||
|
||
from flask_babel import gettext as _ | ||
from pandas import DataFrame | ||
|
||
from superset.exceptions import QueryObjectValidationError | ||
from superset.utils.core import PostProcessingContributionOrientation | ||
from superset.utils.pandas_postprocessing.utils import validate_column_args | ||
|
||
|
||
@validate_column_args("columns") | ||
def contribution( | ||
df: DataFrame, | ||
orientation: Optional[ | ||
PostProcessingContributionOrientation | ||
] = PostProcessingContributionOrientation.COLUMN, | ||
columns: Optional[List[str]] = None, | ||
rename_columns: Optional[List[str]] = None, | ||
) -> DataFrame: | ||
""" | ||
Calculate cell contibution to row/column total for numeric columns. | ||
Non-numeric columns will be kept untouched. | ||
If `columns` are specified, only calculate contributions on selected columns. | ||
:param df: DataFrame containing all-numeric data (temporal column ignored) | ||
:param columns: Columns to calculate values from. | ||
:param rename_columns: The new labels for the calculated contribution columns. | ||
The original columns will not be removed. | ||
:param orientation: calculate by dividing cell with row/column total | ||
:return: DataFrame with contributions. | ||
""" | ||
contribution_df = df.copy() | ||
numeric_df = contribution_df.select_dtypes(include=["number", Decimal]) | ||
# verify column selections | ||
if columns: | ||
numeric_columns = numeric_df.columns.tolist() | ||
for col in columns: | ||
if col not in numeric_columns: | ||
raise QueryObjectValidationError( | ||
_( | ||
'Column "%(column)s" is not numeric or does not ' | ||
"exists in the query results.", | ||
column=col, | ||
) | ||
) | ||
columns = columns or numeric_df.columns | ||
rename_columns = rename_columns or columns | ||
if len(rename_columns) != len(columns): | ||
raise QueryObjectValidationError( | ||
_("`rename_columns` must have the same length as `columns`.") | ||
) | ||
# limit to selected columns | ||
numeric_df = numeric_df[columns] | ||
axis = 0 if orientation == PostProcessingContributionOrientation.COLUMN else 1 | ||
numeric_df = numeric_df / numeric_df.values.sum(axis=axis, keepdims=True) | ||
contribution_df[rename_columns] = numeric_df | ||
return contribution_df |
Oops, something went wrong.