Skip to content

Commit

Permalink
Implement round (#1015)
Browse files Browse the repository at this point in the history
* Implement `roundTiesEven` method in JS

* Implement bindings and tests for CPU and GPU round

* Fix JS binding definition
  • Loading branch information
dfellis authored Dec 18, 2024
1 parent dc0b09b commit 66f9496
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
10 changes: 10 additions & 0 deletions alan/src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,16 @@ test_gpgpu!(gpu_insertbits => r#"
stdout "[4, 7]\n";
);

test_gpgpu!(gpu_round => r#"
export fn main {
let b = GBuffer(
[1.5.f32, 1.75.f32, 2.5.f32, 2.75.f32, (-1.5).f32, (-1.75).f32, (-2.5).f32, (-2.75).f32]
);
b.map(fn (val: gf32) = val.round).read{f32}.print;
}"#;
stdout "[2, 2, 2, 3, -2, -2, -2, -3]\n";
);

// TODO: Fix u64 numeric constants to get u64 bitwise tests in the new test suite
test!(u64_bitwise => r#"
prefix u64 as ~ precedence 10
Expand Down
15 changes: 14 additions & 1 deletion alan/src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ type Js = Env{"ALAN_OUTPUT_LANG"} == "js";

// Importing the Root Scope backing implementation and supporting 3rd party libraries
type{Rs} RootBacking = Rust{"alan_std" @ "https://github.com/alantech/alan.git"};
type{Js} RootBacking = Node{"alan_std" @ "https://github.com/alantech/alan.git"};
type{Js} RootBacking = Node{"alan_std" @ "https://github.com/alantech/alan.git#round"};

// Defining derived types
type void = ();
Expand Down Expand Up @@ -608,6 +608,8 @@ fn coth(x: f32) = 1.0.f32 / tanh(x);
fn asech(x: f32) = ln((1.0.f32 + sqrt(1.0.f32 - x ** 2.0.f32)) / x);
fn acsch(x: f32) = ln(1.0.f32 / x + sqrt(1.0.f32 / x ** 2.0.f32 + 1.0.f32));
fn acoth(x: f32) = ln((x + 1.0.f32) / (x - 1.0.f32)) / 2.0.f32;
fn{Rs} round Method{"round_ties_even"} :: Deref{f32} -> f32;
fn{Js} round Method{"roundTiesEven"} :: f32 -> f32;

fn{Rs} add Infix{"+"} :: (f64, f64) -> f64;
fn{Js} add "((a, b) => new alan_std.F64(a.val + b.val))" <- RootBacking :: (f64, f64) -> f64;
Expand Down Expand Up @@ -690,6 +692,8 @@ fn coth(x: f64) = 1.0 / tanh(x);
fn asech(x: f64) = ln((1.0 + sqrt(1.0 - x ** 2.0)) / x);
fn acsch(x: f64) = ln(1.0 / x + sqrt(1.0 / x ** 2.0 + 1.0));
fn acoth(x: f64) = ln((x + 1.0) / (x - 1.0)) / 2.0;
fn{Rs} round Method{"round_ties_even"} :: Deref{f64} -> f64;
fn{Js} round Method{"roundTiesEven"} :: f64 -> f64;

/// Unsigned Integer-related functions and function bindings
fn{Rs} add Method{"wrapping_add"} :: (u8, Deref{u8}) -> u8;
Expand Down Expand Up @@ -3895,6 +3899,15 @@ fn cross(a: gvec3f, b: gvec3f) {
return gvec3f(varName, statements, buffers);
}

fn gRound{I}(v: I) {
let varName = 'round('.concat(v.varName).concat(')');
return {I}(varName, v.statements, v.buffers);
}
fn round(v: gf32) = gRound(v);
fn round(v: gvec2f) = gRound(v);
fn round(v: gvec3f) = gRound(v);
fn round(v: gvec4f) = gRound(v);

fn gadd{A, B}(a: A, b: B) {
let varName = '('.concat(a.varName).concat(' + ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
Expand Down
22 changes: 20 additions & 2 deletions alan/test.ln
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,16 @@ export fn{Test} main {
.it("floor")
.assert(eq, 2.5.f32.floor, 2.f32)
.it("ceil")
.assert(eq, 2.5.f32.ceil, 3.f32);
.assert(eq, 2.5.f32.ceil, 3.f32)
.it('round', fn (test: Mut{Testing}) = test
.assert(eq, 1.5.f32.round, 2.f32)
.assert(eq, 1.75.f32.round, 2.f32)
.assert(eq, 2.5.f32.round, 2.f32)
.assert(eq, 2.75.f32.round, 3.f32)
.assert(eq, (-1.5).f32.round, (-2).f32)
.assert(eq, (-1.75).f32.round, (-2).f32)
.assert(eq, (-2.5).f32.round, (-2).f32)
.assert(eq, (-2.75).f32.round, (-3).f32));

test.describe("Basic math tests f64")
.it("add")
Expand All @@ -450,7 +459,16 @@ export fn{Test} main {
.it("floor")
.assert(eq, 2.5.floor, 2.0)
.it("ceil")
.assert(eq, 2.5.ceil, 3.0);
.assert(eq, 2.5.ceil, 3.0)
.it('round', fn (test: Mut{Testing}) = test
.assert(eq, 1.5.round, 2.0)
.assert(eq, 1.75.round, 2.0)
.assert(eq, 2.5.round, 2.0)
.assert(eq, 2.75.round, 3.0)
.assert(eq, -1.5.round, -2.0)
.assert(eq, -1.75.round, -2.0)
.assert(eq, -2.5.round, -2.0)
.assert(eq, -2.75.round, -3.0));

test.describe("Basic math tests")
.it("grouping")
Expand Down
13 changes: 13 additions & 0 deletions alan_std.js
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,19 @@ export class Float {
this.bits = bits;
}

roundTiesEven() {
// Matches the behavior of WGSL rounding logic (and my high school chemistry teacher years ago)
// by rounding `.5` towards the even number, so 1.5 becomes 2.0 and 2.5 becomes 2.0, which seems
// odd at first glance, but eliminates rounding direction bias that affects calculations across
// a dataset, so it really makes sense to be the default rounding rule.
let floored = Math.floor(this.val);
if (this.val - floored == 0.5) {
return this.build(floored % 2 == 0 ? floored : floored + 1);
} else {
return this.build(Math.round(this.val));
}
}

valueOf() {
return this.val;
}
Expand Down
18 changes: 18 additions & 0 deletions alan_std.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,21 @@ assert.deepEqual(alanStd.cross(
[new alanStd.F64(1), new alanStd.F64(0), new alanStd.F64(0)],
), [new alanStd.F64(0), new alanStd.F64(0), new alanStd.F64(-1)],
"cross([0, 1, 0], [1, 0, 0]) = [0, 0, -1]");

assert.equal(new alanStd.F32(1.5).roundTiesEven(), 2.0, "roundTiesEvenF32 1.5 == 2.0");
assert.equal(new alanStd.F32(1.75).roundTiesEven(), 2.0, "roundTiesEvenF32 1.75 == 2.0");
assert.equal(new alanStd.F32(2.5).roundTiesEven(), 2.0, "roundTiesEvenF32 2.5 == 2.0");
assert.equal(new alanStd.F32(2.75).roundTiesEven(), 3.0, "roundTiesEvenF32 2.75 == 3.0");
assert.equal(new alanStd.F32(-1.5).roundTiesEven(), -2.0, "roundTiesEvenF32 -1.5 == -2.0");
assert.equal(new alanStd.F32(-1.75).roundTiesEven(), -2.0, "roundTiesEvenF32 -1.75 == -2.0");
assert.equal(new alanStd.F32(-2.5).roundTiesEven(), -2.0, "roundTiesEvenF32 -2.5 == -2.0");
assert.equal(new alanStd.F32(-2.75).roundTiesEven(), -3.0, "roundTiesEvenF32 -2.75 == -3.0");

assert.equal(new alanStd.F64(1.5).roundTiesEven(), 2.0, "roundTiesEvenF64 1.5 == 2.0");
assert.equal(new alanStd.F64(1.75).roundTiesEven(), 2.0, "roundTiesEvenF64 1.75 == 2.0");
assert.equal(new alanStd.F64(2.5).roundTiesEven(), 2.0, "roundTiesEvenF64 2.5 == 2.0");
assert.equal(new alanStd.F64(2.75).roundTiesEven(), 3.0, "roundTiesEvenF64 2.75 == 3.0");
assert.equal(new alanStd.F64(-1.5).roundTiesEven(), -2.0, "roundTiesEvenF64 -1.5 == -2.0");
assert.equal(new alanStd.F64(-1.75).roundTiesEven(), -2.0, "roundTiesEvenF64 -1.75 == -2.0");
assert.equal(new alanStd.F64(-2.5).roundTiesEven(), -2.0, "roundTiesEvenF64 -2.5 == -2.0");
assert.equal(new alanStd.F64(-2.75).roundTiesEven(), -3.0, "roundTiesEvenF64 -2.75 == -3.0");

0 comments on commit 66f9496

Please sign in to comment.