diff --git a/Cargo.toml b/Cargo.toml index 1711b36e..c7a3b0c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,6 @@ include = ["/src", "/CHANGELOG.md", "/README.md"] [dependencies] flatbuffers = "24.3.25" -libm = "0.2.6" rayon = "1.7.0" smallvec = { version = "1.10.0", features = ["union", "const_generics", "const_new"] } rten-tensor = { path = "./rten-tensor", version = "0.13.1" } @@ -55,6 +54,7 @@ memmap2 = { version = "0.9.4", optional = true } num_cpus = "1.16.0" [dev-dependencies] +libm = "0.2.6" rten = { path = ".", features = ["mmap", "random"] } rten-bench = { path = "./rten-bench" } serde_json = { workspace = true } diff --git a/src/ops/unary_elementwise.rs b/src/ops/unary_elementwise.rs index 77442949..9c02aec8 100644 --- a/src/ops/unary_elementwise.rs +++ b/src/ops/unary_elementwise.rs @@ -576,7 +576,7 @@ unary_float_op!(Relu, relu, relu_in_place, |val: f32| val.max(0.)); /// Round float values to the nearest integer. Values with a fractional part /// of 0.5 are rounded to the nearest even number, like `round` in Python and -/// unlike `f32::round` in Rust. +/// unlike [`f32::round`] in Rust. #[derive(Debug)] pub struct Round {} impl UnaryFloatOp for Round { @@ -585,8 +585,7 @@ impl UnaryFloatOp for Round { } fn map_element(&self, val: f32) -> f32 { - // Replace this with `f32::round_ties_even` when that is stabilized. - libm::rintf(val) + val.round_ties_even() } }