Skip to content

Commit

Permalink
feat: addition/subtraction acceleration via inline assembly
Browse files Browse the repository at this point in the history
  • Loading branch information
twiby committed Jul 8, 2024
1 parent a25836e commit 1d9ddc6
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 5 deletions.
103 changes: 101 additions & 2 deletions src/biguint/addition.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::{BigUint, IntDigits};
#[cfg(target_arch = "x86_64")]
use std::arch::asm;

use crate::big_digit::{self, BigDigit};
use crate::UsizePromotion;
Expand Down Expand Up @@ -45,6 +47,96 @@ fn adc(carry: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
u8::from(b || d)
}

/// Performs a part of the addition. Returns a tuple containing the carry state
/// and the number of integers that were added
///
/// By using as many registers as possible, we treat digits 5 by 5
#[cfg(target_arch = "x86_64")]
unsafe fn schoolbook_add_assign_x86_64(
lhs: *mut u64,
rhs: *const u64,
mut size: usize,
) -> (bool, usize) {
size /= 5;
if size == 0 {
return (false, 0);
}

let mut c: u8;
let mut idx = 0;

asm!(
// Clear the carry flag
"clc",

"3:",

// Copy a in registers
"mov {a_tmp1}, qword ptr [{a} + 8*{idx}]",
"mov {a_tmp2}, qword ptr [{a} + 8*{idx} + 8]",
"mov {a_tmp3}, qword ptr [{a} + 8*{idx} + 16]",
"mov {a_tmp4}, qword ptr [{a} + 8*{idx} + 24]",
"mov {a_tmp5}, qword ptr [{a} + 8*{idx} + 32]",

// Copy b in registers
"mov {b_tmp1}, qword ptr [{b} + 8*{idx}]",
"mov {b_tmp2}, qword ptr [{b} + 8*{idx} + 8]",
"mov {b_tmp3}, qword ptr [{b} + 8*{idx} + 16]",
"mov {b_tmp4}, qword ptr [{b} + 8*{idx} + 24]",
"mov {b_tmp5}, qword ptr [{b} + 8*{idx} + 32]",

// Perform the addition
"adc {a_tmp1}, {b_tmp1}",
"adc {a_tmp2}, {b_tmp2}",
"adc {a_tmp3}, {b_tmp3}",
"adc {a_tmp4}, {b_tmp4}",
"adc {a_tmp5}, {b_tmp5}",

// Copy the return values
"mov qword ptr [{a} + 8*{idx}], {a_tmp1}",
"mov qword ptr [{a} + 8*{idx} + 8], {a_tmp2}",
"mov qword ptr [{a} + 8*{idx} + 16], {a_tmp3}",
"mov qword ptr [{a} + 8*{idx} + 24], {a_tmp4}",
"mov qword ptr [{a} + 8*{idx} + 32], {a_tmp5}",

// Increment loop counter
// `inc` and `dec` aren't modifying carry flag
"inc {idx}",
"inc {idx}",
"inc {idx}",
"inc {idx}",
"inc {idx}",
"dec {size}",
"jnz 3b",

// Output carry flag and clear
"setc {c}",
"clc",

size = in(reg) size,
a = in(reg) lhs,
b = in(reg) rhs,
c = lateout(reg_byte) c,
idx = inout(reg) idx,

a_tmp1 = out(reg) _,
a_tmp2 = out(reg) _,
a_tmp3 = out(reg) _,
a_tmp4 = out(reg) _,
a_tmp5 = out(reg) _,

b_tmp1 = out(reg) _,
b_tmp2 = out(reg) _,
b_tmp3 = out(reg) _,
b_tmp4 = out(reg) _,
b_tmp5 = out(reg) _,

options(nostack),
);

(c > 0, idx)
}

/// Two argument addition of raw slices, `a += b`, returning the carry.
///
/// This is used when the data `Vec` might need to resize to push a non-zero carry, so we perform
Expand All @@ -55,10 +147,17 @@ fn adc(carry: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
pub(super) fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
debug_assert!(a.len() >= b.len());

let mut carry = 0;
let (a_lo, a_hi) = a.split_at_mut(b.len());

for (a, b) in a_lo.iter_mut().zip(b) {
// On x86_64 machine, perform most of the addition via inline assembly
#[cfg(target_arch = "x86_64")]
let (c, done) = unsafe { schoolbook_add_assign_x86_64(a_lo.as_mut_ptr(), b.as_ptr(), b.len()) };
#[cfg(not(target_arch = "x86_64"))]
let (c, done) = (false, 0);

let mut carry = c as u8;

for (a, b) in a_lo[done..].iter_mut().zip(b[done..].iter()) {
carry = adc(carry, *a, *b, a);
}

Expand Down
104 changes: 101 additions & 3 deletions src/biguint/subtraction.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::BigUint;
#[cfg(target_arch = "x86_64")]
use std::arch::asm;

use crate::big_digit::{self, BigDigit};
use crate::UsizePromotion;
Expand Down Expand Up @@ -45,14 +47,110 @@ fn sbb(borrow: u8, lhs: BigDigit, rhs: BigDigit, out: &mut BigDigit) -> u8 {
u8::from(b || d)
}

pub(super) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
let mut borrow = 0;
/// Performs a part of the subtraction. Returns a tuple containing the carry state
/// and the number of integers that were subtracted
///
/// By using as many registers as possible, we treat digits 5 by 5
#[cfg(target_arch = "x86_64")]
unsafe fn schoolbook_sub_assign_x86_64(
lhs: *mut u64,
rhs: *const u64,
mut size: usize,
) -> (bool, usize) {
size /= 5;
if size == 0 {
return (false, 0);
}

let mut c: u8;
let mut idx = 0;

asm!(
// Clear carry flag
"clc",

"3:",

// Copy a in registers
"mov {a_tmp1}, qword ptr [{a} + 8*{idx}]",
"mov {a_tmp2}, qword ptr [{a} + 8*{idx} + 8]",
"mov {a_tmp3}, qword ptr [{a} + 8*{idx} + 16]",
"mov {a_tmp4}, qword ptr [{a} + 8*{idx} + 24]",
"mov {a_tmp5}, qword ptr [{a} + 8*{idx} + 32]",

// Copy b in registers
"mov {b_tmp1}, qword ptr [{b} + 8*{idx}]",
"mov {b_tmp2}, qword ptr [{b} + 8*{idx} + 8]",
"mov {b_tmp3}, qword ptr [{b} + 8*{idx} + 16]",
"mov {b_tmp4}, qword ptr [{b} + 8*{idx} + 24]",
"mov {b_tmp5}, qword ptr [{b} + 8*{idx} + 32]",

// Perform the subtraction
"sbb {a_tmp1}, {b_tmp1}",
"sbb {a_tmp2}, {b_tmp2}",
"sbb {a_tmp3}, {b_tmp3}",
"sbb {a_tmp4}, {b_tmp4}",
"sbb {a_tmp5}, {b_tmp5}",

// Copy the return values
"mov qword ptr [{a} + 8*{idx}], {a_tmp1}",
"mov qword ptr [{a} + 8*{idx} + 8], {a_tmp2}",
"mov qword ptr [{a} + 8*{idx} + 16], {a_tmp3}",
"mov qword ptr [{a} + 8*{idx} + 24], {a_tmp4}",
"mov qword ptr [{a} + 8*{idx} + 32], {a_tmp5}",

// Increment loop counter
// `inc` and `dec` aren't modifying carry flag
"inc {idx}",
"inc {idx}",
"inc {idx}",
"inc {idx}",
"inc {idx}",
"dec {size}",
"jnz 3b",

// Output carry flag and clear
"setc {c}",
"clc",

size = in(reg) size,
a = in(reg) lhs,
b = in(reg) rhs,
c = lateout(reg_byte) c,
idx = inout(reg) idx,

a_tmp1 = out(reg) _,
a_tmp2 = out(reg) _,
a_tmp3 = out(reg) _,
a_tmp4 = out(reg) _,
a_tmp5 = out(reg) _,

b_tmp1 = out(reg) _,
b_tmp2 = out(reg) _,
b_tmp3 = out(reg) _,
b_tmp4 = out(reg) _,
b_tmp5 = out(reg) _,

options(nostack),
);

(c > 0, idx)
}

pub(super) fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
let len = Ord::min(a.len(), b.len());
let (a_lo, a_hi) = a.split_at_mut(len);
let (b_lo, b_hi) = b.split_at(len);

for (a, b) in a_lo.iter_mut().zip(b_lo) {
// On x86_64 machine, perform most of the subtraction via inline assembly
#[cfg(target_arch = "x86_64")]
let (b, done) = unsafe { schoolbook_sub_assign_x86_64(a_lo.as_mut_ptr(), b_lo.as_ptr(), len) };
#[cfg(not(target_arch = "x86_64"))]
let (b, done) = (false, 0);

let mut borrow = b as u8;

for (a, b) in a_lo[done..].iter_mut().zip(b_lo[done..].iter()) {
borrow = sbb(borrow, *a, *b, a);
}

Expand Down

0 comments on commit 1d9ddc6

Please sign in to comment.