Skip to content

Commit

Permalink
perf: avoid redundant copies when resetting index for pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Oct 25, 2024
1 parent 21c20f8 commit f4777e3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
10 changes: 6 additions & 4 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import reset_index_no_copy
from narwhals.utils import Implementation
from narwhals.utils import remove_prefix

Expand Down Expand Up @@ -187,13 +188,14 @@ def agg_pandas( # noqa: PLR0915
]
result_simple_aggs = result_simple_aggs.rename(
columns=name_mapping, copy=False
).reset_index()
)
reset_index_no_copy(result_simple_aggs, native_namespace)
if nunique_aggs:
result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique(
dropna=False
)
result_nunique_aggs.columns = list(nunique_aggs.keys())
result_nunique_aggs = result_nunique_aggs.reset_index()
reset_index_no_copy(result_nunique_aggs, native_namespace)
if simple_aggs and nunique_aggs:
if (
set(result_simple_aggs.columns)
Expand Down Expand Up @@ -259,6 +261,6 @@ def func(df: Any) -> Any:
else: # pragma: no cover
result_complex = grouped.apply(func)

result = result_complex.reset_index()
reset_index_no_copy(result_complex, native_namespace)

return from_dataframe(result.loc[:, output_names])
return from_dataframe(result_complex.loc[:, output_names])
4 changes: 3 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals._pandas_like.utils import narwhals_to_native_dtype
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import reset_index_no_copy
from narwhals._pandas_like.utils import set_axis
from narwhals._pandas_like.utils import to_datetime
from narwhals._pandas_like.utils import validate_column_comparand
Expand Down Expand Up @@ -598,7 +599,8 @@ def value_counts(
dropna=False,
sort=False,
normalize=normalize,
).reset_index()
)
reset_index_no_copy(val_count, self.__native_namespace__())

val_count.columns = [index_name_, value_name_]

Expand Down
13 changes: 13 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,16 @@ def calculate_timestamp_date(s: pd.Series, time_unit: str) -> pd.Series:
else:
result = s * 1_000
return result


def reset_index_no_copy(native_series_or_frame: Any, native_namespace: Any) -> None:
if set(native_series_or_frame.index.names).intersection(
native_series_or_frame.columns
):
msg = f"Cannot insert column with name {native_series_or_frame.index.name} into dataframe with columns {native_series_or_frame.columns}"
raise ValueError(msg)
for i, name in enumerate(native_series_or_frame.index.names):
native_series_or_frame[name] = native_series_or_frame.index.get_level_values(i)
native_series_or_frame.index = native_namespace.RangeIndex(
len(native_series_or_frame)
)

0 comments on commit f4777e3

Please sign in to comment.