From 2e89d6aa056377a4ca10eb926d579cbad3a2532e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 31 Oct 2021 14:20:53 +0800 Subject: [PATCH] add support for f16 --- arrow/Cargo.toml | 4 +++- arrow/src/alloc/types.rs | 4 ++++ arrow/src/array/mod.rs | 8 ++++++++ arrow/src/datatypes/native.rs | 14 ++++++++++++-- arrow/src/datatypes/types.rs | 4 ++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 88623211dab7..e81b23c6cbcd 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -42,6 +42,7 @@ serde_json = { version = "1.0", features = ["preserve_order"] } indexmap = "1.6" rand = { version = "0.8", optional = true } num = "0.4" +half = { version = "1.8", optional = true } csv_crate = { version = "1.1", optional = true, package="csv" } regex = "1.3" lazy_static = "1.4" @@ -57,9 +58,10 @@ multiversion = "0.6.1" bitflags = "1.2.1" [features] -default = ["csv", "ipc", "test_utils"] +default = ["csv", "ipc", "test_utils", "f16"] avx512 = [] csv = ["csv_crate"] +f16 = ["half"] ipc = ["flatbuffers"] simd = ["packed_simd"] prettyprint = ["comfy-table"] diff --git a/arrow/src/alloc/types.rs b/arrow/src/alloc/types.rs index 92a6107f3d54..1ba1d642bb08 100644 --- a/arrow/src/alloc/types.rs +++ b/arrow/src/alloc/types.rs @@ -16,6 +16,8 @@ // under the License. use crate::datatypes::DataType; +#[cfg(feature = "f16")] +use half::f16; /// A type that Rust's custom allocator knows how to allocate and deallocate. /// This is implemented for all Arrow's physical types whose in-memory representation @@ -67,5 +69,7 @@ create_native!( i64, DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) ); +#[cfg(feature = "f16")] +create_native!(f16, DataType::Float16); create_native!(f32, DataType::Float32); create_native!(f64, DataType::Float64); diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 5d4e57a7625a..93e26e9bc757 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -194,6 +194,14 @@ pub type UInt64Array = PrimitiveArray; /// /// # Example: Using `collect` /// ``` +/// # use arrow::array::Float16Array; +/// let arr : Float16Array = [Some(1.0), Some(2.0)].into_iter().collect(); +/// ``` +#[cfg(feature = "f16")] +pub type Float16Array = PrimitiveArray; +/// +/// # Example: Using `collect` +/// ``` /// # use arrow::array::Float32Array; /// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect(); /// ``` diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 6e8cf892237e..36a7430616f0 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use serde_json::{Number, Value}; - use super::DataType; +#[cfg(feature = "f16")] +use half::f16; +use serde_json::{Number, Value}; /// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.). pub trait JsonSerializable: 'static { @@ -293,6 +294,13 @@ impl ArrowNativeType for u64 { } } +#[cfg(feature = "f16")] +impl JsonSerializable for f16 { + fn into_json_value(self) -> Option { + Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number) + } +} + impl JsonSerializable for f32 { fn into_json_value(self) -> Option { Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number) @@ -305,6 +313,8 @@ impl JsonSerializable for f64 { } } +#[cfg(feature = "f16")] +impl ArrowNativeType for f16 {} impl ArrowNativeType for f32 {} impl ArrowNativeType for f64 {} diff --git a/arrow/src/datatypes/types.rs b/arrow/src/datatypes/types.rs index 30c9aae89565..4384c9a20d59 100644 --- a/arrow/src/datatypes/types.rs +++ b/arrow/src/datatypes/types.rs @@ -16,6 +16,8 @@ // under the License. use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit}; +#[cfg(feature = "f16")] +use half::f16; // BooleanType is special: its bit-width is not the size of the primitive type, and its `index` // operation assumes bit-packing. @@ -46,6 +48,8 @@ make_type!(UInt8Type, u8, DataType::UInt8); make_type!(UInt16Type, u16, DataType::UInt16); make_type!(UInt32Type, u32, DataType::UInt32); make_type!(UInt64Type, u64, DataType::UInt64); +#[cfg(feature = "f16")] +make_type!(Float16Type, f16, DataType::Float16); make_type!(Float32Type, f32, DataType::Float32); make_type!(Float64Type, f64, DataType::Float64); make_type!(