From 494732d891d758d5b3d58cb4ad119405f9a1b844 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Wed, 17 Mar 2021 17:35:55 +0100 Subject: [PATCH] Return dask scalar for store and update from ddf (#437) --- CHANGES.rst | 4 ++ kartothek/io/dask/dataframe.py | 93 ++++++++++++++++---------- kartothek/io_components/utils.py | 21 ++++-- tests/io/dask/dataframe/test_stats.py | 37 +++------- tests/io/dask/dataframe/test_update.py | 16 ----- 5 files changed, 89 insertions(+), 82 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 39de0ad8..a6144aba 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -28,6 +28,10 @@ This is a major release of kartothek with breaking API changes. * Remove `metadata`, `df_serializer`, `overwrite`, `metadata_merger` from :func:`kartothek.io.eager.write_single_partition` * :func:`~kartothek.io.eager.store_dataframes_as_dataset` now requires a list as an input * Default value for argument `date_as_object` is now universally set to ``True``. The behaviour for `False` will be deprecated and removed in the next major release +* No longer allow to pass `delete_scope` as a delayed object to :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf` +* :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf` and :func:`~kartothek.io.dask.dataframe.store_dataset_from_ddf` now return a `dd.core.Scalar` object. This enables all `dask.DataFrame` graph optimizations by default. +* Remove argument `table_name` from :func:`~kartothek.io.dask.dataframe.collect_dataset_metadata` + Version 3.20.0 (2021-03-15) =========================== diff --git a/kartothek/io/dask/dataframe.py b/kartothek/io/dask/dataframe.py index 610bea30..dfe48879 100644 --- a/kartothek/io/dask/dataframe.py +++ b/kartothek/io/dask/dataframe.py @@ -32,7 +32,6 @@ from kartothek.io_components.update import update_dataset_from_partitions from kartothek.io_components.utils import ( _ensure_compatible_indices, - normalize_arg, normalize_args, validate_partition_keys, ) @@ -239,8 +238,23 @@ def _shuffle_docs(func): return func +def _id(x): + return x + + +def _commit_update_from_reduction(df_mps, **kwargs): + partitions = pd.Series(df_mps.values.flatten()).dropna() + return update_dataset_from_partitions(partition_list=partitions, **kwargs,) + + +def _commit_store_from_reduction(df_mps, **kwargs): + partitions = pd.Series(df_mps.values.flatten()).dropna() + return store_dataset_from_partitions(partition_list=partitions, **kwargs,) + + @default_docs @_shuffle_docs +@normalize_args def store_dataset_from_ddf( ddf: dd.DataFrame, store: StoreInput, @@ -251,7 +265,6 @@ def store_dataset_from_ddf( repartition_ratio: Optional[SupportsFloat] = None, num_buckets: int = 1, sort_partitions_by: Optional[Union[List[str], str]] = None, - delete_scope: Optional[Iterable[Mapping[str, str]]] = None, metadata: Optional[Mapping] = None, df_serializer: Optional[DataFrameSerializer] = None, metadata_merger: Optional[Callable] = None, @@ -263,12 +276,11 @@ def store_dataset_from_ddf( """ Store a dataset from a dask.dataframe. """ - partition_on = normalize_arg("partition_on", partition_on) - secondary_indices = normalize_arg("secondary_indices", secondary_indices) - sort_partitions_by = normalize_arg("sort_partitions_by", sort_partitions_by) - bucket_by = normalize_arg("bucket_by", bucket_by) - store = normalize_arg("store", store) - delete_scope = dask.delayed(normalize_arg)("delete_scope", delete_scope) + # normalization done by normalize_args but mypy doesn't recognize this + sort_partitions_by = cast(List[str], sort_partitions_by) + secondary_indices = cast(List[str], secondary_indices) + bucket_by = cast(List[str], bucket_by) + partition_on = cast(List[str], partition_on) if table is None: raise TypeError("The parameter `table` is not optional.") @@ -277,9 +289,9 @@ def store_dataset_from_ddf( if not overwrite: raise_if_dataset_exists(dataset_uuid=dataset_uuid, store=store) - mps = _write_dataframe_partitions( + mp_ser = _write_dataframe_partitions( ddf=ddf, - store=store, + store=ds_factory.store_factory, dataset_uuid=dataset_uuid, table=table, secondary_indices=secondary_indices, @@ -292,12 +304,18 @@ def store_dataset_from_ddf( partition_on=partition_on, bucket_by=bucket_by, ) - return dask.delayed(store_dataset_from_partitions)( - mps, - store=ds_factory.store_factory if ds_factory else store, - dataset_uuid=ds_factory.dataset_uuid if ds_factory else dataset_uuid, - dataset_metadata=metadata, - metadata_merger=metadata_merger, + return mp_ser.reduction( + chunk=_id, + aggregate=_commit_store_from_reduction, + split_every=False, + token="commit-dataset", + meta=object, + aggregate_kwargs={ + "store": ds_factory.store_factory, + "dataset_uuid": ds_factory.dataset_uuid, + "dataset_metadata": metadata, + "metadata_merger": metadata_merger, + }, ) @@ -365,6 +383,7 @@ def _write_dataframe_partitions( @default_docs @_shuffle_docs +@normalize_args def update_dataset_from_ddf( ddf: dd.DataFrame, store: Optional[StoreInput] = None, @@ -391,15 +410,15 @@ def update_dataset_from_ddf( -------- :ref:`mutating_datasets` """ - partition_on = normalize_arg("partition_on", partition_on) - secondary_indices = normalize_arg("secondary_indices", secondary_indices) - sort_partitions_by = normalize_arg("sort_partitions_by", sort_partitions_by) - bucket_by = normalize_arg("bucket_by", bucket_by) - store = normalize_arg("store", store) - delete_scope = dask.delayed(normalize_arg)("delete_scope", delete_scope) - if table is None: raise TypeError("The parameter `table` is not optional.") + + # normalization done by normalize_args but mypy doesn't recognize this + sort_partitions_by = cast(List[str], sort_partitions_by) + secondary_indices = cast(List[str], secondary_indices) + bucket_by = cast(List[str], bucket_by) + partition_on = cast(List[str], partition_on) + ds_factory, metadata_version, partition_on = validate_partition_keys( dataset_uuid=dataset_uuid, store=store, @@ -411,9 +430,9 @@ def update_dataset_from_ddf( inferred_indices = _ensure_compatible_indices(ds_factory, secondary_indices) del secondary_indices - mps = _write_dataframe_partitions( + mp_ser = _write_dataframe_partitions( ddf=ddf, - store=store, + store=ds_factory.store_factory if ds_factory else store, dataset_uuid=dataset_uuid or ds_factory.dataset_uuid, table=table, secondary_indices=inferred_indices, @@ -426,14 +445,21 @@ def update_dataset_from_ddf( partition_on=cast(List[str], partition_on), bucket_by=bucket_by, ) - return dask.delayed(update_dataset_from_partitions)( - mps, - store_factory=store, - dataset_uuid=dataset_uuid, - ds_factory=ds_factory, - delete_scope=delete_scope, - metadata=metadata, - metadata_merger=metadata_merger, + + return mp_ser.reduction( + chunk=_id, + aggregate=_commit_update_from_reduction, + split_every=False, + token="commit-dataset", + meta=object, + aggregate_kwargs={ + "store_factory": store, + "dataset_uuid": dataset_uuid, + "ds_factory": ds_factory, + "delete_scope": delete_scope, + "metadata": metadata, + "metadata_merger": metadata_merger, + }, ) @@ -442,7 +468,6 @@ def update_dataset_from_ddf( def collect_dataset_metadata( store: Optional[StoreInput] = None, dataset_uuid: Optional[str] = None, - table_name: str = SINGLE_TABLE, predicates: Optional[PredicatesType] = None, frac: float = 1.0, factory: Optional[DatasetFactory] = None, diff --git a/kartothek/io_components/utils.py b/kartothek/io_components/utils.py index ea8d7c92..211f0591 100644 --- a/kartothek/io_components/utils.py +++ b/kartothek/io_components/utils.py @@ -4,7 +4,7 @@ import collections import inspect import logging -from typing import Dict, Iterable, List, Optional, TypeVar, Union, overload +from typing import Dict, Iterable, List, Optional, Union, overload import decorator import pandas as pd @@ -165,7 +165,20 @@ def validate_partition_keys( _NORMALIZE_ARGS = _NORMALIZE_ARGS_LIST + ["store", "dispatch_by"] -T = TypeVar("T") + +@overload +def normalize_arg( + arg_name: Literal[ + "partition_on", + "delete_scope", + "secondary_indices", + "bucket_by", + "sort_partitions_by", + "dispatch_by", + ], + old_value: None, +) -> None: + ... @overload @@ -178,8 +191,8 @@ def normalize_arg( "sort_partitions_by", "dispatch_by", ], - old_value: Optional[Union[T, List[T]]], -) -> List[T]: + old_value: Union[str, List[str]], +) -> List[str]: ... diff --git a/tests/io/dask/dataframe/test_stats.py b/tests/io/dask/dataframe/test_stats.py index 6d7f5d2e..751c6cde 100644 --- a/tests/io/dask/dataframe/test_stats.py +++ b/tests/io/dask/dataframe/test_stats.py @@ -14,7 +14,6 @@ def test_collect_dataset_metadata(store_session_factory, dataset): df_stats = collect_dataset_metadata( store=store_session_factory, dataset_uuid="dataset_uuid", - table_name="table", predicates=None, frac=1, ).compute() @@ -48,7 +47,6 @@ def test_collect_dataset_metadata_predicates(store_session_factory, dataset): df_stats = collect_dataset_metadata( store=store_session_factory, dataset_uuid="dataset_uuid", - table_name="table", predicates=predicates, frac=1, ).compute() @@ -87,11 +85,7 @@ def test_collect_dataset_metadata_predicates_on_index(store_factory): predicates = [[("L", "==", "b")]] df_stats = collect_dataset_metadata( - store=store_factory, - dataset_uuid="dataset_uuid", - table_name="table", - predicates=predicates, - frac=1, + store=store_factory, dataset_uuid="dataset_uuid", predicates=predicates, frac=1, ).compute() assert "L=b" in df_stats["partition_label"].values[0] @@ -135,11 +129,7 @@ def test_collect_dataset_metadata_predicates_row_group_size(store_factory): predicates = [[("L", "==", "a")]] df_stats = collect_dataset_metadata( - store=store_factory, - dataset_uuid="dataset_uuid", - table_name="table", - predicates=predicates, - frac=1, + store=store_factory, dataset_uuid="dataset_uuid", predicates=predicates, frac=1, ).compute() for part_label in df_stats["partition_label"]: @@ -170,10 +160,7 @@ def test_collect_dataset_metadata_predicates_row_group_size(store_factory): def test_collect_dataset_metadata_frac_smoke(store_session_factory, dataset): df_stats = collect_dataset_metadata( - store=store_session_factory, - dataset_uuid="dataset_uuid", - table_name="table", - frac=0.8, + store=store_session_factory, dataset_uuid="dataset_uuid", frac=0.8, ).compute() columns = { "partition_label", @@ -195,7 +182,7 @@ def test_collect_dataset_metadata_empty_dataset(store_factory): store=store_factory, dataset_uuid="dataset_uuid", dfs=[df], partition_on=["A"] ) df_stats = collect_dataset_metadata( - store=store_factory, dataset_uuid="dataset_uuid", table_name="table", + store=store_factory, dataset_uuid="dataset_uuid" ).compute() expected = pd.DataFrame(columns=_METADATA_SCHEMA.keys()) expected = expected.astype(_METADATA_SCHEMA) @@ -209,7 +196,7 @@ def test_collect_dataset_metadata_concat(store_factory): store=store_factory, dataset_uuid="dataset_uuid", dfs=[df], partition_on=["A"] ) df_stats1 = collect_dataset_metadata( - store=store_factory, dataset_uuid="dataset_uuid", table_name="table", + store=store_factory, dataset_uuid="dataset_uuid", ).compute() # Remove all partitions of the dataset @@ -218,7 +205,7 @@ def test_collect_dataset_metadata_concat(store_factory): ) df_stats2 = collect_dataset_metadata( - store=store_factory, dataset_uuid="dataset_uuid", table_name="table", + store=store_factory, dataset_uuid="dataset_uuid", ).compute() pd.concat([df_stats1, df_stats2]) @@ -234,7 +221,7 @@ def test_collect_dataset_metadata_delete_dataset(store_factory): ) df_stats = collect_dataset_metadata( - store=store_factory, dataset_uuid="dataset_uuid", table_name="table", + store=store_factory, dataset_uuid="dataset_uuid", ).compute() expected = pd.DataFrame(columns=_METADATA_SCHEMA) expected = expected.astype(_METADATA_SCHEMA) @@ -273,16 +260,10 @@ def test_collect_dataset_metadata_at_least_one_partition(store_factory): def test_collect_dataset_metadata_invalid_frac(store_session_factory, dataset): with pytest.raises(ValueError, match="Invalid value for parameter `frac`"): collect_dataset_metadata( - store=store_session_factory, - dataset_uuid="dataset_uuid", - table_name="table", - frac=1.1, + store=store_session_factory, dataset_uuid="dataset_uuid", frac=1.1, ) with pytest.raises(ValueError, match="Invalid value for parameter `frac`"): collect_dataset_metadata( - store=store_session_factory, - dataset_uuid="dataset_uuid", - table_name="table", - frac=0.0, + store=store_session_factory, dataset_uuid="dataset_uuid", frac=0.0, ) diff --git a/tests/io/dask/dataframe/test_update.py b/tests/io/dask/dataframe/test_update.py index 14addbc9..9a9d66f9 100644 --- a/tests/io/dask/dataframe/test_update.py +++ b/tests/io/dask/dataframe/test_update.py @@ -44,22 +44,6 @@ def _return_none(): return None -def test_delayed_as_delete_scope(store_factory, df_all_types): - # Check that delayed objects are allowed as delete scope. - tasks = update_dataset_from_ddf( - dd.from_pandas(df_all_types, npartitions=1), - store_factory, - dataset_uuid="output_dataset_uuid", - table="core", - delete_scope=dask.delayed(_return_none)(), - ) - - s = pickle.dumps(tasks, pickle.HIGHEST_PROTOCOL) - tasks = pickle.loads(s) - - tasks.compute() - - @pytest.mark.parametrize("shuffle", [True, False]) def test_update_dataset_from_ddf_empty(store_factory, shuffle): with pytest.raises(ValueError, match="Cannot store empty datasets"):