Skip to content

Commit

Permalink
BUG/CLN: Decouple Series/DataFrame.transform (pandas-dev#35964)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhshadrach authored and Kevin D Smith committed Nov 2, 2020
1 parent 9fc0db2 commit 621ca9f
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 169 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ Other
^^^^^
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` incorrectly raising ``AssertionError`` instead of ``ValueError`` when invalid parameter combinations are passed (:issue:`36045`)
- Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` with numeric values and string ``to_replace`` (:issue:`34789`)
- Bug in :meth:`Series.transform` would give incorrect results or raise when the argument ``func`` was dictionary (:issue:`35811`)
-

.. ---------------------------------------------------------------------------
Expand Down
98 changes: 97 additions & 1 deletion pandas/core/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
Union,
)

from pandas._typing import AggFuncType, FrameOrSeries, Label
from pandas._typing import AggFuncType, Axis, FrameOrSeries, Label

from pandas.core.dtypes.common import is_dict_like, is_list_like
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries

from pandas.core.base import SpecificationError
import pandas.core.common as com
Expand Down Expand Up @@ -384,3 +385,98 @@ def validate_func_kwargs(
if not columns:
raise TypeError(no_arg_message)
return columns, func


def transform(
obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs,
) -> FrameOrSeries:
"""
Transform a DataFrame or Series
Parameters
----------
obj : DataFrame or Series
Object to compute the transform on.
func : string, function, list, or dictionary
Function(s) to compute the transform with.
axis : {0 or 'index', 1 or 'columns'}
Axis along which the function is applied:
* 0 or 'index': apply function to each column.
* 1 or 'columns': apply function to each row.
Returns
-------
DataFrame or Series
Result of applying ``func`` along the given axis of the
Series or DataFrame.
Raises
------
ValueError
If the transform function fails or does not transform.
"""
from pandas.core.reshape.concat import concat

is_series = obj.ndim == 1

if obj._get_axis_number(axis) == 1:
assert not is_series
return transform(obj.T, func, 0, *args, **kwargs).T

if isinstance(func, list):
if is_series:
func = {com.get_callable_name(v) or v: v for v in func}
else:
func = {col: func for col in obj}

if isinstance(func, dict):
if not is_series:
cols = sorted(set(func.keys()) - set(obj.columns))
if len(cols) > 0:
raise SpecificationError(f"Column(s) {cols} do not exist")

if any(isinstance(v, dict) for v in func.values()):
# GH 15931 - deprecation of renaming keys
raise SpecificationError("nested renamer is not supported")

results = {}
for name, how in func.items():
colg = obj._gotitem(name, ndim=1)
try:
results[name] = transform(colg, how, 0, *args, **kwargs)
except Exception as e:
if str(e) == "Function did not transform":
raise e

# combine results
if len(results) == 0:
raise ValueError("Transform function failed")
return concat(results, axis=1)

# func is either str or callable
try:
if isinstance(func, str):
result = obj._try_aggregate_string_function(func, *args, **kwargs)
else:
f = obj._get_cython_func(func)
if f and not args and not kwargs:
result = getattr(obj, f)()
else:
try:
result = obj.apply(func, args=args, **kwargs)
except Exception:
result = func(obj, *args, **kwargs)
except Exception:
raise ValueError("Transform function failed")

# Functions that transform may return empty Series/DataFrame
# when the dtype is not appropriate
if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty:
raise ValueError("Transform function failed")
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
obj.index
):
raise ValueError("Function did not transform")

return result
4 changes: 2 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import builtins
import textwrap
from typing import Any, Dict, FrozenSet, List, Optional, Union
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -560,7 +560,7 @@ def _aggregate_multiple_funcs(self, arg, _axis):
) from err
return result

def _get_cython_func(self, arg: str) -> Optional[str]:
def _get_cython_func(self, arg: Callable) -> Optional[str]:
"""
if we define an internal function for this argument, return it
"""
Expand Down
16 changes: 9 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pandas._libs import algos as libalgos, lib, properties
from pandas._libs.lib import no_default
from pandas._typing import (
AggFuncType,
ArrayLike,
Axes,
Axis,
Expand Down Expand Up @@ -116,7 +117,7 @@

from pandas.core import algorithms, common as com, nanops, ops
from pandas.core.accessor import CachedAccessor
from pandas.core.aggregation import reconstruct_func, relabel_result
from pandas.core.aggregation import reconstruct_func, relabel_result, transform
from pandas.core.arrays import Categorical, ExtensionArray
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin as DatetimeLikeArray
from pandas.core.arrays.sparse import SparseFrameAccessor
Expand Down Expand Up @@ -7462,15 +7463,16 @@ def _aggregate(self, arg, axis=0, *args, **kwargs):
agg = aggregate

@doc(
NDFrame.transform,
_shared_docs["transform"],
klass=_shared_doc_kwargs["klass"],
axis=_shared_doc_kwargs["axis"],
)
def transform(self, func, axis=0, *args, **kwargs) -> DataFrame:
axis = self._get_axis_number(axis)
if axis == 1:
return self.T.transform(func, *args, **kwargs).T
return super().transform(func, *args, **kwargs)
def transform(
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
) -> DataFrame:
result = transform(self, func, axis, *args, **kwargs)
assert isinstance(result, DataFrame)
return result

def apply(self, func, axis=0, raw=False, result_type=None, args=(), **kwds):
"""
Expand Down
74 changes: 0 additions & 74 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10648,80 +10648,6 @@ def ewm(
times=times,
)

@doc(klass=_shared_doc_kwargs["klass"], axis="")
def transform(self, func, *args, **kwargs):
"""
Call ``func`` on self producing a {klass} with transformed values.
Produced {klass} will have same axis length as self.
Parameters
----------
func : function, str, list or dict
Function to use for transforming the data. If a function, must either
work when passed a {klass} or when passed to {klass}.apply.
Accepted combinations are:
- function
- string function name
- list of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
- dict of axis labels -> functions, function names or list of such.
{axis}
*args
Positional arguments to pass to `func`.
**kwargs
Keyword arguments to pass to `func`.
Returns
-------
{klass}
A {klass} that must have the same length as self.
Raises
------
ValueError : If the returned {klass} has a different length than self.
See Also
--------
{klass}.agg : Only perform aggregating type operations.
{klass}.apply : Invoke function on a {klass}.
Examples
--------
>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}})
>>> df
A B
0 0 1
1 1 2
2 2 3
>>> df.transform(lambda x: x + 1)
A B
0 1 2
1 2 3
2 3 4
Even though the resulting {klass} must have the same length as the
input {klass}, it is possible to provide several input functions:
>>> s = pd.Series(range(3))
>>> s
0 0
1 1
2 2
dtype: int64
>>> s.transform([np.sqrt, np.exp])
sqrt exp
0 0.000000 1.000000
1 1.000000 2.718282
2 1.414214 7.389056
"""
result = self.agg(func, *args, **kwargs)
if is_scalar(result) or len(result) != len(self):
raise ValueError("transforms cannot produce aggregated results")

return result

# ----------------------------------------------------------------------
# Misc methods

Expand Down
14 changes: 9 additions & 5 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pandas._libs import lib, properties, reshape, tslibs
from pandas._libs.lib import no_default
from pandas._typing import (
AggFuncType,
ArrayLike,
Axis,
DtypeObj,
Expand Down Expand Up @@ -89,6 +90,7 @@
from pandas.core.indexes.timedeltas import TimedeltaIndex
from pandas.core.indexing import check_bool_indexer
from pandas.core.internals import SingleBlockManager
from pandas.core.shared_docs import _shared_docs
from pandas.core.sorting import ensure_key_mapped
from pandas.core.strings import StringMethods
from pandas.core.tools.datetimes import to_datetime
Expand Down Expand Up @@ -4081,14 +4083,16 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
agg = aggregate

@doc(
NDFrame.transform,
_shared_docs["transform"],
klass=_shared_doc_kwargs["klass"],
axis=_shared_doc_kwargs["axis"],
)
def transform(self, func, axis=0, *args, **kwargs):
# Validate the axis parameter
self._get_axis_number(axis)
return super().transform(func, *args, **kwargs)
def transform(
self, func: AggFuncType, axis: Axis = 0, *args, **kwargs
) -> FrameOrSeriesUnion:
from pandas.core.aggregation import transform

return transform(self, func, axis, *args, **kwargs)

def apply(self, func, convert_dtype=True, args=(), **kwds):
"""
Expand Down
69 changes: 69 additions & 0 deletions pandas/core/shared_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,72 @@
1 b B E 3
2 c B E 5
"""

_shared_docs[
"transform"
] = """\
Call ``func`` on self producing a {klass} with transformed values.
Produced {klass} will have same axis length as self.
Parameters
----------
func : function, str, list or dict
Function to use for transforming the data. If a function, must either
work when passed a {klass} or when passed to {klass}.apply.
Accepted combinations are:
- function
- string function name
- list of functions and/or function names, e.g. ``[np.exp, 'sqrt']``
- dict of axis labels -> functions, function names or list of such.
{axis}
*args
Positional arguments to pass to `func`.
**kwargs
Keyword arguments to pass to `func`.
Returns
-------
{klass}
A {klass} that must have the same length as self.
Raises
------
ValueError : If the returned {klass} has a different length than self.
See Also
--------
{klass}.agg : Only perform aggregating type operations.
{klass}.apply : Invoke function on a {klass}.
Examples
--------
>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}})
>>> df
A B
0 0 1
1 1 2
2 2 3
>>> df.transform(lambda x: x + 1)
A B
0 1 2
1 2 3
2 3 4
Even though the resulting {klass} must have the same length as the
input {klass}, it is possible to provide several input functions:
>>> s = pd.Series(range(3))
>>> s
0 0
1 1
2 2
dtype: int64
>>> s.transform([np.sqrt, np.exp])
sqrt exp
0 0.000000 1.000000
1 1.000000 2.718282
2 1.414214 7.389056
"""
Loading

0 comments on commit 621ca9f

Please sign in to comment.