Skip to content

Commit

Permalink
feat(python)!: Support decimals by default when converting from Arrow (
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jun 4, 2024
1 parent b61d4e6 commit 9912af0
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 124 deletions.
9 changes: 0 additions & 9 deletions crates/polars-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ pub(crate) const FMT_TABLE_INLINE_COLUMN_DATA_TYPE: &str =
pub(crate) const FMT_TABLE_ROUNDED_CORNERS: &str = "POLARS_FMT_TABLE_ROUNDED_CORNERS";
pub(crate) const FMT_TABLE_CELL_LIST_LEN: &str = "POLARS_FMT_TABLE_CELL_LIST_LEN";

// Other env vars
#[cfg(all(feature = "dtype-decimal", feature = "python"))]
pub(crate) const DECIMAL_ACTIVE: &str = "POLARS_ACTIVATE_DECIMAL";

#[cfg(all(feature = "dtype-decimal", feature = "python"))]
pub(crate) fn decimal_is_active() -> bool {
std::env::var(DECIMAL_ACTIVE).as_deref().unwrap_or("") == "1"
}

pub fn verbose() -> bool {
std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1"
}
Expand Down
45 changes: 8 additions & 37 deletions crates/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ use crate::chunked_array::object::extension::EXTENSION_NAME;
use crate::chunked_array::temporal::parse_fixed_offset;
#[cfg(feature = "timezones")]
use crate::chunked_array::temporal::validate_time_zone;
#[cfg(all(feature = "dtype-decimal", feature = "python"))]
use crate::config::decimal_is_active;
use crate::config::verbose;
use crate::prelude::*;

Expand Down Expand Up @@ -449,41 +447,14 @@ impl Series {
#[cfg(feature = "dtype-decimal")]
ArrowDataType::Decimal(precision, scale)
| ArrowDataType::Decimal256(precision, scale) => {
#[cfg(feature = "python")]
{
let (precision, scale) = (Some(*precision), *scale);
let chunks =
cast_chunks(&chunks, &DataType::Decimal(precision, Some(scale)), false)
.unwrap();
if decimal_is_active() {
Ok(Int128Chunked::from_chunks(name, chunks)
.into_decimal_unchecked(precision, scale)
.into_series())
} else {
if verbose() {
eprintln!(
"Activate beta decimal types to read as decimal. Current behavior casts to Float64."
);
}
Ok(Float64Chunked::from_chunks(
name,
cast_chunks(&chunks, &DataType::Float64, true).unwrap(),
)
.into_series())
}
}

#[cfg(not(feature = "python"))]
{
let (precision, scale) = (Some(*precision), *scale);
let chunks =
cast_chunks(&chunks, &DataType::Decimal(precision, Some(scale)), false)
.unwrap();
// or DecimalChunked?
Ok(Int128Chunked::from_chunks(name, chunks)
.into_decimal_unchecked(precision, scale)
.into_series())
}
let (precision, scale) = (Some(*precision), *scale);
let chunks =
cast_chunks(&chunks, &DataType::Decimal(precision, Some(scale)), false)
.unwrap();
// or DecimalChunked?
Ok(Int128Chunked::from_chunks(name, chunks)
.into_decimal_unchecked(precision, scale)
.into_series())
},
#[allow(unreachable_patterns)]
ArrowDataType::Decimal256(_, _) | ArrowDataType::Decimal(_, _) => {
Expand Down
6 changes: 0 additions & 6 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,18 +1101,12 @@ mod test {
let mut s1 = s1.clone();
s1.append(&s2).unwrap();
assert_eq!(s1.len(), 3);
#[cfg(feature = "python")]
assert_eq!(s1.get(2).unwrap(), AnyValue::Float64(3.0));
#[cfg(not(feature = "python"))]
assert_eq!(s1.get(2).unwrap(), AnyValue::Decimal(300, 2));
}

{
let mut s2 = s2.clone();
s2.extend(&s1).unwrap();
#[cfg(feature = "python")]
assert_eq!(s2.get(2).unwrap(), AnyValue::Float64(2.29)); // 2.3 == 2.2999999999999998
#[cfg(not(feature = "python"))]
assert_eq!(s2.get(2).unwrap(), AnyValue::Decimal(2, 0));
}
}
Expand Down
22 changes: 15 additions & 7 deletions py-polars/polars/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, get_args

from polars._utils.deprecation import deprecate_nonkeyword_arguments
from polars._utils.deprecation import (
deprecate_nonkeyword_arguments,
issue_deprecation_warning,
)
from polars._utils.various import normalize_filepath
from polars.dependencies import json

Expand Down Expand Up @@ -44,7 +47,6 @@
# and/or unstable settings that should not be saved or reset with the Config vars.
_POLARS_CFG_ENV_VARS = {
"POLARS_WARN_UNSTABLE",
"POLARS_ACTIVATE_DECIMAL",
"POLARS_AUTO_STRUCTIFY",
"POLARS_FMT_MAX_COLS",
"POLARS_FMT_MAX_ROWS",
Expand Down Expand Up @@ -353,13 +355,19 @@ def activate_decimals(cls, active: bool | None = True) -> type[Config]:
"""
Activate `Decimal` data types.
.. deprecated:: 1.0.0
Decimals are now always active and this function is a no-op.
This setting will be removed in the next major version.
This is a temporary setting that will be removed once the `Decimal` type
stabilizes (`Decimal` is currently considered to be in beta testing).
stabilizes (`Decimal` is currently considered *unstable*).
"""
if not active:
os.environ.pop("POLARS_ACTIVATE_DECIMAL", None)
else:
os.environ["POLARS_ACTIVATE_DECIMAL"] = str(int(active))
issue_deprecation_warning(
"`Config.activate_decimals` is deprecated and will be removed in the next major version.`"
" Decimals are now always active and this function is a no-op."
" Remove the call to `activate_decimals` to silence this warning.",
version="1.0.0",
)
return cls

@classmethod
Expand Down
2 changes: 0 additions & 2 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,6 @@ def test_error_string_dtypes() -> None:

def test_init_structured_objects(monkeypatch: Any) -> None:
# validate init from dataclass, namedtuple, and pydantic model objects
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")

@dataclasses.dataclass
class TradeDC:
timestamp: datetime
Expand Down
1 change: 0 additions & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def test_mixed_sequence_selection() -> None:


def test_from_arrow(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")
tbl = pa.table(
{
"a": pa.array([1, 2], pa.timestamp("s")),
Expand Down
72 changes: 30 additions & 42 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def test_series_from_pydecimal_and_ints(
def test_frame_from_pydecimal_and_ints(
permutations_int_dec_none: list[tuple[D | int | None, ...]], monkeypatch: Any
) -> None:
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")

class X(NamedTuple):
a: int | D | None

Expand Down Expand Up @@ -152,7 +150,6 @@ def test_decimal_cast_no_scale() -> None:


def test_decimal_scale_precision_roundtrip(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")
assert pl.from_arrow(pl.Series("dec", [D("10.0")]).to_arrow()).item() == D("10.0")


Expand Down Expand Up @@ -184,7 +181,6 @@ def test_string_to_decimal() -> None:


def test_read_csv_decimal(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")
csv = """a,b
123.12,a
1.1,a
Expand Down Expand Up @@ -407,41 +403,36 @@ def test_decimal_write_parquet_12375() -> None:


def test_decimal_list_get_13847() -> None:
with pl.Config() as cfg:
cfg.activate_decimals()
df = pl.DataFrame({"a": [[D("1.1"), D("1.2")], [D("2.1")]]})
out = df.select(pl.col("a").list.get(0))
expected = pl.DataFrame({"a": [D("1.1"), D("2.1")]})
assert_frame_equal(out, expected)
df = pl.DataFrame({"a": [[D("1.1"), D("1.2")], [D("2.1")]]})
out = df.select(pl.col("a").list.get(0))
expected = pl.DataFrame({"a": [D("1.1"), D("2.1")]})
assert_frame_equal(out, expected)


def test_decimal_explode() -> None:
with pl.Config() as cfg:
cfg.activate_decimals()
nested_decimal_df = pl.DataFrame(
{
"bar": [[D("3.4"), D("3.4")], [D("4.5")]],
}
)
df = nested_decimal_df.explode("bar")
expected_df = pl.DataFrame(
{
"bar": [D("3.4"), D("3.4"), D("4.5")],
}
)
assert_frame_equal(df, expected_df)

nested_decimal_df = pl.DataFrame(
{
"bar": [[D("3.4"), D("3.4")], [D("4.5")]],
}
)
df = nested_decimal_df.explode("bar")
expected_df = pl.DataFrame(
{
"bar": [D("3.4"), D("3.4"), D("4.5")],
}
)
assert_frame_equal(df, expected_df)

# test group-by head #15330
df = pl.DataFrame(
{
"foo": [1, 1, 2],
"bar": [D("3.4"), D("3.4"), D("4.5")],
}
)
head_df = df.group_by("foo", maintain_order=True).head(1)
expected_df = pl.DataFrame({"foo": [1, 2], "bar": [D("3.4"), D("4.5")]})
assert_frame_equal(head_df, expected_df)
# test group-by head #15330
df = pl.DataFrame(
{
"foo": [1, 1, 2],
"bar": [D("3.4"), D("3.4"), D("4.5")],
}
)
head_df = df.group_by("foo", maintain_order=True).head(1)
expected_df = pl.DataFrame({"foo": [1, 2], "bar": [D("3.4"), D("4.5")]})
assert_frame_equal(head_df, expected_df)


def test_decimal_streaming() -> None:
Expand All @@ -464,10 +455,7 @@ def test_decimal_streaming() -> None:


def test_decimal_supertype() -> None:
with pl.Config() as cfg:
cfg.activate_decimals()
pl.Config.activate_decimals()
q = pl.LazyFrame([0.12345678]).select(
pl.col("column_0").cast(pl.Decimal(scale=6)) * 1
)
assert q.collect().dtypes[0].is_decimal()
q = pl.LazyFrame([0.12345678]).select(
pl.col("column_0").cast(pl.Decimal(scale=6)) * 1
)
assert q.collect().dtypes[0].is_decimal()
1 change: 0 additions & 1 deletion py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def test_nested_struct_read_12610() -> None:
assert_frame_equal(expect, actual)


@pl.Config(activate_decimals=True)
@pytest.mark.write_disk()
def test_decimal_parquet(tmp_path: Path) -> None:
path = tmp_path / "foo.parquet"
Expand Down
23 changes: 11 additions & 12 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,18 @@ def test_numeric_decimal_type(
expected_value: D,
expected_dtype: PolarsDataType,
) -> None:
with pl.Config(activate_decimals=True):
df = pl.DataFrame({"n": [value]})
with pl.SQLContext(df=df) as ctx:
result = ctx.execute(
f"""
SELECT n::{sqltype}{prec_scale} AS "dec" FROM df
"""
)
expected = pl.LazyFrame(
data={"dec": [expected_value]},
schema={"dec": expected_dtype},
df = pl.DataFrame({"n": [value]})
with pl.SQLContext(df=df) as ctx:
result = ctx.execute(
f"""
SELECT n::{sqltype}{prec_scale} AS "dec" FROM df
"""
)
assert_frame_equal(result, expected)
expected = pl.LazyFrame(
data={"dec": [expected_value]},
schema={"dec": expected_dtype},
)
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
10 changes: 3 additions & 7 deletions py-polars/tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,12 +733,9 @@ def test_config_state_env_only() -> None:
assert "set_fmt_float" not in state_env_only


def test_activate_decimals() -> None:
with pl.Config() as cfg:
cfg.activate_decimals(True)
assert os.environ.get("POLARS_ACTIVATE_DECIMAL") == "1"
cfg.activate_decimals(False)
assert "POLARS_ACTIVATE_DECIMAL" not in os.environ
def test_activate_decimals_deprecated() -> None:
with pytest.deprecated_call():
pl.Config().activate_decimals(True)


def test_set_streaming_chunk_size() -> None:
Expand Down Expand Up @@ -776,7 +773,6 @@ def test_warn_unstable(recwarn: pytest.WarningsRecorder) -> None:
@pytest.mark.parametrize(
("environment_variable", "config_setting", "value", "expected"),
[
("POLARS_ACTIVATE_DECIMAL", "activate_decimals", True, "1"),
("POLARS_AUTO_STRUCTIFY", "set_auto_structify", True, "1"),
("POLARS_FMT_MAX_COLS", "set_tbl_cols", 12, "12"),
("POLARS_FMT_MAX_ROWS", "set_tbl_rows", 3, "3"),
Expand Down

0 comments on commit 9912af0

Please sign in to comment.