Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Decimal types in write_csv/write_json #14209

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions crates/polars-arrow/src/compute/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use std::sync::atomic::{AtomicBool, Ordering};

use atoi::FromRadix10SignedChecked;
use num_traits::Euclid;

static TRIM_DECIMAL_ZEROS: AtomicBool = AtomicBool::new(false);

pub fn get_trim_decimal_zeros() -> bool {
TRIM_DECIMAL_ZEROS.load(Ordering::Relaxed)
}
pub fn set_trim_decimal_zeros(trim: Option<bool>) {
TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false), Ordering::Relaxed)
}

/// Count the number of b'0's at the beginning of a slice.
fn leading_zeros(bytes: &[u8]) -> u8 {
bytes.iter().take_while(|byte| **byte == b'0').count() as u8
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-core/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::borrow::Cow;
use std::fmt::{Debug, Display, Formatter, Write};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::RwLock;
use std::{fmt, str};

Expand Down Expand Up @@ -44,7 +44,6 @@ pub enum FloatFmt {
static FLOAT_PRECISION: RwLock<Option<usize>> = RwLock::new(None);
static FLOAT_FMT: AtomicU8 = AtomicU8::new(FloatFmt::Mixed as u8);

static TRIM_DECIMAL_ZEROS: AtomicBool = AtomicBool::new(false);
static THOUSANDS_SEPARATOR: AtomicU8 = AtomicU8::new(b'\0');
static DECIMAL_SEPARATOR: AtomicU8 = AtomicU8::new(b'.');

Expand All @@ -70,8 +69,9 @@ pub fn get_thousands_separator() -> String {
sep.to_string()
}
}
#[cfg(feature = "dtype-decimal")]
pub fn get_trim_decimal_zeros() -> bool {
TRIM_DECIMAL_ZEROS.load(Ordering::Relaxed)
arrow::compute::decimal::get_trim_decimal_zeros()
}

// Numeric formatting setters
Expand All @@ -87,8 +87,9 @@ pub fn set_decimal_separator(dec: Option<char>) {
pub fn set_thousands_separator(sep: Option<char>) {
THOUSANDS_SEPARATOR.store(sep.unwrap_or('\0') as u8, Ordering::Relaxed)
}
#[cfg(feature = "dtype-decimal")]
pub fn set_trim_decimal_zeros(trim: Option<bool>) {
TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false), Ordering::Relaxed)
arrow::compute::decimal::set_trim_decimal_zeros(trim)
}

/// Parses an environment variable value.
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-core/src/named_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ impl_named_from_owned!(Vec<i8>, Int8Type);
impl_named_from_owned!(Vec<i16>, Int16Type);
impl_named_from_owned!(Vec<i32>, Int32Type);
impl_named_from_owned!(Vec<i64>, Int64Type);
#[cfg(feature = "dtype-decimal")]
impl_named_from_owned!(Vec<i128>, Int128Type);
#[cfg(feature = "dtype-u8")]
impl_named_from_owned!(Vec<u8>, UInt8Type);
#[cfg(feature = "dtype-u16")]
Expand Down Expand Up @@ -77,6 +79,8 @@ impl_named_from!([i8], Int8Type, from_slice);
impl_named_from!([i16], Int16Type, from_slice);
impl_named_from!([i32], Int32Type, from_slice);
impl_named_from!([i64], Int64Type, from_slice);
#[cfg(feature = "dtype-decimal")]
impl_named_from!([i128], Int128Type, from_slice);
impl_named_from!([f32], Float32Type, from_slice);
impl_named_from!([f64], Float64Type, from_slice);
impl_named_from!([Option<String>], StringType, from_slice_options);
Expand All @@ -94,6 +98,8 @@ impl_named_from!([Option<i8>], Int8Type, from_slice_options);
impl_named_from!([Option<i16>], Int16Type, from_slice_options);
impl_named_from!([Option<i32>], Int32Type, from_slice_options);
impl_named_from!([Option<i64>], Int64Type, from_slice_options);
#[cfg(feature = "dtype-decimal")]
impl_named_from!([Option<i128>], Int128Type, from_slice_options);
impl_named_from!([Option<f32>], Float32Type, from_slice_options);
impl_named_from!([Option<f64>], Float64Type, from_slice_options);

Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/serde/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ impl Serialize for Series {
let ca = self.time().unwrap();
ca.serialize(serializer)
},
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, _) => {
let ca = self.decimal().unwrap();
ca.serialize(serializer)
},
dt => {
with_match_physical_numeric_polars_type!(dt, |$T| {
let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref();
Expand Down Expand Up @@ -194,6 +199,13 @@ impl<'de> Deserialize<'de> for Series {
let values: Vec<Option<i64>> = map.next_value()?;
Ok(Series::new(&name, values).cast(&DataType::Time).unwrap())
},
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(precision, Some(scale)) => {
let values: Vec<Option<i128>> = map.next_value()?;
Ok(ChunkedArray::from_slice_options(&name, &values)
.into_decimal_unchecked(precision, scale)
.into_series())
},
DataType::Boolean => {
let values: Vec<Option<bool>> = map.next_value()?;
Ok(Series::new(&name, values))
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ timezones = [
"chrono-tz",
"dtype-datetime",
"arrow/timezones",
"polars-json?/chrono-tz",
]
dtype-time = ["polars-core/dtype-time", "polars-core/temporal", "polars-time/dtype-time"]
dtype-struct = ["polars-core/dtype-struct"]
dtype-decimal = ["polars-core/dtype-decimal"]
dtype-decimal = ["polars-core/dtype-decimal", "polars-json?/dtype-decimal"]
fmt = ["polars-core/fmt"]
lazy = []
parquet = ["polars-parquet", "polars-parquet/compression"]
Expand Down
30 changes: 30 additions & 0 deletions crates/polars-io/src/csv/write/write_impl/serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,24 @@ fn bool_serializer<const QUOTE_NON_NULL: bool>(array: &BooleanArray) -> impl Ser
})
}

#[cfg(feature = "dtype-decimal")]
fn decimal_serializer(array: &PrimitiveArray<i128>, scale: usize) -> impl Serializer {
let trim_zeros = arrow::compute::decimal::get_trim_decimal_zeros();

let f = move |&item, buf: &mut Vec<u8>, _options: &SerializeOptions| {
let value = arrow::compute::decimal::format_decimal(item, scale, trim_zeros);
buf.extend_from_slice(value.as_str().as_bytes());
};

make_serializer::<_, _, false>(f, array.iter(), |array| {
array
.as_any()
.downcast_ref::<PrimitiveArray<i128>>()
.expect(ARRAY_MISMATCH_MSG)
.iter()
})
}

#[cfg(any(
feature = "dtype-date",
feature = "dtype-time",
Expand Down Expand Up @@ -666,6 +684,18 @@ pub(super) fn serializer_for<'a>(
array,
)
},
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, scale) => {
let array = array.as_any().downcast_ref().unwrap();
match options.quote_style {
QuoteStyle::Never => Box::new(decimal_serializer(array, scale.unwrap_or(0)))
as Box<dyn Serializer + Send>,
_ => Box::new(quote_serializer(decimal_serializer(
array,
scale.unwrap_or(0),
))),
}
},
_ => polars_bail!(ComputeError: "datatype {dtype} cannot be written to csv"),
};
Ok(serializer)
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-json/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ polars-utils = { workspace = true }
ahash = { workspace = true }
arrow = { workspace = true }
chrono = { workspace = true }
chrono-tz = { workspace = true, optional = true }
fallible-streaming-iterator = { version = "0.1" }
hashbrown = { workspace = true }
indexmap = { workspace = true }
Expand All @@ -25,4 +26,6 @@ simd-json = { workspace = true }
streaming-iterator = { workspace = true }

[features]
chrono-tz = ["dep:chrono-tz", "arrow/chrono-tz"]
dtype-decimal = ["arrow/dtype-decimal"]
timezones = ["arrow/chrono-tz"]
25 changes: 25 additions & 0 deletions crates/polars-json/src/json/write/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::io::Write;

use arrow::array::*;
use arrow::bitmap::utils::ZipValidity;
#[cfg(feature = "dtype-decimal")]
use arrow::compute::decimal::{format_decimal, get_trim_decimal_zeros};
use arrow::datatypes::{ArrowDataType, IntegerType, TimeUnit};
use arrow::io::iterator::BufStreamingIterator;
use arrow::offset::Offset;
Expand Down Expand Up @@ -112,6 +114,25 @@ where
materialize_serializer(f, array.iter(), offset, take)
}

#[cfg(feature = "dtype-decimal")]
fn decimal_serializer<'a>(
array: &'a PrimitiveArray<i128>,
scale: usize,
offset: usize,
take: usize,
) -> Box<dyn StreamingIterator<Item = [u8]> + 'a + Send + Sync> {
let trim_zeros = get_trim_decimal_zeros();
let f = move |x: Option<&i128>, buf: &mut Vec<u8>| {
if let Some(x) = x {
utf8::write_str(buf, format_decimal(*x, scale, trim_zeros).as_str()).unwrap()
} else {
buf.extend(b"null")
}
};

materialize_serializer(f, array.iter(), offset, take)
}

fn dictionary_utf8view_serializer<'a, K: DictionaryKey>(
array: &'a DictionaryArray<K>,
offset: usize,
Expand Down Expand Up @@ -419,6 +440,10 @@ pub(crate) fn new_serializer<'a>(
ArrowDataType::Float64 => {
float_serializer::<f64>(array.as_any().downcast_ref().unwrap(), offset, take)
},
#[cfg(feature = "dtype-decimal")]
ArrowDataType::Decimal(_, scale) => {
decimal_serializer(array.as_any().downcast_ref().unwrap(), *scale, offset, take)
},
ArrowDataType::LargeUtf8 => {
utf8_serializer::<i64>(array.as_any().downcast_ref().unwrap(), offset, take)
},
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
from datetime import date, datetime, timedelta
from decimal import Decimal as D
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -91,6 +92,12 @@ def test_df_serde_enum() -> None:
],
pl.Array(pl.Datetime, shape=3),
),
(
[[D("1.0"), D("2.0"), D("3.0")], [None, None, None]],
# we have to specify precision, because `AnonymousListBuilder::finish`
# use `ArrowDataType` which will remap `None` precision to `38`
pl.Array(pl.Decimal(precision=38, scale=1), shape=3),
),
],
)
def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
Expand Down
34 changes: 18 additions & 16 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import textwrap
import zlib
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal as D
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, TypedDict

Expand Down Expand Up @@ -1805,34 +1806,35 @@ class TemporalFormats(TypedDict):
"date": [dt, None, dt],
"datetime": [None, dtm, dtm],
"time": [tm, tm, None],
"decimal": [D("1.0"), D("2.0"), None],
}
)

assert df.write_csv(quote_style="always", **temporal_formats) == (
'"float","string","int","bool","date","datetime","time"\n'
'"1.0","a","1","true","2077-07-05","","03:01:00"\n'
'"2.0","a,bc","2","false","","2077-07-05T03:01:00","03:01:00"\n'
'"","""hello","3","","2077-07-05","2077-07-05T03:01:00",""\n'
'"float","string","int","bool","date","datetime","time","decimal"\n'
'"1.0","a","1","true","2077-07-05","","03:01:00","1.0"\n'
'"2.0","a,bc","2","false","","2077-07-05T03:01:00","03:01:00","2.0"\n'
'"","""hello","3","","2077-07-05","2077-07-05T03:01:00","",""\n'
)
assert df.write_csv(quote_style="necessary", **temporal_formats) == (
"float,string,int,bool,date,datetime,time\n"
"1.0,a,1,true,2077-07-05,,03:01:00\n"
'2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00\n'
',"""hello",3,,2077-07-05,2077-07-05T03:01:00,\n'
"float,string,int,bool,date,datetime,time,decimal\n"
'1.0,a,1,true,2077-07-05,,03:01:00,"1.0"\n'
'2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,"2.0"\n'
',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,""\n'
)
assert df.write_csv(quote_style="never", **temporal_formats) == (
"float,string,int,bool,date,datetime,time\n"
"1.0,a,1,true,2077-07-05,,03:01:00\n"
"2.0,a,bc,2,false,,2077-07-05T03:01:00,03:01:00\n"
',"hello,3,,2077-07-05,2077-07-05T03:01:00,\n'
"float,string,int,bool,date,datetime,time,decimal\n"
"1.0,a,1,true,2077-07-05,,03:01:00,1.0\n"
"2.0,a,bc,2,false,,2077-07-05T03:01:00,03:01:00,2.0\n"
',"hello,3,,2077-07-05,2077-07-05T03:01:00,,\n'
)
assert df.write_csv(
quote_style="non_numeric", quote_char="8", **temporal_formats
) == (
"8float8,8string8,8int8,8bool8,8date8,8datetime8,8time8\n"
"1.0,8a8,1,8true8,82077-07-058,,803:01:008\n"
"2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008\n"
',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,\n'
"8float8,8string8,8int8,8bool8,8date8,8datetime8,8time8,8decimal8\n"
"1.0,8a8,1,8true8,82077-07-058,,803:01:008,81.08\n"
"2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,82.08\n"
',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,88\n'
)


Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import typing
from collections import OrderedDict
from decimal import Decimal as D
from io import BytesIO
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -53,6 +54,14 @@ def test_write_json_duration() -> None:
assert value == expected


def test_write_json_decimal() -> None:
df = pl.DataFrame({"a": pl.Series([D("1.00"), D("2.00"), None])})

# we don't guarantee a format, just round-circling
value = df.write_json()
assert value == """[{"a":"1.00"},{"a":"2.00"},{"a":null}]"""


def test_json_infer_schema_length_11148() -> None:
response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1
result = pl.read_json(json.dumps(response).encode(), infer_schema_length=2)
Expand Down