diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 694d31d067f7..7a3aad5e995e 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -43,6 +43,7 @@ serde_json = { version = "1.0", features = ["preserve_order"] } indexmap = "1.6" rand = { version = "0.8", optional = true } num = "0.4" +half = "1.8" csv_crate = { version = "1.1", optional = true, package="csv" } regex = "1.3" lazy_static = "1.4" @@ -61,6 +62,7 @@ bitflags = "1.2.1" default = ["csv", "ipc", "test_utils"] avx512 = [] csv = ["csv_crate"] +f16 = ["half"] ipc = ["flatbuffers"] simd = ["packed_simd"] prettyprint = ["comfy-table"] diff --git a/arrow/src/alloc/types.rs b/arrow/src/alloc/types.rs index 92a6107f3d54..026e1241f46b 100644 --- a/arrow/src/alloc/types.rs +++ b/arrow/src/alloc/types.rs @@ -16,6 +16,7 @@ // under the License. use crate::datatypes::DataType; +use half::f16; /// A type that Rust's custom allocator knows how to allocate and deallocate. /// This is implemented for all Arrow's physical types whose in-memory representation @@ -67,5 +68,6 @@ create_native!( i64, DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) ); +create_native!(f16, DataType::Float16); create_native!(f32, DataType::Float32); create_native!(f64, DataType::Float64); diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index fcf4647666e8..34cdb73f7166 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -240,7 +240,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef, DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef, DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef, - DataType::Float16 => panic!("Float16 datatype not supported"), + DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef, DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef, DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef, @@ -393,7 +393,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { DataType::UInt8 => new_null_sized_array::(data_type, length), DataType::Int16 => new_null_sized_array::(data_type, length), DataType::UInt16 => new_null_sized_array::(data_type, length), - DataType::Float16 => unreachable!(), + DataType::Float16 => new_null_sized_array::(data_type, length), DataType::Int32 => new_null_sized_array::(data_type, length), DataType::UInt32 => new_null_sized_array::(data_type, length), DataType::Float32 => new_null_sized_array::(data_type, length), diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 40a8beebf51f..0fc04ef671a3 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -18,10 +18,6 @@ //! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates //! common attributes and operations for Arrow array. -use std::convert::TryInto; -use std::mem; -use std::sync::Arc; - use crate::datatypes::{DataType, IntervalUnit}; use crate::error::{ArrowError, Result}; use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; @@ -29,6 +25,9 @@ use crate::{ buffer::{Buffer, MutableBuffer}, util::bit_util, }; +use std::convert::TryInto; +use std::mem; +use std::sync::Arc; use super::equal::equal; @@ -89,6 +88,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], + DataType::Float16 => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], DataType::Float32 => [ MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, @@ -178,7 +181,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff ], _ => unreachable!(), }, - DataType::Float16 => unreachable!(), DataType::FixedSizeList(_, _) | DataType::Struct(_) => { [empty_buffer, MutableBuffer::new(0)] } @@ -319,7 +321,7 @@ impl ArrayData { buffers: Vec, child_data: Vec, ) -> Result { - // Safetly justification: `validate` is (will be) called below + // Safety justification: `validate` is (will be) called below let new_self = unsafe { Self::new_unchecked( data_type, @@ -519,6 +521,7 @@ impl ArrayData { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::Date32 @@ -554,7 +557,6 @@ impl ArrayData { DataType::Dictionary(_, data_type) => { vec![Self::new_empty(data_type)] } - DataType::Float16 => unreachable!(), }; // Data was constructed correctly above diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 15d41a0d67d6..dfceff394053 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -251,7 +251,12 @@ fn equal_values( ), _ => unreachable!(), }, - DataType::Float16 => unreachable!(), + DataType::Float16 => { + use half::f16; + primitive_equal::( + lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, + ) + } DataType::Map(_, _) => { list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) } diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 1298e58edc3f..b351d040f76b 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -194,6 +194,14 @@ pub type UInt64Array = PrimitiveArray; /// /// # Example: Using `collect` /// ``` +/// # use arrow::array::Float16Array; +/// use half::f16; +/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect(); +/// ``` +pub type Float16Array = PrimitiveArray; +/// +/// # Example: Using `collect` +/// ``` /// # use arrow::array::Float32Array; /// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect(); /// ``` diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index 2c1884861f68..9ad3dbf7c13b 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. +use super::{ + data::{into_buffers, new_buffers}, + ArrayData, ArrayDataBuilder, +}; +use crate::array::StringOffsetSizeTrait; use crate::{ buffer::MutableBuffer, datatypes::DataType, error::{ArrowError, Result}, util::bit_util, }; +use half::f16; use std::mem; -use super::{ - data::{into_buffers, new_buffers}, - ArrayData, ArrayDataBuilder, -}; -use crate::array::StringOffsetSizeTrait; - mod boolean; mod fixed_binary; mod list; @@ -266,7 +266,7 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), - DataType::Float16 => unreachable!(), + DataType::Float16 => primitive::build_extend::(array), /* DataType::FixedSizeList(_, _) => {} DataType::Union(_) => {} @@ -315,7 +315,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { }, DataType::Struct(_) => structure::extend_nulls, DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, - DataType::Float16 => unreachable!(), + DataType::Float16 => primitive::extend_nulls::, /* DataType::FixedSizeList(_, _) => {} DataType::Union(_) => {} @@ -429,6 +429,7 @@ impl<'a> MutableArrayData<'a> { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::Date32 @@ -467,7 +468,6 @@ impl<'a> MutableArrayData<'a> { } // the dictionary type just appends keys and clones the values. DataType::Dictionary(_, _) => vec![], - DataType::Float16 => unreachable!(), DataType::Struct(fields) => match capacities { Capacities::Struct(capacity, Some(ref child_capacities)) => { array_capacity = capacity; diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 6e8cf892237e..18d593b72980 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use serde_json::{Number, Value}; - use super::DataType; +use half::f16; +use serde_json::{Number, Value}; /// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.). pub trait JsonSerializable: 'static { @@ -293,6 +293,12 @@ impl ArrowNativeType for u64 { } } +impl JsonSerializable for f16 { + fn into_json_value(self) -> Option { + Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number) + } +} + impl JsonSerializable for f32 { fn into_json_value(self) -> Option { Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number) @@ -305,6 +311,7 @@ impl JsonSerializable for f64 { } } +impl ArrowNativeType for f16 {} impl ArrowNativeType for f32 {} impl ArrowNativeType for f64 {} diff --git a/arrow/src/datatypes/types.rs b/arrow/src/datatypes/types.rs index 30c9aae89565..2731e3d46658 100644 --- a/arrow/src/datatypes/types.rs +++ b/arrow/src/datatypes/types.rs @@ -16,6 +16,7 @@ // under the License. use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit}; +use half::f16; // BooleanType is special: its bit-width is not the size of the primitive type, and its `index` // operation assumes bit-packing. @@ -46,6 +47,7 @@ make_type!(UInt8Type, u8, DataType::UInt8); make_type!(UInt16Type, u16, DataType::UInt16); make_type!(UInt32Type, u32, DataType::UInt32); make_type!(UInt64Type, u64, DataType::UInt64); +make_type!(Float16Type, f16, DataType::Float16); make_type!(Float32Type, f32, DataType::Float32); make_type!(Float64Type, f64, DataType::Float64); make_type!(