Skip to content

Commit

Permalink
Add stricter typing and validation to ColumnAccessor (#16602)
Browse files Browse the repository at this point in the history
* Added typing annotations that are generally a little stricter on when `Column`s should be passed. Added error handling for these cases
* Moved some argument checking that was performed on `DataFrame` to `ColumnAccessor`
* Adding more `verify=False` to `ColumnAccessor` calls and preserving `.label_dtype` more when we're just selecting columns from the prior `ColumnAccessor`

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #16602
  • Loading branch information
mroeschke authored Aug 20, 2024
1 parent 28fee97 commit 58799d6
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 75 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def read_csv(
# Set index if the index_col parameter is passed
if index_col is not None and index_col is not False:
if isinstance(index_col, int):
index_col_name = df._data.select_by_index(index_col).names[0]
index_col_name = df._data.get_labels_by_index(index_col)[0]
df = df.set_index(index_col_name)
if isinstance(index_col_name, str) and \
names is None and orig_header == "infer":
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,7 +1698,7 @@ def join(
# in case of MultiIndex
if isinstance(lhs, cudf.MultiIndex):
on = (
lhs._data.select_by_index(level).names[0]
lhs._data.get_labels_by_index(level)[0]
if isinstance(level, int)
else level
)
Expand Down
114 changes: 64 additions & 50 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
rangeindex: bool = False,
label_dtype: Dtype | None = None,
verify: bool = True,
):
) -> None:
if isinstance(data, ColumnAccessor):
self._data = data._data
self._level_names = data.level_names
Expand Down Expand Up @@ -147,10 +147,10 @@ def __iter__(self):
def __getitem__(self, key: Any) -> ColumnBase:
return self._data[key]

def __setitem__(self, key: Any, value: Any):
def __setitem__(self, key: Any, value: ColumnBase) -> None:
self.set_by_label(key, value)

def __delitem__(self, key: Any):
def __delitem__(self, key: Any) -> None:
old_ncols = len(self._data)
del self._data[key]
new_ncols = len(self._data)
Expand All @@ -174,7 +174,7 @@ def __repr__(self) -> str:

def _from_columns_like_self(
self, columns: abc.Iterable[ColumnBase], verify: bool = True
):
) -> Self:
"""
Return a new ColumnAccessor with columns and the properties of self.
Expand Down Expand Up @@ -250,7 +250,7 @@ def _grouped_data(self) -> abc.MutableMapping:
else:
return self._data

def _clear_cache(self, old_ncols: int, new_ncols: int):
def _clear_cache(self, old_ncols: int, new_ncols: int) -> None:
"""
Clear cached attributes.
Expand Down Expand Up @@ -310,16 +310,14 @@ def to_pandas_index(self) -> pd.Index:
)
return result

def insert(
self, name: Any, value: Any, loc: int = -1, validate: bool = True
):
def insert(self, name: Any, value: ColumnBase, loc: int = -1) -> None:
"""
Insert column into the ColumnAccessor at the specified location.
Parameters
----------
name : Name corresponding to the new column
value : column-like
value : ColumnBase
loc : int, optional
The location to insert the new value at.
Must be (0 <= loc <= ncols). By default, the column is added
Expand All @@ -330,33 +328,35 @@ def insert(
None, this function operates in-place.
"""
name = self._pad_key(name)
if name in self._data:
raise ValueError(f"Cannot insert '{name}', already exists")

old_ncols = len(self._data)
if loc == -1:
loc = old_ncols
if not (0 <= loc <= old_ncols):
elif not (0 <= loc <= old_ncols):
raise ValueError(
f"insert: loc out of bounds: must be 0 <= loc <= {old_ncols}"
)

if not isinstance(value, column.ColumnBase):
raise ValueError("value must be a Column")
elif old_ncols > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")

# TODO: we should move all insert logic here
if name in self._data:
raise ValueError(f"Cannot insert '{name}', already exists")
if loc == old_ncols:
if validate:
value = column.as_column(value)
if old_ncols > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")
self._data[name] = value
else:
new_keys = self.names[:loc] + (name,) + self.names[loc:]
new_values = self.columns[:loc] + (value,) + self.columns[loc:]
self._data = self._data.__class__(zip(new_keys, new_values))
self._data = dict(zip(new_keys, new_values))
self._clear_cache(old_ncols, old_ncols + 1)
if old_ncols == 0:
# The type(name) may no longer match the prior label_dtype
self.label_dtype = None

def copy(self, deep=False) -> ColumnAccessor:
def copy(self, deep: bool = False) -> Self:
"""
Make a copy of this ColumnAccessor.
"""
Expand All @@ -373,7 +373,7 @@ def copy(self, deep=False) -> ColumnAccessor:
verify=False,
)

def select_by_label(self, key: Any) -> ColumnAccessor:
def select_by_label(self, key: Any) -> Self:
"""
Return a subset of this column accessor,
composed of the keys specified by `key`.
Expand All @@ -389,7 +389,7 @@ def select_by_label(self, key: Any) -> ColumnAccessor:
if isinstance(key, slice):
return self._select_by_label_slice(key)
elif pd.api.types.is_list_like(key) and not isinstance(key, tuple):
return self._select_by_label_list_like(key)
return self._select_by_label_list_like(tuple(key))
else:
if isinstance(key, tuple):
if any(isinstance(k, slice) for k in key):
Expand Down Expand Up @@ -427,9 +427,13 @@ def get_labels_by_index(self, index: Any) -> tuple:
# TODO: Doesn't handle on-device columns
return tuple(n for n, keep in zip(self.names, index) if keep)
else:
if len(set(index)) != len(index):
raise NotImplementedError(
"Selecting duplicate column labels is not supported."
)
return tuple(self.names[i] for i in index)

def select_by_index(self, index: Any) -> ColumnAccessor:
def select_by_index(self, index: Any) -> Self:
"""
Return a ColumnAccessor composed of the columns
specified by index.
Expand All @@ -445,13 +449,15 @@ def select_by_index(self, index: Any) -> ColumnAccessor:
"""
keys = self.get_labels_by_index(index)
data = {k: self._data[k] for k in keys}
return self.__class__(
return type(self)(
data,
multiindex=self.multiindex,
level_names=self.level_names,
label_dtype=self.label_dtype,
verify=False,
)

def swaplevel(self, i=-2, j=-1):
def swaplevel(self, i=-2, j=-1) -> Self:
"""
Swap level i with level j.
Calling this method does not change the ordering of the values.
Expand All @@ -467,6 +473,10 @@ def swaplevel(self, i=-2, j=-1):
-------
ColumnAccessor
"""
if not self.multiindex:
raise ValueError(
"swaplevel is only valid for self.multiindex=True"
)

i = _get_level(i, self.nlevels, self.level_names)
j = _get_level(j, self.nlevels, self.level_names)
Expand All @@ -486,40 +496,38 @@ def swaplevel(self, i=-2, j=-1):
new_names = list(self.level_names)
new_names[i], new_names[j] = new_names[j], new_names[i]

return self.__class__(
return type(self)(
new_data,
multiindex=True,
multiindex=self.multiindex,
level_names=new_names,
rangeindex=self.rangeindex,
label_dtype=self.label_dtype,
verify=False,
)

def set_by_label(self, key: Any, value: Any, validate: bool = True):
def set_by_label(self, key: Any, value: ColumnBase) -> None:
"""
Add (or modify) column by name.
Parameters
----------
key
name of the column
value : column-like
value : Column
The value to insert into the column.
validate : bool
If True, the provided value will be coerced to a column and
validated before setting (Default value = True).
"""
key = self._pad_key(key)
if validate:
value = column.as_column(value)
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")
if not isinstance(value, column.ColumnBase):
raise ValueError("value must be a Column")
if len(self) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")

old_ncols = len(self._data)
self._data[key] = value
new_ncols = len(self._data)
self._clear_cache(old_ncols, new_ncols)

def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
# Might be a generator
key = tuple(key)
def _select_by_label_list_like(self, key: tuple) -> Self:
# Special-casing for boolean mask
if (bn := len(key)) > 0 and all(map(is_bool, key)):
if bn != (n := len(self.names)):
Expand All @@ -539,19 +547,22 @@ def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
)
if self.multiindex:
data = dict(_to_flat_dict_inner(data))
return self.__class__(
return type(self)(
data,
multiindex=self.multiindex,
level_names=self.level_names,
label_dtype=self.label_dtype,
verify=False,
)

def _select_by_label_grouped(self, key: Any) -> ColumnAccessor:
def _select_by_label_grouped(self, key: Any) -> Self:
result = self._grouped_data[key]
if isinstance(result, column.ColumnBase):
# self._grouped_data[key] = self._data[key] so skip validation
return self.__class__(
return type(self)(
data={key: result},
multiindex=self.multiindex,
label_dtype=self.label_dtype,
verify=False,
)
else:
Expand All @@ -563,9 +574,10 @@ def _select_by_label_grouped(self, key: Any) -> ColumnAccessor:
result,
multiindex=self.nlevels - len(key) > 1,
level_names=self.level_names[len(key) :],
verify=False,
)

def _select_by_label_slice(self, key: slice) -> ColumnAccessor:
def _select_by_label_slice(self, key: slice) -> Self:
start, stop = key.start, key.stop
if key.step is not None:
raise TypeError("Label slicing with step is not supported")
Expand All @@ -585,19 +597,22 @@ def _select_by_label_slice(self, key: slice) -> ColumnAccessor:
stop_idx = len(self.names) - idx
break
keys = self.names[start_idx:stop_idx]
return self.__class__(
return type(self)(
{k: self._data[k] for k in keys},
multiindex=self.multiindex,
level_names=self.level_names,
label_dtype=self.label_dtype,
verify=False,
)

def _select_by_label_with_wildcard(self, key: Any) -> ColumnAccessor:
def _select_by_label_with_wildcard(self, key: tuple) -> Self:
key = self._pad_key(key, slice(None))
return self.__class__(
{k: self._data[k] for k in self._data if _keys_equal(k, key)},
data = {k: self._data[k] for k in self.names if _keys_equal(k, key)}
return type(self)(
data,
multiindex=self.multiindex,
level_names=self.level_names,
label_dtype=self.label_dtype,
verify=False,
)

Expand All @@ -614,7 +629,7 @@ def _pad_key(self, key: Any, pad_value="") -> Any:

def rename_levels(
self, mapper: Mapping[Any, Any] | Callable, level: int | None = None
) -> ColumnAccessor:
) -> Self:
"""
Rename the specified levels of the given ColumnAccessor
Expand Down Expand Up @@ -686,7 +701,7 @@ def rename_column(x):
verify=False,
)

def droplevel(self, level):
def droplevel(self, level) -> None:
# drop the nth level
if level < 0:
level += self.nlevels
Expand All @@ -701,9 +716,8 @@ def droplevel(self, level):
self._level_names[:level] + self._level_names[level + 1 :]
)

if (
len(self._level_names) == 1
): # can't use nlevels, as it depends on multiindex
if len(self._level_names) == 1:
# can't use nlevels, as it depends on multiindex
self.multiindex = False
self._clear_cache(old_ncols, new_ncols)

Expand Down
14 changes: 6 additions & 8 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,19 @@ def _setitem_tuple_arg(self, key, value):
value = as_column(value, length=length)

if isinstance(value, ColumnBase):
new_col = cudf.Series._from_column(value, index=idx)
new_ser = cudf.Series._from_column(value, index=idx)
else:
new_col = cudf.Series(value, index=idx)
new_ser = cudf.Series(value, index=idx)
if len(self._frame.index) != 0:
new_col = new_col._align_to_index(
new_ser = new_ser._align_to_index(
self._frame.index, how="right"
)

if len(self._frame.index) == 0:
self._frame.index = (
idx if idx is not None else cudf.RangeIndex(len(new_col))
idx if idx is not None else cudf.RangeIndex(len(new_ser))
)
self._frame._data.insert(key[1], new_col)
self._frame._data.insert(key[1], new_ser._column)
else:
if is_scalar(value):
for col in columns_df._column_names:
Expand Down Expand Up @@ -981,6 +981,7 @@ def _init_from_series_list(self, data, columns, index):
self._data.rangeindex = isinstance(
columns, (range, cudf.RangeIndex, pd.RangeIndex)
)
self._data.label_dtype = pd.Index(columns).dtype
else:
self._data.rangeindex = True

Expand Down Expand Up @@ -3272,9 +3273,6 @@ def _insert(self, loc, name, value, nan_as_null=None, ignore_index=True):
If False, a reindexing operation is performed if
`value.index` is not equal to `self.index`.
"""
if name in self._data:
raise NameError(f"duplicated column name {name}")

num_cols = self._num_columns
if loc < 0:
loc += num_cols + 1
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,9 +1010,7 @@ def _copy_type_metadata(self: Self, other: Self) -> Self:
See `ColumnBase._with_type_metadata` for more information.
"""
for (name, col), (_, dtype) in zip(self._data.items(), other._dtypes):
self._data.set_by_label(
name, col._with_type_metadata(dtype), validate=False
)
self._data.set_by_label(name, col._with_type_metadata(dtype))

return self

Expand Down
4 changes: 0 additions & 4 deletions python/cudf/cudf/core/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ def destructure_dataframe_iloc_indexer(
column_names: ColumnLabels = list(
frame._data.get_labels_by_index(cols)
)
if len(set(column_names)) != len(column_names):
raise NotImplementedError(
"cudf DataFrames do not support repeated column names"
)
except TypeError:
raise TypeError(
"Column indices must be integers, slices, "
Expand Down
Loading

0 comments on commit 58799d6

Please sign in to comment.