Skip to content

Commit

Permalink
Return dask scalar for store and update from ddf (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Mar 17, 2021
1 parent 3807a02 commit 494732d
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 82 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ This is a major release of kartothek with breaking API changes.
* Remove `metadata`, `df_serializer`, `overwrite`, `metadata_merger` from :func:`kartothek.io.eager.write_single_partition`
* :func:`~kartothek.io.eager.store_dataframes_as_dataset` now requires a list as an input
* Default value for argument `date_as_object` is now universally set to ``True``. The behaviour for `False` will be deprecated and removed in the next major release
* No longer allow to pass `delete_scope` as a delayed object to :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf`
* :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf` and :func:`~kartothek.io.dask.dataframe.store_dataset_from_ddf` now return a `dd.core.Scalar` object. This enables all `dask.DataFrame` graph optimizations by default.
* Remove argument `table_name` from :func:`~kartothek.io.dask.dataframe.collect_dataset_metadata`


Version 3.20.0 (2021-03-15)
===========================
Expand Down
93 changes: 59 additions & 34 deletions kartothek/io/dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from kartothek.io_components.update import update_dataset_from_partitions
from kartothek.io_components.utils import (
_ensure_compatible_indices,
normalize_arg,
normalize_args,
validate_partition_keys,
)
Expand Down Expand Up @@ -239,8 +238,23 @@ def _shuffle_docs(func):
return func


def _id(x):
return x


def _commit_update_from_reduction(df_mps, **kwargs):
partitions = pd.Series(df_mps.values.flatten()).dropna()
return update_dataset_from_partitions(partition_list=partitions, **kwargs,)


def _commit_store_from_reduction(df_mps, **kwargs):
partitions = pd.Series(df_mps.values.flatten()).dropna()
return store_dataset_from_partitions(partition_list=partitions, **kwargs,)


@default_docs
@_shuffle_docs
@normalize_args
def store_dataset_from_ddf(
ddf: dd.DataFrame,
store: StoreInput,
Expand All @@ -251,7 +265,6 @@ def store_dataset_from_ddf(
repartition_ratio: Optional[SupportsFloat] = None,
num_buckets: int = 1,
sort_partitions_by: Optional[Union[List[str], str]] = None,
delete_scope: Optional[Iterable[Mapping[str, str]]] = None,
metadata: Optional[Mapping] = None,
df_serializer: Optional[DataFrameSerializer] = None,
metadata_merger: Optional[Callable] = None,
Expand All @@ -263,12 +276,11 @@ def store_dataset_from_ddf(
"""
Store a dataset from a dask.dataframe.
"""
partition_on = normalize_arg("partition_on", partition_on)
secondary_indices = normalize_arg("secondary_indices", secondary_indices)
sort_partitions_by = normalize_arg("sort_partitions_by", sort_partitions_by)
bucket_by = normalize_arg("bucket_by", bucket_by)
store = normalize_arg("store", store)
delete_scope = dask.delayed(normalize_arg)("delete_scope", delete_scope)
# normalization done by normalize_args but mypy doesn't recognize this
sort_partitions_by = cast(List[str], sort_partitions_by)
secondary_indices = cast(List[str], secondary_indices)
bucket_by = cast(List[str], bucket_by)
partition_on = cast(List[str], partition_on)

if table is None:
raise TypeError("The parameter `table` is not optional.")
Expand All @@ -277,9 +289,9 @@ def store_dataset_from_ddf(

if not overwrite:
raise_if_dataset_exists(dataset_uuid=dataset_uuid, store=store)
mps = _write_dataframe_partitions(
mp_ser = _write_dataframe_partitions(
ddf=ddf,
store=store,
store=ds_factory.store_factory,
dataset_uuid=dataset_uuid,
table=table,
secondary_indices=secondary_indices,
Expand All @@ -292,12 +304,18 @@ def store_dataset_from_ddf(
partition_on=partition_on,
bucket_by=bucket_by,
)
return dask.delayed(store_dataset_from_partitions)(
mps,
store=ds_factory.store_factory if ds_factory else store,
dataset_uuid=ds_factory.dataset_uuid if ds_factory else dataset_uuid,
dataset_metadata=metadata,
metadata_merger=metadata_merger,
return mp_ser.reduction(
chunk=_id,
aggregate=_commit_store_from_reduction,
split_every=False,
token="commit-dataset",
meta=object,
aggregate_kwargs={
"store": ds_factory.store_factory,
"dataset_uuid": ds_factory.dataset_uuid,
"dataset_metadata": metadata,
"metadata_merger": metadata_merger,
},
)


Expand Down Expand Up @@ -365,6 +383,7 @@ def _write_dataframe_partitions(

@default_docs
@_shuffle_docs
@normalize_args
def update_dataset_from_ddf(
ddf: dd.DataFrame,
store: Optional[StoreInput] = None,
Expand All @@ -391,15 +410,15 @@ def update_dataset_from_ddf(
--------
:ref:`mutating_datasets`
"""
partition_on = normalize_arg("partition_on", partition_on)
secondary_indices = normalize_arg("secondary_indices", secondary_indices)
sort_partitions_by = normalize_arg("sort_partitions_by", sort_partitions_by)
bucket_by = normalize_arg("bucket_by", bucket_by)
store = normalize_arg("store", store)
delete_scope = dask.delayed(normalize_arg)("delete_scope", delete_scope)

if table is None:
raise TypeError("The parameter `table` is not optional.")

# normalization done by normalize_args but mypy doesn't recognize this
sort_partitions_by = cast(List[str], sort_partitions_by)
secondary_indices = cast(List[str], secondary_indices)
bucket_by = cast(List[str], bucket_by)
partition_on = cast(List[str], partition_on)

ds_factory, metadata_version, partition_on = validate_partition_keys(
dataset_uuid=dataset_uuid,
store=store,
Expand All @@ -411,9 +430,9 @@ def update_dataset_from_ddf(
inferred_indices = _ensure_compatible_indices(ds_factory, secondary_indices)
del secondary_indices

mps = _write_dataframe_partitions(
mp_ser = _write_dataframe_partitions(
ddf=ddf,
store=store,
store=ds_factory.store_factory if ds_factory else store,
dataset_uuid=dataset_uuid or ds_factory.dataset_uuid,
table=table,
secondary_indices=inferred_indices,
Expand All @@ -426,14 +445,21 @@ def update_dataset_from_ddf(
partition_on=cast(List[str], partition_on),
bucket_by=bucket_by,
)
return dask.delayed(update_dataset_from_partitions)(
mps,
store_factory=store,
dataset_uuid=dataset_uuid,
ds_factory=ds_factory,
delete_scope=delete_scope,
metadata=metadata,
metadata_merger=metadata_merger,

return mp_ser.reduction(
chunk=_id,
aggregate=_commit_update_from_reduction,
split_every=False,
token="commit-dataset",
meta=object,
aggregate_kwargs={
"store_factory": store,
"dataset_uuid": dataset_uuid,
"ds_factory": ds_factory,
"delete_scope": delete_scope,
"metadata": metadata,
"metadata_merger": metadata_merger,
},
)


Expand All @@ -442,7 +468,6 @@ def update_dataset_from_ddf(
def collect_dataset_metadata(
store: Optional[StoreInput] = None,
dataset_uuid: Optional[str] = None,
table_name: str = SINGLE_TABLE,
predicates: Optional[PredicatesType] = None,
frac: float = 1.0,
factory: Optional[DatasetFactory] = None,
Expand Down
21 changes: 17 additions & 4 deletions kartothek/io_components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import collections
import inspect
import logging
from typing import Dict, Iterable, List, Optional, TypeVar, Union, overload
from typing import Dict, Iterable, List, Optional, Union, overload

import decorator
import pandas as pd
Expand Down Expand Up @@ -165,7 +165,20 @@ def validate_partition_keys(

_NORMALIZE_ARGS = _NORMALIZE_ARGS_LIST + ["store", "dispatch_by"]

T = TypeVar("T")

@overload
def normalize_arg(
arg_name: Literal[
"partition_on",
"delete_scope",
"secondary_indices",
"bucket_by",
"sort_partitions_by",
"dispatch_by",
],
old_value: None,
) -> None:
...


@overload
Expand All @@ -178,8 +191,8 @@ def normalize_arg(
"sort_partitions_by",
"dispatch_by",
],
old_value: Optional[Union[T, List[T]]],
) -> List[T]:
old_value: Union[str, List[str]],
) -> List[str]:
...


Expand Down
37 changes: 9 additions & 28 deletions tests/io/dask/dataframe/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def test_collect_dataset_metadata(store_session_factory, dataset):
df_stats = collect_dataset_metadata(
store=store_session_factory,
dataset_uuid="dataset_uuid",
table_name="table",
predicates=None,
frac=1,
).compute()
Expand Down Expand Up @@ -48,7 +47,6 @@ def test_collect_dataset_metadata_predicates(store_session_factory, dataset):
df_stats = collect_dataset_metadata(
store=store_session_factory,
dataset_uuid="dataset_uuid",
table_name="table",
predicates=predicates,
frac=1,
).compute()
Expand Down Expand Up @@ -87,11 +85,7 @@ def test_collect_dataset_metadata_predicates_on_index(store_factory):
predicates = [[("L", "==", "b")]]

df_stats = collect_dataset_metadata(
store=store_factory,
dataset_uuid="dataset_uuid",
table_name="table",
predicates=predicates,
frac=1,
store=store_factory, dataset_uuid="dataset_uuid", predicates=predicates, frac=1,
).compute()

assert "L=b" in df_stats["partition_label"].values[0]
Expand Down Expand Up @@ -135,11 +129,7 @@ def test_collect_dataset_metadata_predicates_row_group_size(store_factory):
predicates = [[("L", "==", "a")]]

df_stats = collect_dataset_metadata(
store=store_factory,
dataset_uuid="dataset_uuid",
table_name="table",
predicates=predicates,
frac=1,
store=store_factory, dataset_uuid="dataset_uuid", predicates=predicates, frac=1,
).compute()

for part_label in df_stats["partition_label"]:
Expand Down Expand Up @@ -170,10 +160,7 @@ def test_collect_dataset_metadata_predicates_row_group_size(store_factory):

def test_collect_dataset_metadata_frac_smoke(store_session_factory, dataset):
df_stats = collect_dataset_metadata(
store=store_session_factory,
dataset_uuid="dataset_uuid",
table_name="table",
frac=0.8,
store=store_session_factory, dataset_uuid="dataset_uuid", frac=0.8,
).compute()
columns = {
"partition_label",
Expand All @@ -195,7 +182,7 @@ def test_collect_dataset_metadata_empty_dataset(store_factory):
store=store_factory, dataset_uuid="dataset_uuid", dfs=[df], partition_on=["A"]
)
df_stats = collect_dataset_metadata(
store=store_factory, dataset_uuid="dataset_uuid", table_name="table",
store=store_factory, dataset_uuid="dataset_uuid"
).compute()
expected = pd.DataFrame(columns=_METADATA_SCHEMA.keys())
expected = expected.astype(_METADATA_SCHEMA)
Expand All @@ -209,7 +196,7 @@ def test_collect_dataset_metadata_concat(store_factory):
store=store_factory, dataset_uuid="dataset_uuid", dfs=[df], partition_on=["A"]
)
df_stats1 = collect_dataset_metadata(
store=store_factory, dataset_uuid="dataset_uuid", table_name="table",
store=store_factory, dataset_uuid="dataset_uuid",
).compute()

# Remove all partitions of the dataset
Expand All @@ -218,7 +205,7 @@ def test_collect_dataset_metadata_concat(store_factory):
)

df_stats2 = collect_dataset_metadata(
store=store_factory, dataset_uuid="dataset_uuid", table_name="table",
store=store_factory, dataset_uuid="dataset_uuid",
).compute()
pd.concat([df_stats1, df_stats2])

Expand All @@ -234,7 +221,7 @@ def test_collect_dataset_metadata_delete_dataset(store_factory):
)

df_stats = collect_dataset_metadata(
store=store_factory, dataset_uuid="dataset_uuid", table_name="table",
store=store_factory, dataset_uuid="dataset_uuid",
).compute()
expected = pd.DataFrame(columns=_METADATA_SCHEMA)
expected = expected.astype(_METADATA_SCHEMA)
Expand Down Expand Up @@ -273,16 +260,10 @@ def test_collect_dataset_metadata_at_least_one_partition(store_factory):
def test_collect_dataset_metadata_invalid_frac(store_session_factory, dataset):
with pytest.raises(ValueError, match="Invalid value for parameter `frac`"):
collect_dataset_metadata(
store=store_session_factory,
dataset_uuid="dataset_uuid",
table_name="table",
frac=1.1,
store=store_session_factory, dataset_uuid="dataset_uuid", frac=1.1,
)

with pytest.raises(ValueError, match="Invalid value for parameter `frac`"):
collect_dataset_metadata(
store=store_session_factory,
dataset_uuid="dataset_uuid",
table_name="table",
frac=0.0,
store=store_session_factory, dataset_uuid="dataset_uuid", frac=0.0,
)
16 changes: 0 additions & 16 deletions tests/io/dask/dataframe/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,6 @@ def _return_none():
return None


def test_delayed_as_delete_scope(store_factory, df_all_types):
# Check that delayed objects are allowed as delete scope.
tasks = update_dataset_from_ddf(
dd.from_pandas(df_all_types, npartitions=1),
store_factory,
dataset_uuid="output_dataset_uuid",
table="core",
delete_scope=dask.delayed(_return_none)(),
)

s = pickle.dumps(tasks, pickle.HIGHEST_PROTOCOL)
tasks = pickle.loads(s)

tasks.compute()


@pytest.mark.parametrize("shuffle", [True, False])
def test_update_dataset_from_ddf_empty(store_factory, shuffle):
with pytest.raises(ValueError, match="Cannot store empty datasets"):
Expand Down

0 comments on commit 494732d

Please sign in to comment.