Skip to content

Commit

Permalink
Support decoding decimals in raw decoder (#3820)
Browse files Browse the repository at this point in the history
  • Loading branch information
spebern authored Mar 9, 2023
1 parent 1883bb6 commit 053973a
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
76 changes: 76 additions & 0 deletions arrow-json/src/raw/decimal_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::marker::PhantomData;

use arrow_array::builder::PrimitiveBuilder;
use arrow_array::types::DecimalType;
use arrow_array::Array;
use arrow_cast::parse::parse_decimal;
use arrow_data::ArrayData;
use arrow_schema::ArrowError;

use crate::raw::tape::{Tape, TapeElement};
use crate::raw::{tape_error, ArrayDecoder};

pub struct DecimalArrayDecoder<D: DecimalType> {
precision: u8,
scale: i8,
// Invariant and Send
phantom: PhantomData<fn(D) -> D>,
}

impl<D: DecimalType> DecimalArrayDecoder<D> {
pub fn new(precision: u8, scale: i8) -> Self {
Self {
precision,
scale,
phantom: PhantomData,
}
}
}

impl<D> ArrayDecoder for DecimalArrayDecoder<D>
where
D: DecimalType,
{
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
let mut builder = PrimitiveBuilder::<D>::with_capacity(pos.len());

for p in pos {
match tape.get(*p) {
TapeElement::Null => builder.append_null(),
TapeElement::String(idx) => {
let s = tape.get_string(idx);
let value = parse_decimal::<D>(s, self.precision, self.scale)?;
builder.append_value(value)
}
TapeElement::Number(idx) => {
let s = tape.get_string(idx);
let value = parse_decimal::<D>(s, self.precision, self.scale)?;
builder.append_value(value)
}
d => return Err(tape_error(d, "decimal")),
}
}

Ok(builder
.finish()
.with_precision_and_scale(self.precision, self.scale)?
.into_data())
}
}
61 changes: 61 additions & 0 deletions arrow-json/src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! [`Reader`]: crate::reader::Reader

use crate::raw::boolean_array::BooleanArrayDecoder;
use crate::raw::decimal_array::DecimalArrayDecoder;
use crate::raw::list_array::ListArrayDecoder;
use crate::raw::map_array::MapArrayDecoder;
use crate::raw::primitive_array::PrimitiveArrayDecoder;
Expand All @@ -33,6 +34,7 @@ use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::io::BufRead;

mod boolean_array;
mod decimal_array;
mod list_array;
mod map_array;
mod primitive_array;
Expand Down Expand Up @@ -291,6 +293,8 @@ fn make_decoder(
data_type => (primitive_decoder, data_type),
DataType::Float32 => primitive_decoder!(Float32Type, data_type),
DataType::Float64 => primitive_decoder!(Float64Type, data_type),
DataType::Decimal128(p, s) => Ok(Box::new(DecimalArrayDecoder::<Decimal128Type>::new(p, s))),
DataType::Decimal256(p, s) => Ok(Box::new(DecimalArrayDecoder::<Decimal256Type>::new(p, s))),
DataType::Boolean => Ok(Box::<BooleanArrayDecoder>::default()),
DataType::Utf8 => Ok(Box::new(StringArrayDecoder::<i32>::new(coerce_primitive))),
DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::<i64>::new(coerce_primitive))),
Expand Down Expand Up @@ -321,6 +325,7 @@ mod tests {
};
use arrow_array::types::Int32Type;
use arrow_array::Array;
use arrow_buffer::ArrowNativeType;
use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_schema::{DataType, Field, Schema};
use std::fs::File;
Expand Down Expand Up @@ -721,4 +726,60 @@ mod tests {
assert!(col3.is_null(4));
assert!(col3.is_null(5));
}

fn test_decimal<T: DecimalType>(data_type: DataType) {
let buf = r#"
{"a": 1, "b": 2, "c": 38.30}
{"a": 2, "b": 4, "c": 123.456}
{"b": 1337, "a": "2.0452"}
{"b": "5", "a": "11034.2"}
{"b": 40}
{"b": 1234, "a": null}
"#;

let schema = Arc::new(Schema::new(vec![
Field::new("a", data_type.clone(), true),
Field::new("b", data_type.clone(), true),
Field::new("c", data_type, true),
]));

let batches = do_read(buf, 1024, true, schema);
assert_eq!(batches.len(), 1);

let col1 = as_primitive_array::<T>(batches[0].column(0));
assert_eq!(col1.null_count(), 2);
assert!(col1.is_null(4));
assert!(col1.is_null(5));
assert_eq!(
col1.values(),
&[100, 200, 204, 1103420, 0, 0].map(T::Native::usize_as)
);

let col2 = as_primitive_array::<T>(batches[0].column(1));
assert_eq!(col2.null_count(), 0);
assert_eq!(
col2.values(),
&[200, 400, 133700, 500, 4000, 123400].map(T::Native::usize_as)
);

let col3 = as_primitive_array::<T>(batches[0].column(2));
assert_eq!(col3.null_count(), 4);
assert!(!col3.is_null(0));
assert!(!col3.is_null(1));
assert!(col3.is_null(2));
assert!(col3.is_null(3));
assert!(col3.is_null(4));
assert!(col3.is_null(5));
assert_eq!(
col3.values(),
&[3830, 12345, 0, 0, 0, 0].map(T::Native::usize_as)
);
}

#[test]
fn test_decimals() {
test_decimal::<Decimal128Type>(DataType::Decimal128(10, 2));
test_decimal::<Decimal256Type>(DataType::Decimal256(10, 2));
}
}

0 comments on commit 053973a

Please sign in to comment.