Skip to content

Commit

Permalink
feat: add DataFrame.struct.explode to add struct subfields to a Dat…
Browse files Browse the repository at this point in the history
…aFrame (#916)

* feat: add `DataFrame.struct.explode` to add struct subfields to a DataFrame

* add tests for multiple columns and custom separator

* fix system test
  • Loading branch information
tswast authored Aug 23, 2024
1 parent 575a29e commit ad2f75e
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 10 deletions.
36 changes: 36 additions & 0 deletions bigframes/core/explode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for implementing 'explode' functions."""

from typing import cast, Sequence, Union

import bigframes.core.blocks as blocks
import bigframes.core.utils as utils


def check_column(
column: Union[blocks.Label, Sequence[blocks.Label]],
) -> Sequence[blocks.Label]:
if not utils.is_list_like(column):
column_labels = cast(Sequence[blocks.Label], (column,))
else:
column_labels = cast(Sequence[blocks.Label], tuple(column))

if not column_labels:
raise ValueError("column must be nonempty")
if len(column_labels) > len(set(column_labels)):
raise ValueError("column must be unique")

return column_labels
16 changes: 7 additions & 9 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import bigframes.core.block_transforms as block_ops
import bigframes.core.blocks as blocks
import bigframes.core.convert
import bigframes.core.explode
import bigframes.core.expression as ex
import bigframes.core.groupby as groupby
import bigframes.core.guid
Expand All @@ -71,6 +72,7 @@
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
import bigframes.operations.plotting as plotting
import bigframes.operations.structs
import bigframes.series
import bigframes.series as bf_series
import bigframes.session._io.bigquery
Expand Down Expand Up @@ -2875,15 +2877,7 @@ def explode(
*,
ignore_index: Optional[bool] = False,
) -> DataFrame:
if not utils.is_list_like(column):
column_labels = typing.cast(typing.Sequence[blocks.Label], (column,))
else:
column_labels = typing.cast(typing.Sequence[blocks.Label], tuple(column))

if not column_labels:
raise ValueError("column must be nonempty")
if len(column_labels) > len(set(column_labels)):
raise ValueError("column must be unique")
column_labels = bigframes.core.explode.check_column(column)

column_ids = [self._resolve_label_exact(label) for label in column_labels]
missing = [
Expand Down Expand Up @@ -3751,6 +3745,10 @@ def __matmul__(self, other) -> DataFrame:

__matmul__.__doc__ = inspect.getdoc(vendored_pandas_frame.DataFrame.__matmul__)

@property
def struct(self):
return bigframes.operations.structs.StructFrameAccessor(self)

def _throw_if_null_index(self, opname: str):
if not self._has_index:
raise bigframes.exceptions.NullIndexError(
Expand Down
23 changes: 23 additions & 0 deletions bigframes/operations/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,26 @@ def dtypes(self) -> pd.Series:
],
index=[pa_type.field(i).name for i in range(pa_type.num_fields)],
)


@log_adapter.class_logger
class StructFrameAccessor(vendoracessors.StructFrameAccessor):
__doc__ = vendoracessors.StructAccessor.__doc__

def __init__(self, data: bigframes.dataframe.DataFrame) -> None:
self._parent = data

def explode(self, column, *, separator: str = ".") -> bigframes.dataframe.DataFrame:
df = self._parent
column_labels = bigframes.core.explode.check_column(column)

for label in column_labels:
position = df.columns.to_list().index(label)
df = df.drop(columns=label)
subfields = self._parent[label].struct.explode()
for subfield in reversed(subfields.columns):
df.insert(
position, f"{label}{separator}{subfield}", subfields[subfield]
)

return df
2 changes: 1 addition & 1 deletion tests/data/nested.jsonl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{"rowindex":0,"customer_id":"jkl","day":"2023-12-18","flag":1,"event_sequence":[{"category":"B","timestamp":"2023-12-18 03:43:58","data":[{"key":"x","value":20.2533015856},{"key":"y","value":42.8363462389}]},{"category":"D","timestamp":"2023-12-18 07:15:37","data":[{"key":"x","value":62.0762664928},{"key":"z","value":83.6655402432}]}]}
{"rowindex":0,"customer_id":"jkl","day":"2023-12-18","flag":1,"label":{"key": "my-key","value":"my-value"},"event_sequence":[{"category":"B","timestamp":"2023-12-18 03:43:58","data":[{"key":"x","value":20.2533015856},{"key":"y","value":42.8363462389}]},{"category":"D","timestamp":"2023-12-18 07:15:37","data":[{"key":"x","value":62.0762664928},{"key":"z","value":83.6655402432}]}],"address":{"street":"123 Test Lane","city":"Testerchon"}}
{"rowindex":1,"customer_id":"def","day":"2023-12-18","flag":2,"event_sequence":[{"category":"D","timestamp":"2023-12-18 23:11:11","data":[{"key":"w","value":36.1388065179}]},{"category":"B","timestamp":"2023-12-18 07:12:50","data":[{"key":"z","value":68.7673488304}]},{"category":"D","timestamp":"2023-12-18 09:09:03","data":[{"key":"x","value":57.4139647019}]},{"category":"C","timestamp":"2023-12-18 13:05:30","data":[{"key":"z","value":36.087871201}]}]}
{"rowindex":2,"customer_id":"abc","day":"2023-12-6","flag":0,"event_sequence":[{"category":"C","timestamp":"2023-12-06 10:37:11","data":[]},{"category":"A","timestamp":"2023-12-06 03:35:44","data":[]},{"category":"D","timestamp":"2023-12-06 13:10:57","data":[{"key":"z","value":21.8487807658}]},{"category":"B","timestamp":"2023-12-06 01:39:16","data":[{"key":"y","value":1.6380505139}]}]}
{"rowindex":3,"customer_id":"mno","day":"2023-12-16","flag":2,"event_sequence":[]}
Expand Down
28 changes: 28 additions & 0 deletions tests/data/nested_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
"name": "flag",
"type": "INTEGER"
},
{
"fields": [
{
"name": "key",
"type": "STRING"
},
{
"name": "value",
"type": "STRING"
}
],
"name": "label",
"type": "RECORD"
},
{
"fields": [
{
Expand Down Expand Up @@ -52,5 +66,19 @@
"mode": "REPEATED",
"name": "event_sequence",
"type": "RECORD"
},
{
"fields": [
{
"name": "street",
"type": "STRING"
},
{
"name": "city",
"type": "STRING"
}
],
"name": "address",
"type": "RECORD"
}
]
40 changes: 40 additions & 0 deletions tests/system/small/operations/test_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def test_dataframe_struct_explode_multiple_columns(nested_df):
got = nested_df.struct.explode(["label", "address"])
assert got.columns.to_list() == [
"customer_id",
"day",
"flag",
"label.key",
"label.value",
"event_sequence",
"address.street",
"address.city",
]


def test_dataframe_struct_explode_separator(nested_df):
got = nested_df.struct.explode("label", separator="__sep__")
assert got.columns.to_list() == [
"customer_id",
"day",
"flag",
"label__sep__key",
"label__sep__value",
"event_sequence",
"address",
]
16 changes: 16 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,14 @@ def test_get_dtypes_array_struct_table(nested_df):
"customer_id": pd.StringDtype(storage="pyarrow"),
"day": pd.ArrowDtype(pa.date32()),
"flag": pd.Int64Dtype(),
"label": pd.ArrowDtype(
pa.struct(
[
("key", pa.string()),
("value", pa.string()),
]
),
),
"event_sequence": pd.ArrowDtype(
pa.list_(
pa.struct(
Expand All @@ -1457,6 +1465,14 @@ def test_get_dtypes_array_struct_table(nested_df):
),
),
),
"address": pd.ArrowDtype(
pa.struct(
[
("street", pa.string()),
("city", pa.string()),
]
),
),
}
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,56 @@ def dtypes(self):
A *pandas* Series with the data type of all child fields.
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)


class StructFrameAccessor:
"""
Accessor object for structured data properties of the DataFrame values.
"""

def explode(self, column, *, separator: str = "."):
"""
Extract all child fields of struct column(s) and add to the DataFrame.
**Examples:**
>>> import bigframes.pandas as bpd
>>> import pyarrow as pa
>>> bpd.options.display.progress_bar = None
>>> countries = bpd.Series(["cn", "es", "us"])
>>> files = bpd.Series(
... [
... {"version": 1, "project": "pandas"},
... {"version": 2, "project": "pandas"},
... {"version": 1, "project": "numpy"},
... ],
... dtype=bpd.ArrowDtype(pa.struct(
... [("version", pa.int64()), ("project", pa.string())]
... ))
... )
>>> downloads = bpd.Series([100, 200, 300])
>>> df = bpd.DataFrame({"country": countries, "file": files, "download_count": downloads})
>>> df.struct.explode("file")
country file.version file.project download_count
0 cn 1 pandas 100
1 es 2 pandas 200
2 us 1 numpy 300
<BLANKLINE>
[3 rows x 4 columns]
Args:
column:
Column(s) to explode. For multiple columns, specify a non-empty
list with each element be str or tuple, and all specified
columns their list-like data on same row of the frame must
have matching length.
separator:
Separator/delimiter to use to separate the original column name
from the sub-field column name.
Returns:
DataFrame:
Original DataFrame with exploded struct column(s).
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 comments on commit ad2f75e

Please sign in to comment.