Skip to content

Commit

Permalink
Implement handlers for series literal in cudf-polars (rapidsai#16113)
Browse files Browse the repository at this point in the history
A query plan can contain a "literal" polars Series. Often, for example, when calling a contains-like function. To translate these, introduce a new `LiteralColumn` node to capture the concept and add an evaluation rule (converting from arrow).

Since list-dtype Series need the same casting treatment as in dataframe scan case, factor the casting out into a utility, and take the opportunity to handled casting of nested lists correctly.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: rapidsai#16113
  • Loading branch information
wence- authored Jul 1, 2024
1 parent 3c3edfe commit 599ce95
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 12 deletions.
32 changes: 31 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
import cudf._lib.pylibcudf as plc

from cudf_polars.containers import Column, NamedColumn
from cudf_polars.utils import sorting
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

import polars.polars as plrs
import polars.type_aliases as pl_types

from cudf_polars.containers import DataFrame
Expand Down Expand Up @@ -369,6 +370,29 @@ def do_evaluate(
return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1))


class LiteralColumn(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Array[Any, Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: plrs.PySeries) -> None:
super().__init__(dtype)
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# datatype of pyarrow array is correct by construction.
return Column(plc.interop.from_arrow(self.value))


class Col(Expr):
__slots__ = ("name",)
_non_child = ("dtype", "name")
Expand Down Expand Up @@ -1156,6 +1180,12 @@ def __init__(
super().__init__(dtype)
self.op = op
self.children = (left, right)
if (
op in (plc.binaryop.BinaryOperator.ADD, plc.binaryop.BinaryOperator.SUB)
and ({left.dtype.id(), right.dtype.id()}.issubset(dtypes.TIMELIKE_TYPES))
and not dtypes.have_compatible_resolution(left.dtype.id(), right.dtype.id())
):
raise NotImplementedError("Casting rules for timelike types")

_MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = {
pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL,
Expand Down
20 changes: 10 additions & 10 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import DataFrame, NamedColumn
from cudf_polars.utils import sorting
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
Expand Down Expand Up @@ -130,6 +130,11 @@ class IR:
schema: Schema
"""Mapping from column names to their data types."""

def __post_init__(self):
"""Validate preconditions."""
if any(dtype.id() == plc.TypeId.EMPTY for dtype in self.schema.values()):
raise NotImplementedError("Cannot make empty columns.")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""
Evaluate the node and return a dataframe.
Expand Down Expand Up @@ -292,15 +297,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
# TODO: Nested types
if field.type == pa.large_string():
# TODO: goes away when libcudf supports large strings
schema = schema.set(i, pa.field(field.name, pa.string()))
elif isinstance(field.type, pa.LargeListType):
# TODO: goes away when libcudf supports large lists
schema = schema.set(
i, pa.field(field.name, pa.list_(field.type.field(0)))
)
schema = schema.set(
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
)
# No-op if the schema is unchanged.
table = table.cast(schema)
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
Expand Down
3 changes: 3 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pyarrow as pa
from typing_extensions import assert_never

import polars.polars as plrs
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir

import cudf._lib.pylibcudf as plc
Expand Down Expand Up @@ -383,6 +384,8 @@ def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr

@_translate_expr.register
def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
if isinstance(node.value, plrs.PySeries):
return expr.LiteralColumn(dtype, node.value)
value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
return expr.Literal(dtype, value)

Expand Down
81 changes: 80 additions & 1 deletion python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,92 @@

from functools import cache

import pyarrow as pa
from typing_extensions import assert_never

import polars as pl

import cudf._lib.pylibcudf as plc

__all__ = ["from_polars"]
__all__ = ["from_polars", "downcast_arrow_lists", "have_compatible_resolution"]


TIMELIKE_TYPES: frozenset[plc.TypeId] = frozenset(
[
plc.TypeId.TIMESTAMP_MILLISECONDS,
plc.TypeId.TIMESTAMP_MICROSECONDS,
plc.TypeId.TIMESTAMP_NANOSECONDS,
plc.TypeId.TIMESTAMP_DAYS,
plc.TypeId.DURATION_MILLISECONDS,
plc.TypeId.DURATION_MICROSECONDS,
plc.TypeId.DURATION_NANOSECONDS,
]
)


def have_compatible_resolution(lid: plc.TypeId, rid: plc.TypeId):
"""
Do two datetime typeids have matching resolution for a binop.
Parameters
----------
lid
Left type id
rid
Right type id
Returns
-------
True if resolutions are compatible, False otherwise.
Notes
-----
Polars has different casting rules for combining
datetimes/durations than libcudf, and while we don't encode the
casting rules fully, just reject things we can't handle.
Precondition for correctness: both lid and rid are timelike.
"""
if lid == rid:
return True
# Timestamps are smaller than durations in the libcudf enum.
lid, rid = sorted([lid, rid])
if lid == plc.TypeId.TIMESTAMP_MILLISECONDS:
return rid == plc.TypeId.DURATION_MILLISECONDS
elif lid == plc.TypeId.TIMESTAMP_MICROSECONDS:
return rid == plc.TypeId.DURATION_MICROSECONDS
elif lid == plc.TypeId.TIMESTAMP_NANOSECONDS:
return rid == plc.TypeId.DURATION_NANOSECONDS
return False


def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType:
"""
Sanitize an arrow datatype from polars.
Parameters
----------
typ
Arrow type to sanitize
Returns
-------
Sanitized arrow type
Notes
-----
As well as arrow ``ListType``s, polars can produce
``LargeListType``s and ``FixedSizeListType``s, these are not
currently handled by libcudf, so we attempt to cast them all into
normal ``ListType``s on the arrow side before consuming the arrow
data.
"""
if isinstance(typ, pa.LargeListType):
return pa.list_(downcast_arrow_lists(typ.value_type))
# We don't have to worry about diving into struct types for now
# since those are always NotImplemented before we get here.
assert not isinstance(typ, pa.StructType)
return typ


@cache
Expand Down
96 changes: 96 additions & 0 deletions python/cudf_polars/tests/expressions/test_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)
from cudf_polars.utils import dtypes


@pytest.fixture(
params=[
None,
pl.Int8(),
pl.Int16(),
pl.Int32(),
pl.Int64(),
pl.UInt8(),
pl.UInt16(),
pl.UInt32(),
pl.UInt64(),
]
)
def integer(request):
return pl.lit(10, dtype=request.param)


@pytest.fixture(params=[None, pl.Float32(), pl.Float64()])
def float(request):
return pl.lit(1.0, dtype=request.param)


def test_numeric_literal(integer, float):
df = pl.LazyFrame({})

q = df.select(integer=integer, float_=float, sum_=integer + float)

assert_gpu_result_equal(q)


@pytest.fixture(
params=[pl.Date(), pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")]
)
def timestamp(request):
return pl.lit(10_000, dtype=request.param)


@pytest.fixture(params=[pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")])
def timedelta(request):
return pl.lit(9_000, dtype=request.param)


def test_timelike_literal(timestamp, timedelta):
df = pl.LazyFrame({})

q = df.select(
time=timestamp,
delta=timedelta,
adjusted=timestamp + timedelta,
two_delta=timedelta + timedelta,
)
schema = q.collect_schema()
time_type = schema["time"]
delta_type = schema["delta"]
if dtypes.have_compatible_resolution(
dtypes.from_polars(time_type).id(), dtypes.from_polars(delta_type).id()
):
assert_gpu_result_equal(q)
else:
assert_ir_translation_raises(q, NotImplementedError)


def test_select_literal_series():
df = pl.LazyFrame({})

q = df.select(
a=pl.Series(["a", "b", "c"], dtype=pl.String()),
b=pl.Series([[1, 2], [3], None], dtype=pl.List(pl.UInt16())),
c=pl.Series([[[1]], [], [[1, 2, 3, 4]]], dtype=pl.List(pl.List(pl.Float32()))),
)

assert_gpu_result_equal(q)


@pytest.mark.parametrize("expr", [pl.lit(None), pl.lit(10, dtype=pl.Decimal())])
def test_unsupported_literal_raises(expr):
df = pl.LazyFrame({})

q = df.select(expr)

assert_ir_translation_raises(q, NotImplementedError)
19 changes: 19 additions & 0 deletions python/cudf_polars/tests/test_dataframescan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,22 @@ def test_scan_drop_nulls(subset, predicate_pushdown):
assert_gpu_result_equal(
q, collect_kwargs={"predicate_pushdown": predicate_pushdown}
)


def test_can_convert_lists():
df = pl.LazyFrame(
{
"a": pl.Series([[1, 2], [3]], dtype=pl.List(pl.Int8())),
"b": pl.Series([[1], [2]], dtype=pl.List(pl.UInt16())),
"c": pl.Series(
[
[["1", "2", "3"], ["4", "567"]],
[["8", "9"], []],
],
dtype=pl.List(pl.List(pl.String())),
),
"d": pl.Series([[[1, 2]], []], dtype=pl.List(pl.List(pl.UInt16()))),
}
)

assert_gpu_result_equal(df)

0 comments on commit 599ce95

Please sign in to comment.