Skip to content

Commit

Permalink
Improved support for pyarrow strings (#10000)
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet authored Mar 8, 2023
1 parent 5f1fc42 commit 15ba4b3
Show file tree
Hide file tree
Showing 33 changed files with 357 additions and 87 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ jobs:
path: env.yaml

- name: Run tests
# pyarrow dtypes currently fail, so we allow continuing on error for that specific build
# TODO: Remove the `continue-on-error` line below once tests are all passing
continue-on-error: ${{ matrix.extra == 'pyarrow' }}
run: source continuous_integration/scripts/run_tests.sh

- name: Coverage
Expand Down
19 changes: 19 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def pytest_runtest_setup(item):
pytest.skip("need --runslow option to run")


skip_with_pyarrow_strings = pytest.mark.skipif(
bool(dask.config.get("dataframe.convert_string")),
reason="No need to run with pyarrow strings",
)

xfail_with_pyarrow_strings = pytest.mark.xfail(
bool(dask.config.get("dataframe.convert_string")),
reason="Known failure with pyarrow strings",
)


def pytest_collection_modifyitems(config, items):
for item in items:
if "skip_with_pyarrow_strings" in item.keywords:
item.add_marker(skip_with_pyarrow_strings)
if "xfail_with_pyarrow_strings" in item.keywords:
item.add_marker(xfail_with_pyarrow_strings)


pytest.register_assert_rewrite(
"dask.array.utils", "dask.dataframe.utils", "dask.bag.utils"
)
Expand Down
1 change: 1 addition & 0 deletions dask/bag/tests/test_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,7 @@ def test_dask_layers_to_delayed(optimize):
db.Item(arr.dask, (arr.name,), layer="foo")


@pytest.mark.skip_with_pyarrow_strings # test checks graph layers
def test_to_dataframe_optimize_graph():
pytest.importorskip("dask.dataframe")
from dask.dataframe.utils import assert_eq as assert_eq_df
Expand Down
8 changes: 7 additions & 1 deletion dask/bytes/tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ def test_open_glob(dir_server):


@pytest.mark.network
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
@pytest.mark.parametrize(
"engine",
(
"pyarrow",
pytest.param("fastparquet", marks=pytest.mark.xfail_with_pyarrow_strings),
),
)
def test_parquet(engine):
pytest.importorskip("requests", minversion="2.21.0")
dd = pytest.importorskip("dask.dataframe")
Expand Down
24 changes: 21 additions & 3 deletions dask/bytes/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,13 @@ def test_modification_time_read_bytes(s3, s3so):
assert [aa._key for aa in concat(a)] != [cc._key for cc in concat(c)]


@pytest.mark.parametrize("engine", ["pyarrow", "fastparquet"])
@pytest.mark.parametrize(
"engine",
[
"pyarrow",
pytest.param("fastparquet", marks=pytest.mark.xfail_with_pyarrow_strings),
],
)
@pytest.mark.parametrize("metadata_file", [True, False])
def test_parquet(s3, engine, s3so, metadata_file):
dd = pytest.importorskip("dask.dataframe")
Expand Down Expand Up @@ -536,7 +542,13 @@ def _open(*args, check=True, **kwargs):
dd.utils.assert_eq(data, df4)


@pytest.mark.parametrize("engine", ["pyarrow", "fastparquet"])
@pytest.mark.parametrize(
"engine",
[
"pyarrow",
pytest.param("fastparquet", marks=pytest.mark.xfail_with_pyarrow_strings),
],
)
def test_parquet_append(s3, engine, s3so):
pytest.importorskip(engine)
dd = pytest.importorskip("dask.dataframe")
Expand Down Expand Up @@ -591,7 +603,13 @@ def test_parquet_append(s3, engine, s3so):
)


@pytest.mark.parametrize("engine", ["pyarrow", "fastparquet"])
@pytest.mark.parametrize(
"engine",
[
"pyarrow",
pytest.param("fastparquet", marks=pytest.mark.xfail_with_pyarrow_strings),
],
)
def test_parquet_wstoragepars(s3, s3so, engine):
pytest.importorskip(engine)
dd = pytest.importorskip("dask.dataframe")
Expand Down
7 changes: 6 additions & 1 deletion dask/dataframe/_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def is_pyarrow_string_dtype(dtype):

def is_object_string_dtype(dtype):
"""Determine if input is a non-pyarrow string dtype"""
return pd.api.types.is_string_dtype(dtype) and not is_pyarrow_string_dtype(dtype)
# in pandas < 2.0, is_string_dtype(DecimalDtype()) returns True
return (
pd.api.types.is_string_dtype(dtype)
and not is_pyarrow_string_dtype(dtype)
and not pd.api.types.is_dtype_equal(dtype, "decimal")
)


def is_object_string_index(x):
Expand Down
7 changes: 5 additions & 2 deletions dask/dataframe/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ def _delegate_property(obj, accessor, attr):

@staticmethod
def _delegate_method(obj, accessor, attr, args, kwargs):
out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs)
return maybe_wrap_pandas(obj, out)
with warnings.catch_warnings():
# Falling back on a non-pyarrow code path which may decrease performance
warnings.simplefilter("ignore", pd.errors.PerformanceWarning)
out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs)
return maybe_wrap_pandas(obj, out)

def _property_map(self, attr):
meta = self._delegate_property(self._series._meta, self._accessor_name, attr)
Expand Down
4 changes: 2 additions & 2 deletions dask/dataframe/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def categorize(df, columns=None, index=None, split_every=None, **kwargs):
"""
meta = df._meta
if columns is None:
columns = list(meta.select_dtypes(["object", "category"]).columns)
columns = list(meta.select_dtypes(["object", "string", "category"]).columns)
elif is_scalar(columns):
columns = [columns]

Expand All @@ -114,7 +114,7 @@ def categorize(df, columns=None, index=None, split_every=None, **kwargs):
if is_categorical_dtype(meta.index):
index = not has_known_categories(meta.index)
elif index is None:
index = meta.index.dtype == object
index = str(meta.index.dtype) in ("object", "string")

# Nothing to do
if not len(columns) and index is False:
Expand Down
7 changes: 5 additions & 2 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3694,8 +3694,11 @@ def _raise_if_object_series(x, funcname):
Utility function to raise an error if an object column does not support
a certain operation like `mean`.
"""
if isinstance(x, Series) and hasattr(x, "dtype") and x.dtype == object:
raise ValueError("`%s` not supported with object series" % funcname)
if isinstance(x, Series) and hasattr(x, "dtype"):
if x.dtype == object:
raise ValueError("`%s` not supported with object series" % funcname)
elif pd.api.types.is_dtype_equal(x.dtype, "string"):
raise ValueError("`%s` not supported with string series" % funcname)


class Series(_Frame):
Expand Down
48 changes: 34 additions & 14 deletions dask/dataframe/io/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
text_blocks_to_pandas,
)
from dask.dataframe.optimize import optimize_dataframe_getitem
from dask.dataframe.utils import assert_eq, has_known_categories
from dask.dataframe.utils import (
assert_eq,
get_string_dtype,
has_known_categories,
pyarrow_strings_enabled,
)
from dask.layers import DataFrameIOLayer
from dask.utils import filetext, filetexts, tmpdir, tmpfile
from dask.utils_test import hlg_layer
Expand Down Expand Up @@ -120,7 +125,18 @@ def parse_filename(path):
),
}

expected = pd.concat([pd.read_csv(BytesIO(csv_files[k])) for k in sorted(csv_files)])

def read_files(file_names=csv_files):
df = pd.concat([pd.read_csv(BytesIO(csv_files[k])) for k in sorted(file_names)])
df = df.astype({"name": get_string_dtype(), "amount": int, "id": int})
return df


def read_files_with(file_names, handler, **kwargs):
df = pd.concat([handler(n, **kwargs) for n in sorted(file_names)])
df = df.astype({"name": get_string_dtype(), "amount": int, "id": int})
return df


comment_header = b"""# some header lines
# that may be present
Expand Down Expand Up @@ -193,7 +209,7 @@ def test_text_blocks_to_pandas_simple(reader, files):
values = text_blocks_to_pandas(reader, blocks, header, head, kwargs)
assert isinstance(values, dd.DataFrame)
assert hasattr(values, "dask")
assert len(values.dask) == 3
assert len(values.dask) == 6 if pyarrow_strings_enabled() else 3

assert_eq(df.amount.sum(), 100 + 200 + 300 + 400 + 500 + 600)

Expand All @@ -214,6 +230,7 @@ def test_text_blocks_to_pandas_kwargs(reader, files):

@csv_and_table
def test_text_blocks_to_pandas_blocked(reader, files):
expected = read_files()
header = files["2014-01-01.csv"].split(b"\n")[0] + b"\n"
blocks = []
for k in sorted(files):
Expand Down Expand Up @@ -248,7 +265,7 @@ def test_skiprows(dd_read, pd_read, files):
skip = len(comment_header.splitlines())
with filetexts(files, mode="b"):
df = dd_read("2014-01-*.csv", skiprows=skip)
expected_df = pd.concat([pd_read(n, skiprows=skip) for n in sorted(files)])
expected_df = read_files_with(files, pd_read, skiprows=skip)
assert_eq(df, expected_df, check_dtype=False)


Expand All @@ -260,12 +277,12 @@ def test_comment(dd_read, pd_read, files):
files = {
name: comment_header
+ b"\n"
+ content.replace(b"\n", b" # just some comment\n", 1)
+ content.replace(b"\n", b"# just some comment\n", 1)
for name, content in files.items()
}
with filetexts(files, mode="b"):
df = dd_read("2014-01-*.csv", comment="#")
expected_df = pd.concat([pd_read(n, comment="#") for n in sorted(files)])
expected_df = read_files_with(files, pd_read, comment="#")
assert_eq(df, expected_df, check_dtype=False)


Expand All @@ -278,9 +295,7 @@ def test_skipfooter(dd_read, pd_read, files):
skip = len(comment_footer.splitlines())
with filetexts(files, mode="b"):
df = dd_read("2014-01-*.csv", skipfooter=skip, engine="python")
expected_df = pd.concat(
[pd_read(n, skipfooter=skip, engine="python") for n in sorted(files)]
)
expected_df = read_files_with(files, pd_read, skipfooter=skip, engine="python")
assert_eq(df, expected_df, check_dtype=False)


Expand All @@ -299,7 +314,7 @@ def test_skiprows_as_list(dd_read, pd_read, files, units):
skip = [0, 1, 2, 3, 5]
with filetexts(files, mode="b"):
df = dd_read("2014-01-*.csv", skiprows=skip)
expected_df = pd.concat([pd_read(n, skiprows=skip) for n in sorted(files)])
expected_df = read_files_with(files, pd_read, skiprows=skip)
assert_eq(df, expected_df, check_dtype=False)


Expand Down Expand Up @@ -412,6 +427,7 @@ def test_read_csv_skiprows_only_in_first_partition(dd_read, pd_read, text, skip)
[(dd.read_csv, pd.read_csv, csv_files), (dd.read_table, pd.read_table, tsv_files)],
)
def test_read_csv_files(dd_read, pd_read, files):
expected = read_files()
with filetexts(files, mode="b"):
df = dd_read("2014-01-*.csv")
assert_eq(df, expected, check_dtype=False)
Expand All @@ -429,7 +445,7 @@ def test_read_csv_files(dd_read, pd_read, files):
def test_read_csv_files_list(dd_read, pd_read, files):
with filetexts(files, mode="b"):
subset = sorted(files)[:2] # Just first 2
sol = pd.concat([pd_read(BytesIO(files[k])) for k in subset])
sol = read_files(subset)
res = dd_read(subset)
assert_eq(res, sol, check_dtype=False)

Expand Down Expand Up @@ -616,11 +632,11 @@ def test_consistent_dtypes_2():
Frank,600
"""
)

string_dtype = get_string_dtype()
with filetexts({"foo.1.csv": text1, "foo.2.csv": text2}):
df = dd.read_csv("foo.*.csv", blocksize=25)
assert df.name.dtype == object
assert df.name.compute().dtype == object
assert df.name.dtype == string_dtype
assert df.name.compute().dtype == string_dtype


def test_categorical_dtypes():
Expand Down Expand Up @@ -759,6 +775,8 @@ def test_read_csv_sensitive_to_enforce():
def test_read_csv_compression(fmt, blocksize):
if fmt and fmt not in compress:
pytest.skip("compress function not provided for %s" % fmt)

expected = read_files()
suffix = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz"}.get(fmt, "")
files2 = valmap(compress[fmt], csv_files) if fmt else csv_files
renamed_files = {k + suffix: v for k, v in files2.items()}
Expand Down Expand Up @@ -1246,6 +1264,7 @@ def test_robust_column_mismatch():
assert_eq(ddf, ddf)


@pytest.mark.xfail_with_pyarrow_strings # needs a follow-up
def test_different_columns_are_allowed():
files = csv_files.copy()
k = sorted(files)[-1]
Expand Down Expand Up @@ -1723,6 +1742,7 @@ def test_csv_getitem_column_order(tmpdir):
assert_eq(df1[columns], df2)


@pytest.mark.skip_with_pyarrow_strings # checks graph layers
def test_getitem_optimization_after_filter():
with filetext(timeseries) as fn:
expect = pd.read_csv(fn)
Expand Down
6 changes: 4 additions & 2 deletions dask/dataframe/io/tests/test_demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pandas as pd
import pytest

import dask
import dask.dataframe as dd
from dask.blockwise import Blockwise, optimize_blockwise
from dask.dataframe._compat import tm
from dask.dataframe.optimize import optimize_dataframe_getitem
from dask.dataframe.utils import assert_eq
from dask.dataframe.utils import assert_eq, get_string_dtype


def test_make_timeseries():
Expand All @@ -18,7 +19,7 @@ def test_make_timeseries():
tm.assert_index_equal(df.columns, pd.Index(["A", "B", "C"]))
assert df["A"].head().dtype == float
assert df["B"].head().dtype == int
assert df["C"].head().dtype == object
assert df["C"].head().dtype == get_string_dtype()
assert df.index.name == "timestamp"
assert df.head().index.name == df.index.name
assert df.divisions == tuple(pd.date_range(start="2000", end="2015", freq="6M"))
Expand Down Expand Up @@ -80,6 +81,7 @@ def test_make_timeseries_no_args():
assert len(set(df.dtypes)) > 1


@pytest.mark.skip_with_pyarrow_strings # checks graph layers
def test_make_timeseries_blockwise():
df = dd.demo.make_timeseries()
df = df[["x", "y"]]
Expand Down
4 changes: 4 additions & 0 deletions dask/dataframe/io/tests/test_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from dask.layers import DataFrameIOLayer
from dask.utils import dependency_depth, tmpdir, tmpfile

# there's no support in upstream for writing HDF with extension dtypes yet.
# see https://github.com/pandas-dev/pandas/issues/31199
pytestmark = pytest.mark.skip_with_pyarrow_strings


def test_to_hdf():
pytest.importorskip("tables")
Expand Down
Loading

0 comments on commit 15ba4b3

Please sign in to comment.