Skip to content

Commit a3b1f25

Browse files
authored
Ensure dtype objects are passed within Column.astype (rapidsai#17978)
Continuation of rapidsai#17918 * Some modified `astype` calls might be from `Series/Index/etc.`, but IMO it's OK if we are a bit stricter and pass dtype objects when calling those methods too * Added some stricter typings to `Column.as_*_column` since they are called by `Column.astype` Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: rapidsai#17978
1 parent f38445a commit a3b1f25

20 files changed

+127
-126
lines changed

docs/cudf/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ def on_missing_reference(app, env, node, contnode):
583583
("py:obj", "cudf.Index.to_flat_index"),
584584
("py:obj", "cudf.MultiIndex.to_flat_index"),
585585
("py:meth", "pyarrow.Table.to_pandas"),
586+
("py:class", "abc.Hashable"),
586587
("py:class", "pd.DataFrame"),
587588
("py:class", "pandas.core.indexes.frozen.FrozenList"),
588589
("py:class", "pa.Array"),

python/cudf/cudf/core/column/categorical.py

+6-19
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def codes(self) -> cudf.Series:
150150
return cudf.Series._from_column(self._column.codes, index=index)
151151

152152
@property
153-
def ordered(self) -> bool:
153+
def ordered(self) -> bool | None:
154154
"""
155155
Whether the categories have an ordered relationship.
156156
"""
@@ -620,7 +620,7 @@ def codes(self) -> NumericalColumn:
620620
return self._codes
621621

622622
@property
623-
def ordered(self) -> bool:
623+
def ordered(self) -> bool | None:
624624
return self.dtype.ordered
625625

626626
def __setitem__(self, key, value):
@@ -1086,20 +1086,7 @@ def is_monotonic_increasing(self) -> bool:
10861086
def is_monotonic_decreasing(self) -> bool:
10871087
return bool(self.ordered) and self.codes.is_monotonic_decreasing
10881088

1089-
def as_categorical_column(self, dtype: Dtype) -> Self:
1090-
if isinstance(dtype, str) and dtype == "category":
1091-
return self
1092-
if isinstance(dtype, pd.CategoricalDtype):
1093-
dtype = cudf.CategoricalDtype.from_pandas(dtype)
1094-
if (
1095-
isinstance(dtype, cudf.CategoricalDtype)
1096-
and dtype.categories is None
1097-
and dtype.ordered is None
1098-
):
1099-
return self
1100-
elif not isinstance(dtype, CategoricalDtype):
1101-
raise ValueError("dtype must be CategoricalDtype")
1102-
1089+
def as_categorical_column(self, dtype: cudf.CategoricalDtype) -> Self:
11031090
if not isinstance(self.categories, type(dtype.categories._column)):
11041091
if isinstance(
11051092
self.categories.dtype, cudf.StructDtype
@@ -1130,16 +1117,16 @@ def as_categorical_column(self, dtype: Dtype) -> Self:
11301117
new_categories=dtype.categories, ordered=bool(dtype.ordered)
11311118
)
11321119

1133-
def as_numerical_column(self, dtype: Dtype) -> NumericalColumn:
1120+
def as_numerical_column(self, dtype: np.dtype) -> NumericalColumn:
11341121
return self._get_decategorized_column().as_numerical_column(dtype)
11351122

11361123
def as_string_column(self) -> StringColumn:
11371124
return self._get_decategorized_column().as_string_column()
11381125

1139-
def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
1126+
def as_datetime_column(self, dtype: np.dtype) -> DatetimeColumn:
11401127
return self._get_decategorized_column().as_datetime_column(dtype)
11411128

1142-
def as_timedelta_column(self, dtype: Dtype) -> TimeDeltaColumn:
1129+
def as_timedelta_column(self, dtype: np.dtype) -> TimeDeltaColumn:
11431130
return self._get_decategorized_column().as_timedelta_column(dtype)
11441131

11451132
def _get_decategorized_column(self) -> ColumnBase:

python/cudf/cudf/core/column/column.py

+28-50
Original file line numberDiff line numberDiff line change
@@ -1603,30 +1603,13 @@ def cast(self, dtype: Dtype) -> ColumnBase:
16031603
result.dtype.precision = dtype.precision # type: ignore[union-attr]
16041604
return result
16051605

1606-
def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
1607-
if len(self) == 0:
1608-
dtype = cudf.dtype(dtype)
1609-
if self.dtype == dtype:
1610-
result = self
1611-
else:
1612-
result = column_empty(0, dtype=dtype)
1613-
elif dtype == "category":
1614-
# TODO: Figure out why `cudf.dtype("category")`
1615-
# astype's different than just the string
1616-
result = self.as_categorical_column(dtype)
1617-
elif (
1618-
isinstance(dtype, str)
1619-
and dtype == "interval"
1620-
and isinstance(self.dtype, IntervalDtype)
1621-
):
1622-
# astype("interval") (the string only) should no-op
1606+
def astype(self, dtype: DtypeObj, copy: bool = False) -> ColumnBase:
1607+
if self.dtype == dtype:
16231608
result = self
1609+
elif len(self) == 0:
1610+
result = column_empty(0, dtype=dtype)
16241611
else:
1625-
was_object = dtype == object or dtype == np.dtype(object)
1626-
dtype = cudf.dtype(dtype)
1627-
if self.dtype == dtype:
1628-
result = self
1629-
elif isinstance(dtype, CategoricalDtype):
1612+
if isinstance(dtype, CategoricalDtype):
16301613
result = self.as_categorical_column(dtype)
16311614
elif isinstance(dtype, IntervalDtype):
16321615
result = self.as_interval_column(dtype)
@@ -1643,11 +1626,6 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
16431626
elif dtype.kind == "m":
16441627
result = self.as_timedelta_column(dtype)
16451628
elif dtype.kind == "O":
1646-
if cudf.get_option("mode.pandas_compatible") and was_object:
1647-
raise ValueError(
1648-
f"Casting to {dtype} is not supported, use "
1649-
"`.astype('str')` instead."
1650-
)
16511629
result = self.as_string_column()
16521630
else:
16531631
result = self.as_numerical_column(dtype)
@@ -1656,19 +1634,13 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
16561634
return result.copy()
16571635
return result
16581636

1659-
def as_categorical_column(self, dtype) -> ColumnBase:
1660-
if isinstance(dtype, pd.CategoricalDtype):
1661-
dtype = cudf.CategoricalDtype.from_pandas(dtype)
1662-
if isinstance(dtype, cudf.CategoricalDtype):
1663-
ordered = dtype.ordered
1664-
else:
1665-
ordered = False
1637+
def as_categorical_column(
1638+
self, dtype: cudf.CategoricalDtype
1639+
) -> cudf.core.column.categorical.CategoricalColumn:
1640+
ordered = dtype.ordered
16661641

16671642
# Re-label self w.r.t. the provided categories
1668-
if (
1669-
isinstance(dtype, cudf.CategoricalDtype)
1670-
and dtype._categories is not None
1671-
):
1643+
if dtype._categories is not None:
16721644
cat_col = dtype._categories
16731645
codes = self._label_encoding(cats=cat_col)
16741646
codes = cudf.core.column.categorical.as_unsigned_codes(
@@ -1704,31 +1676,31 @@ def as_categorical_column(self, dtype) -> ColumnBase:
17041676
)
17051677

17061678
def as_numerical_column(
1707-
self, dtype: Dtype
1708-
) -> "cudf.core.column.NumericalColumn":
1679+
self, dtype: np.dtype
1680+
) -> cudf.core.column.NumericalColumn:
17091681
raise NotImplementedError
17101682

17111683
def as_datetime_column(
1712-
self, dtype: Dtype
1684+
self, dtype: np.dtype
17131685
) -> cudf.core.column.DatetimeColumn:
17141686
raise NotImplementedError
17151687

17161688
def as_interval_column(
1717-
self, dtype: Dtype
1718-
) -> "cudf.core.column.IntervalColumn":
1689+
self, dtype: IntervalDtype
1690+
) -> cudf.core.column.IntervalColumn:
17191691
raise NotImplementedError
17201692

17211693
def as_timedelta_column(
1722-
self, dtype: Dtype
1694+
self, dtype: np.dtype
17231695
) -> cudf.core.column.TimeDeltaColumn:
17241696
raise NotImplementedError
17251697

17261698
def as_string_column(self) -> cudf.core.column.StringColumn:
17271699
raise NotImplementedError
17281700

17291701
def as_decimal_column(
1730-
self, dtype: Dtype
1731-
) -> "cudf.core.column.decimal.DecimalBaseColumn":
1702+
self, dtype: DecimalDtype
1703+
) -> cudf.core.column.decimal.DecimalBaseColumn:
17321704
raise NotImplementedError
17331705

17341706
def apply_boolean_mask(self, mask) -> ColumnBase:
@@ -2001,7 +1973,7 @@ def _label_encoding(
20011973
cats: ColumnBase,
20021974
dtype: Dtype | None = None,
20031975
na_sentinel: pa.Scalar | None = None,
2004-
):
1976+
) -> NumericalColumn:
20051977
"""
20061978
Convert each value in `self` into an integer code, with `cats`
20071979
providing the mapping between codes and values.
@@ -2070,7 +2042,7 @@ def _return_sentinel_column():
20702042
plc_codes = sorting.sort_by_key(
20712043
[codes], [left_gather_map], [True], ["last"], stable=True
20722044
)[0]
2073-
return ColumnBase.from_pylibcudf(plc_codes).fillna(na_sentinel)
2045+
return ColumnBase.from_pylibcudf(plc_codes).fillna(na_sentinel) # type: ignore[return-value]
20742046

20752047
@acquire_spill_lock()
20762048
def copy_if_else(
@@ -2828,10 +2800,16 @@ def as_column(
28282800
}:
28292801
if isinstance(dtype, (CategoricalDtype, IntervalDtype)):
28302802
dtype = dtype.to_pandas()
2803+
if isinstance(dtype, pd.IntervalDtype):
2804+
# pd.Series(arbitrary) might be already inferred as IntervalDtype
2805+
ser = pd.Series(arbitrary).astype(dtype)
2806+
else:
2807+
ser = pd.Series(arbitrary, dtype=dtype)
28312808
elif dtype == object and not cudf.get_option("mode.pandas_compatible"):
28322809
# Unlike pandas, interpret object as "str" instead of "python object"
2833-
dtype = "str"
2834-
ser = pd.Series(arbitrary, dtype=dtype)
2810+
ser = pd.Series(arbitrary, dtype="str")
2811+
else:
2812+
ser = pd.Series(arbitrary, dtype=dtype)
28352813
return as_column(ser, nan_as_null=nan_as_null)
28362814
elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)):
28372815
try:

python/cudf/cudf/core/column/datetime.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def normalize_binop_value( # type: ignore[override]
597597

598598
return NotImplemented
599599

600-
def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
600+
def as_datetime_column(self, dtype: np.dtype) -> DatetimeColumn:
601601
if dtype == self.dtype:
602602
return self
603603
elif isinstance(dtype, pd.DatetimeTZDtype):
@@ -607,13 +607,13 @@ def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
607607
)
608608
return self.cast(dtype=dtype) # type: ignore[return-value]
609609

610-
def as_timedelta_column(self, dtype: Dtype) -> None: # type: ignore[override]
610+
def as_timedelta_column(self, dtype: np.dtype) -> None: # type: ignore[override]
611611
raise TypeError(
612612
f"cannot astype a datetimelike from {self.dtype} to {dtype}"
613613
)
614614

615615
def as_numerical_column(
616-
self, dtype: Dtype
616+
self, dtype: np.dtype
617617
) -> cudf.core.column.NumericalColumn:
618618
col = cudf.core.column.NumericalColumn(
619619
data=self.base_data, # type: ignore[arg-type]
@@ -1132,7 +1132,9 @@ def strftime(self, format: str) -> cudf.core.column.StringColumn:
11321132
def as_string_column(self) -> cudf.core.column.StringColumn:
11331133
return self._local_time.as_string_column()
11341134

1135-
def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
1135+
def as_datetime_column(
1136+
self, dtype: np.dtype | pd.DatetimeTZDtype
1137+
) -> DatetimeColumn:
11361138
if isinstance(dtype, pd.DatetimeTZDtype) and dtype != self.dtype:
11371139
if dtype.unit != self.time_unit:
11381140
# TODO: Doesn't check that new unit is valid.

python/cudf/cudf/core/column/decimal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __cuda_array_interface__(self):
7272

7373
def as_decimal_column(
7474
self,
75-
dtype: Dtype,
75+
dtype: DecimalDtype,
7676
) -> DecimalBaseColumn:
7777
if isinstance(dtype, DecimalDtype) and dtype.scale < self.dtype.scale:
7878
warnings.warn(
@@ -234,7 +234,7 @@ def normalize_binop_value(self, other) -> Self | cudf.Scalar:
234234
return NotImplemented
235235

236236
def as_numerical_column(
237-
self, dtype: Dtype
237+
self, dtype: np.dtype
238238
) -> cudf.core.column.NumericalColumn:
239239
return self.cast(dtype=dtype) # type: ignore[return-value]
240240

python/cudf/cudf/core/column/lists.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def concat(self, dropna=True) -> ParentType:
875875
self._column.concatenate_list_elements(dropna)
876876
)
877877

878-
def astype(self, dtype):
878+
def astype(self, dtype: Dtype):
879879
"""
880880
Return a new list Series with the leaf values casted
881881
to the specified data type.
@@ -899,6 +899,6 @@ def astype(self, dtype):
899899
"""
900900
return self._return_or_inplace(
901901
self._column._transform_leaves(
902-
lambda col, dtype: col.astype(dtype), dtype
902+
lambda col, dtype: col.astype(cudf.dtype(dtype)), dtype
903903
)
904904
)

python/cudf/cudf/core/column/numerical.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from cudf.core.buffer import Buffer
4545
from cudf.core.column import DecimalBaseColumn
46+
from cudf.core.dtypes import DecimalDtype
4647

4748

4849
class NumericalColumn(NumericalBaseColumn):
@@ -384,7 +385,7 @@ def as_string_column(self) -> cudf.core.column.StringColumn:
384385
)
385386

386387
def as_datetime_column(
387-
self, dtype: Dtype
388+
self, dtype: np.dtype
388389
) -> cudf.core.column.DatetimeColumn:
389390
return cudf.core.column.DatetimeColumn(
390391
data=self.astype(np.dtype(np.int64)).base_data, # type: ignore[arg-type]
@@ -395,7 +396,7 @@ def as_datetime_column(
395396
)
396397

397398
def as_timedelta_column(
398-
self, dtype: Dtype
399+
self, dtype: np.dtype
399400
) -> cudf.core.column.TimeDeltaColumn:
400401
return cudf.core.column.TimeDeltaColumn(
401402
data=self.astype(np.dtype(np.int64)).base_data, # type: ignore[arg-type]
@@ -405,7 +406,7 @@ def as_timedelta_column(
405406
size=self.size,
406407
)
407408

408-
def as_decimal_column(self, dtype: Dtype) -> DecimalBaseColumn:
409+
def as_decimal_column(self, dtype: DecimalDtype) -> DecimalBaseColumn:
409410
return self.cast(dtype=dtype) # type: ignore[return-value]
410411

411412
def as_numerical_column(self, dtype: Dtype) -> NumericalColumn:

python/cudf/cudf/core/column/string.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from cudf.core.column.lists import ListColumn
5050
from cudf.core.column.numerical import NumericalColumn
51+
from cudf.core.dtypes import DecimalDtype
5152

5253

5354
def _is_supported_regex_flags(flags: int) -> bool:
@@ -6026,7 +6027,7 @@ def strptime(
60266027
return result_col # type: ignore[return-value]
60276028

60286029
def as_datetime_column(
6029-
self, dtype: Dtype
6030+
self, dtype: np.dtype
60306031
) -> cudf.core.column.DatetimeColumn:
60316032
not_null = self.apply_boolean_mask(self.notnull())
60326033
if len(not_null) == 0:
@@ -6039,13 +6040,13 @@ def as_datetime_column(
60396040
return self.strptime(dtype, format) # type: ignore[return-value]
60406041

60416042
def as_timedelta_column(
6042-
self, dtype: Dtype
6043+
self, dtype: np.dtype
60436044
) -> cudf.core.column.TimeDeltaColumn:
60446045
return self.strptime(dtype, "%D days %H:%M:%S") # type: ignore[return-value]
60456046

60466047
@acquire_spill_lock()
60476048
def as_decimal_column(
6048-
self, dtype: Dtype
6049+
self, dtype: DecimalDtype
60496050
) -> cudf.core.column.DecimalBaseColumn:
60506051
plc_column = plc.strings.convert.convert_fixed_point.to_fixed_point(
60516052
self.to_pylibcudf(mode="read"),

python/cudf/cudf/core/column/timedelta.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def round(self, freq: str) -> ColumnBase:
325325
raise NotImplementedError("round is currently not implemented")
326326

327327
def as_numerical_column(
328-
self, dtype: Dtype
328+
self, dtype: np.dtype
329329
) -> cudf.core.column.NumericalColumn:
330330
col = cudf.core.column.NumericalColumn(
331331
data=self.base_data, # type: ignore[arg-type]
@@ -336,7 +336,7 @@ def as_numerical_column(
336336
)
337337
return cast("cudf.core.column.NumericalColumn", col.astype(dtype))
338338

339-
def as_datetime_column(self, dtype: Dtype) -> None: # type: ignore[override]
339+
def as_datetime_column(self, dtype: np.dtype) -> None: # type: ignore[override]
340340
raise TypeError(
341341
f"cannot astype a timedelta from {self.dtype} to {dtype}"
342342
)
@@ -358,7 +358,7 @@ def strftime(self, format: str) -> cudf.core.column.StringColumn:
358358
def as_string_column(self) -> cudf.core.column.StringColumn:
359359
return self.strftime("%D days %H:%M:%S")
360360

361-
def as_timedelta_column(self, dtype: Dtype) -> TimeDeltaColumn:
361+
def as_timedelta_column(self, dtype: np.dtype) -> TimeDeltaColumn:
362362
if dtype == self.dtype:
363363
return self
364364
return self.cast(dtype=dtype) # type: ignore[return-value]

0 commit comments

Comments
 (0)