@@ -1603,30 +1603,13 @@ def cast(self, dtype: Dtype) -> ColumnBase:
1603
1603
result .dtype .precision = dtype .precision # type: ignore[union-attr]
1604
1604
return result
1605
1605
1606
- def astype (self , dtype : Dtype , copy : bool = False ) -> ColumnBase :
1607
- if len (self ) == 0 :
1608
- dtype = cudf .dtype (dtype )
1609
- if self .dtype == dtype :
1610
- result = self
1611
- else :
1612
- result = column_empty (0 , dtype = dtype )
1613
- elif dtype == "category" :
1614
- # TODO: Figure out why `cudf.dtype("category")`
1615
- # astype's different than just the string
1616
- result = self .as_categorical_column (dtype )
1617
- elif (
1618
- isinstance (dtype , str )
1619
- and dtype == "interval"
1620
- and isinstance (self .dtype , IntervalDtype )
1621
- ):
1622
- # astype("interval") (the string only) should no-op
1606
+ def astype (self , dtype : DtypeObj , copy : bool = False ) -> ColumnBase :
1607
+ if self .dtype == dtype :
1623
1608
result = self
1609
+ elif len (self ) == 0 :
1610
+ result = column_empty (0 , dtype = dtype )
1624
1611
else :
1625
- was_object = dtype == object or dtype == np .dtype (object )
1626
- dtype = cudf .dtype (dtype )
1627
- if self .dtype == dtype :
1628
- result = self
1629
- elif isinstance (dtype , CategoricalDtype ):
1612
+ if isinstance (dtype , CategoricalDtype ):
1630
1613
result = self .as_categorical_column (dtype )
1631
1614
elif isinstance (dtype , IntervalDtype ):
1632
1615
result = self .as_interval_column (dtype )
@@ -1643,11 +1626,6 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
1643
1626
elif dtype .kind == "m" :
1644
1627
result = self .as_timedelta_column (dtype )
1645
1628
elif dtype .kind == "O" :
1646
- if cudf .get_option ("mode.pandas_compatible" ) and was_object :
1647
- raise ValueError (
1648
- f"Casting to { dtype } is not supported, use "
1649
- "`.astype('str')` instead."
1650
- )
1651
1629
result = self .as_string_column ()
1652
1630
else :
1653
1631
result = self .as_numerical_column (dtype )
@@ -1656,19 +1634,13 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
1656
1634
return result .copy ()
1657
1635
return result
1658
1636
1659
- def as_categorical_column (self , dtype ) -> ColumnBase :
1660
- if isinstance (dtype , pd .CategoricalDtype ):
1661
- dtype = cudf .CategoricalDtype .from_pandas (dtype )
1662
- if isinstance (dtype , cudf .CategoricalDtype ):
1663
- ordered = dtype .ordered
1664
- else :
1665
- ordered = False
1637
+ def as_categorical_column (
1638
+ self , dtype : cudf .CategoricalDtype
1639
+ ) -> cudf .core .column .categorical .CategoricalColumn :
1640
+ ordered = dtype .ordered
1666
1641
1667
1642
# Re-label self w.r.t. the provided categories
1668
- if (
1669
- isinstance (dtype , cudf .CategoricalDtype )
1670
- and dtype ._categories is not None
1671
- ):
1643
+ if dtype ._categories is not None :
1672
1644
cat_col = dtype ._categories
1673
1645
codes = self ._label_encoding (cats = cat_col )
1674
1646
codes = cudf .core .column .categorical .as_unsigned_codes (
@@ -1704,31 +1676,31 @@ def as_categorical_column(self, dtype) -> ColumnBase:
1704
1676
)
1705
1677
1706
1678
def as_numerical_column (
1707
- self , dtype : Dtype
1708
- ) -> " cudf.core.column.NumericalColumn" :
1679
+ self , dtype : np . dtype
1680
+ ) -> cudf .core .column .NumericalColumn :
1709
1681
raise NotImplementedError
1710
1682
1711
1683
def as_datetime_column (
1712
- self , dtype : Dtype
1684
+ self , dtype : np . dtype
1713
1685
) -> cudf .core .column .DatetimeColumn :
1714
1686
raise NotImplementedError
1715
1687
1716
1688
def as_interval_column (
1717
- self , dtype : Dtype
1718
- ) -> " cudf.core.column.IntervalColumn" :
1689
+ self , dtype : IntervalDtype
1690
+ ) -> cudf .core .column .IntervalColumn :
1719
1691
raise NotImplementedError
1720
1692
1721
1693
def as_timedelta_column (
1722
- self , dtype : Dtype
1694
+ self , dtype : np . dtype
1723
1695
) -> cudf .core .column .TimeDeltaColumn :
1724
1696
raise NotImplementedError
1725
1697
1726
1698
def as_string_column (self ) -> cudf .core .column .StringColumn :
1727
1699
raise NotImplementedError
1728
1700
1729
1701
def as_decimal_column (
1730
- self , dtype : Dtype
1731
- ) -> " cudf.core.column.decimal.DecimalBaseColumn" :
1702
+ self , dtype : DecimalDtype
1703
+ ) -> cudf .core .column .decimal .DecimalBaseColumn :
1732
1704
raise NotImplementedError
1733
1705
1734
1706
def apply_boolean_mask (self , mask ) -> ColumnBase :
@@ -2001,7 +1973,7 @@ def _label_encoding(
2001
1973
cats : ColumnBase ,
2002
1974
dtype : Dtype | None = None ,
2003
1975
na_sentinel : pa .Scalar | None = None ,
2004
- ):
1976
+ ) -> NumericalColumn :
2005
1977
"""
2006
1978
Convert each value in `self` into an integer code, with `cats`
2007
1979
providing the mapping between codes and values.
@@ -2070,7 +2042,7 @@ def _return_sentinel_column():
2070
2042
plc_codes = sorting .sort_by_key (
2071
2043
[codes ], [left_gather_map ], [True ], ["last" ], stable = True
2072
2044
)[0 ]
2073
- return ColumnBase .from_pylibcudf (plc_codes ).fillna (na_sentinel )
2045
+ return ColumnBase .from_pylibcudf (plc_codes ).fillna (na_sentinel ) # type: ignore[return-value]
2074
2046
2075
2047
@acquire_spill_lock ()
2076
2048
def copy_if_else (
@@ -2828,10 +2800,16 @@ def as_column(
2828
2800
}:
2829
2801
if isinstance (dtype , (CategoricalDtype , IntervalDtype )):
2830
2802
dtype = dtype .to_pandas ()
2803
+ if isinstance (dtype , pd .IntervalDtype ):
2804
+ # pd.Series(arbitrary) might be already inferred as IntervalDtype
2805
+ ser = pd .Series (arbitrary ).astype (dtype )
2806
+ else :
2807
+ ser = pd .Series (arbitrary , dtype = dtype )
2831
2808
elif dtype == object and not cudf .get_option ("mode.pandas_compatible" ):
2832
2809
# Unlike pandas, interpret object as "str" instead of "python object"
2833
- dtype = "str"
2834
- ser = pd .Series (arbitrary , dtype = dtype )
2810
+ ser = pd .Series (arbitrary , dtype = "str" )
2811
+ else :
2812
+ ser = pd .Series (arbitrary , dtype = dtype )
2835
2813
return as_column (ser , nan_as_null = nan_as_null )
2836
2814
elif isinstance (dtype , (cudf .StructDtype , cudf .ListDtype )):
2837
2815
try :
0 commit comments