Skip to content

Commit

Permalink
Add typing to callables in transform decorators (#78)
Browse files Browse the repository at this point in the history
* typing: transform decorator Input type

* typing: add polars dataframe return types to polars decorator

* typing: add pandas dataframe return type to pandas decorator

* typing: add pyspark dataframe return type to transform_df decorator

* typing: add return types to regular transform decorator

* typing: add Input and Output types to transform decorator
  • Loading branch information
JimLundin authored Oct 21, 2024
1 parent 0ec0b3c commit 75d28f8
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions libs/transforms/src/transforms/api/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
if TYPE_CHECKING:
from collections.abc import Callable

import pandas as pd
import polars as pl
import pyspark.sql


def lightweight(
_maybe_transform: Transform | None = None,
Expand Down Expand Up @@ -113,7 +117,7 @@ def _lightweight(transform: Transform) -> Transform:
return _lightweight if _maybe_transform is None else _lightweight(_maybe_transform)


def transform_polars(output: Output, **inputs) -> Callable[[Callable], Transform]:
def transform_polars(output: Output, **inputs: Input) -> Callable[[Callable[..., pl.DataFrame]], Transform]:
"""Register the wrapped compute function as a Polars transform.
Note:
Expand Down Expand Up @@ -143,7 +147,7 @@ def transform_polars(output: Output, **inputs) -> Callable[[Callable], Transform
**inputs (Input): kwargs comprised of named :class:`Input` specs.
"""

def _transform_polars(compute_function: Callable) -> Transform:
def _transform_polars(compute_function: Callable[..., pl.DataFrame]) -> Transform:
return Transform(
compute_function,
outputs={"output": output},
Expand All @@ -154,7 +158,7 @@ def _transform_polars(compute_function: Callable) -> Transform:
return _transform_polars


def transform_df(output: Output, **inputs) -> Callable[[Callable], Transform]:
def transform_df(output: Output, **inputs: Input) -> Callable[[Callable[..., pyspark.sql.DataFrame]], Transform]:
"""Register the wrapped compute function as a dataframe transform.
The ``transform_df`` decorator is used to construct a :class:`Transform` object from
Expand All @@ -180,13 +184,13 @@ def transform_df(output: Output, **inputs) -> Callable[[Callable], Transform]:
**inputs (Input): kwargs comprised of named :class:`Input` specs.
"""

def _transform_df(compute_func: Callable) -> Transform:
def _transform_df(compute_func: Callable[..., pyspark.sql.DataFrame]) -> Transform:
return Transform(compute_func, {"output": output}, inputs=inputs, decorator="spark")

return _transform_df


def transform_pandas(output: Output, **inputs) -> Callable[[Callable], Transform]:
def transform_pandas(output: Output, **inputs: Input) -> Callable[[Callable[..., pd.DataFrame]], Transform]:
"""Register the wrapped compute function as a Pandas transform.
The ``transform_pandas`` decorator is used to construct a :class:`Transform` object from
Expand All @@ -211,13 +215,13 @@ def transform_pandas(output: Output, **inputs) -> Callable[[Callable], Transform
**inputs (Input): kwargs comprised of named :class:`Input` specs.
"""

def _transform_pandas(compute_func: Callable) -> Transform:
def _transform_pandas(compute_func: Callable[..., pd.DataFrame]) -> Transform:
return Transform(compute_func, {"output": output}, inputs=inputs, decorator="pandas")

return _transform_pandas


def transform(**kwargs) -> Callable[[Callable], Transform]:
def transform(**kwargs: Input | Output) -> Callable[[Callable[..., None]], Transform]:
"""Wrap up a compute function as a Transform object.
>>> from transforms.api import transform, Input, Output
Expand All @@ -239,7 +243,7 @@ def transform(**kwargs) -> Callable[[Callable], Transform]:
The compute function is responsible for writing data to its outputs.
"""

def _transform(compute_func: Callable) -> Transform:
def _transform(compute_func: Callable[..., None]) -> Transform:
return Transform(
compute_func,
outputs={k: v for k, v in kwargs.items() if isinstance(v, Output)},
Expand Down

0 comments on commit 75d28f8

Please sign in to comment.