From d7ded732afb71c13490ea042f41924549b32053e Mon Sep 17 00:00:00 2001 From: Benedikt Reinartz Date: Tue, 12 Mar 2024 01:05:19 +0100 Subject: [PATCH] Implement Decoder/Encoder support for i/u128 --- rustler/src/types/i128.rs | 130 ++++++++++++++++++ rustler/src/types/mod.rs | 1 + rustler_tests/lib/rustler_test.ex | 2 + rustler_tests/native/rustler_test/src/lib.rs | 2 + .../rustler_test/src/test_primitives.rs | 10 ++ rustler_tests/test/primitives_test.exs | 31 +++++ 6 files changed, 176 insertions(+) create mode 100644 rustler/src/types/i128.rs diff --git a/rustler/src/types/i128.rs b/rustler/src/types/i128.rs new file mode 100644 index 00000000..ae76e41b --- /dev/null +++ b/rustler/src/types/i128.rs @@ -0,0 +1,130 @@ +use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; +use std::convert::TryFrom; + +const EXTERNAL_TERM_FORMAT_VERSION: u8 = 131; +const SMALL_BIG_EXT: u8 = 110; + +impl Encoder for i128 { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + if let Ok(int) = i64::try_from(*self) { + int.encode(env) + } else { + let mut etf = [0u8; 4 + 16]; + etf[0] = EXTERNAL_TERM_FORMAT_VERSION; + etf[1] = SMALL_BIG_EXT; + etf[2] = 16; // length in bytes + if *self < 0 { + etf[3] = 1; + let bytes = (-self).to_le_bytes(); + etf[4..].copy_from_slice(&bytes); + } else { + etf[4..].copy_from_slice(&self.to_le_bytes()); + } + let (term, _) = env.binary_to_term(&etf).unwrap(); + term + } + } +} + +impl Encoder for u128 { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + if let Ok(int) = u64::try_from(*self) { + int.encode(env) + } else { + let mut etf = [0u8; 4 + 16]; + etf[0] = EXTERNAL_TERM_FORMAT_VERSION; + etf[1] = SMALL_BIG_EXT; + etf[2] = 16; // length in bytes + etf[4..].copy_from_slice(&self.to_le_bytes()); + let (term, _) = env.binary_to_term(&etf).unwrap(); + term + } + } +} + +impl<'a> Decoder<'a> for i128 { + fn decode(term: Term<'a>) -> NifResult { + if !term.is_integer() { + return Err(Error::BadArg); + } + + if let Ok(int) = term.decode::() { + return Ok(int as i128); + } + + let input = term.to_binary(); + let input = input.as_slice(); + if input.len() < 4 { + return Err(Error::BadArg); + } + + if input[0] != EXTERNAL_TERM_FORMAT_VERSION { + return Err(Error::BadArg); + } + + if input[1] != SMALL_BIG_EXT { + return Err(Error::BadArg); + } + + let n = input[2] as usize; + if n > 16 { + return Err(Error::BadArg); + } + + let mut res = [0u8; 16]; + res[16 - n..].copy_from_slice(&input[4..4 + n]); + let res = i128::from_le_bytes(res); + if res < 0 { + // The stored data is supposed to be unsigned, so if we interpret it as negative here, + // it was too large. + return Err(Error::BadArg); + } + + if input[3] == 0 { + Ok(res) + } else { + Ok(-res) + } + } +} + +impl<'a> Decoder<'a> for u128 { + fn decode(term: Term<'a>) -> NifResult { + if !term.is_integer() { + return Err(Error::BadArg); + } + + if let Ok(int) = term.decode::() { + return Ok(int as u128); + } + + let input = term.to_binary(); + let input = input.as_slice(); + + if input.len() < 4 { + return Err(Error::BadArg); + } + + if input[0] != EXTERNAL_TERM_FORMAT_VERSION { + return Err(Error::BadArg); + } + + if input[1] != SMALL_BIG_EXT { + return Err(Error::BadArg); + } + + let n = input[2] as usize; + if n > 16 { + return Err(Error::BadArg); + } + + if input[3] == 1 { + // Negative value + return Err(Error::BadArg); + } + + let mut res = [0u8; 16]; + res[16 - n..].copy_from_slice(&input[4..4 + n]); + Ok(u128::from_le_bytes(res)) + } +} diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index 5aa0393e..b6205dc2 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -2,6 +2,7 @@ use crate::{Env, Error, NifResult, Term}; #[macro_use] pub mod atom; +pub mod i128; pub use crate::types::atom::Atom; pub mod binary; diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index ce63a45d..0c2dc351 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -24,6 +24,8 @@ defmodule RustlerTest do def add_u32(_, _), do: err() def add_i32(_, _), do: err() def echo_u8(_), do: err() + def echo_u128(_), do: err() + def echo_i128(_), do: err() def option_inc(_), do: err() def erlang_option_inc(_), do: err() def result_to_int(_), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index a7a4f700..144353fb 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -23,6 +23,8 @@ rustler::init!( test_primitives::option_inc, test_primitives::erlang_option_inc, test_primitives::result_to_int, + test_primitives::echo_u128, + test_primitives::echo_i128, test_list::sum_list, test_list::make_list, test_term::term_debug, diff --git a/rustler_tests/native/rustler_test/src/test_primitives.rs b/rustler_tests/native/rustler_test/src/test_primitives.rs index bcee1673..f6faeb63 100644 --- a/rustler_tests/native/rustler_test/src/test_primitives.rs +++ b/rustler_tests/native/rustler_test/src/test_primitives.rs @@ -33,3 +33,13 @@ pub fn result_to_int(res: Result) -> Result { Err(errstr) => Err(format!("{}{}", errstr, errstr)), } } + +#[rustler::nif] +pub fn echo_u128(n: u128) -> u128 { + n +} + +#[rustler::nif] +pub fn echo_i128(n: i128) -> i128 { + n +} diff --git a/rustler_tests/test/primitives_test.exs b/rustler_tests/test/primitives_test.exs index 3462ace8..10b8d6fd 100644 --- a/rustler_tests/test/primitives_test.exs +++ b/rustler_tests/test/primitives_test.exs @@ -32,4 +32,35 @@ defmodule RustlerTest.PrimitivesTest do assert {:error, "watwat"} == RustlerTest.result_to_int({:error, "wat"}) assert_raise ArgumentError, fn -> RustlerTest.result_to_int({:great, true}) end end + + test "i128 support" do + import Bitwise + + i = 1 <<< 62 + assert i == RustlerTest.echo_i128(i) + assert -i == RustlerTest.echo_i128(-i) + + i = 1 <<< 126 + assert i == RustlerTest.echo_i128(i) + assert -i == RustlerTest.echo_i128(-i) + + assert_raise ArgumentError, fn -> RustlerTest.echo_i128(:non_int) end + assert_raise ArgumentError, fn -> RustlerTest.echo_i128(123.45) end + assert_raise ArgumentError, fn -> RustlerTest.echo_i128(1 <<< 127) end + assert_raise ArgumentError, fn -> RustlerTest.echo_i128(1 <<< 128) end + end + + test "u128 support" do + import Bitwise + + i = 1 <<< 63 + assert i == RustlerTest.echo_u128(i) + + i = 1 <<< 127 + assert i == RustlerTest.echo_u128(i) + + assert_raise ArgumentError, fn -> RustlerTest.echo_u128(:non_int) end + assert_raise ArgumentError, fn -> RustlerTest.echo_u128(123.45) end + assert_raise ArgumentError, fn -> RustlerTest.echo_i128(1 <<< 128) end + end end