Skip to content

Commit

Permalink
AArch64: Remove literal pools from native code
Browse files Browse the repository at this point in the history
This commit removes all literal pools from the native AArch64 assembly.
Those literal pools are slightly easier to read, but impede verification
using HOL-Light. Instead, constant vectors are prepared by loading immediates
into GPRs and copying/broadcasting them into the target vector.

Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
  • Loading branch information
hanno-becker committed Jan 17, 2025
1 parent e4ff720 commit 1427630
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 263 deletions.
40 changes: 10 additions & 30 deletions mlkem/native/aarch64/src/intt_clean.S
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

inp .req x3
count .req x4
xtmp .req x5
wtmp .req w5

data0 .req v8
data1 .req v9
Expand Down Expand Up @@ -193,40 +193,20 @@
t3 .req v28

ninv .req v29
q_ninv .req q29
ninv_tw .req v30
q_ninv_tw .req q30

/* Literal pool */
.macro dup8h c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.endm

.p2align 4
c_consts: .short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0
c_ninv: dup8h 512
c_ninv_tw: dup8h 5040

MLKEM_ASM_NAMESPACE(intt_asm_clean):
push_stack

ldr q_consts, c_consts
ldr q_ninv, c_ninv
ldr q_ninv_tw, c_ninv_tw
// Setup constants
mov wtmp, #3329
mov consts.h[0], wtmp
mov wtmp, #20159
mov consts.h[1], wtmp
mov wtmp, #512
dup ninv.8h, wtmp
mov wtmp, #5040
dup ninv_tw.8h, wtmp

mov inp, in
mov count, #8
Expand Down
40 changes: 10 additions & 30 deletions mlkem/native/aarch64/src/intt_opt.S
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

inp .req x3
count .req x4
xtmp .req x5
wtmp .req w5

data0 .req v8
data1 .req v9
Expand Down Expand Up @@ -193,40 +193,20 @@
t3 .req v28

ninv .req v29
q_ninv .req q29
ninv_tw .req v30
q_ninv_tw .req q30

/* Literal pool */
.macro dup8h c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.endm

.p2align 4
c_consts: .short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0
c_ninv: dup8h 512
c_ninv_tw: dup8h 5040

MLKEM_ASM_NAMESPACE(intt_asm_opt):
push_stack

ldr q_consts, c_consts
ldr q_ninv, c_ninv
ldr q_ninv_tw, c_ninv_tw
// Setup constants
mov wtmp, #3329
mov consts.h[0], wtmp
mov wtmp, #20159
mov consts.h[1], wtmp
mov wtmp, #512
dup ninv.8h, wtmp
mov wtmp, #5040
dup ninv_tw.8h, wtmp

mov inp, in
mov count, #8
Expand Down
20 changes: 6 additions & 14 deletions mlkem/native/aarch64/src/ntt_clean.S
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

inp .req x3
count .req x4
xtmp .req x5
wtmp .req w5

data0 .req v8
data1 .req v9
Expand Down Expand Up @@ -167,21 +167,13 @@
.text
.global MLKEM_ASM_NAMESPACE(ntt_asm_clean)

/* Literal pool */
.p2align 4
c_consts:
.short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0

MLKEM_ASM_NAMESPACE(ntt_asm_clean):
push_stack
ldr q_consts, c_consts

mov wtmp, #3329
mov consts.h[0], wtmp
mov wtmp, #20159
mov consts.h[1], wtmp

mov inp, in
mov count, #4
Expand Down
20 changes: 6 additions & 14 deletions mlkem/native/aarch64/src/ntt_opt.S
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

inp .req x3
count .req x4
xtmp .req x5
wtmp .req w5

data0 .req v8
data1 .req v9
Expand Down Expand Up @@ -167,21 +167,13 @@
.text
.global MLKEM_ASM_NAMESPACE(ntt_asm_opt)

/* Literal pool */
.p2align 4
c_consts:
.short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0

MLKEM_ASM_NAMESPACE(ntt_asm_opt):
push_stack
ldr q_consts, c_consts

mov wtmp, #3329
mov consts.h[0], wtmp
mov wtmp, #20159
mov consts.h[1], wtmp

mov inp, in
mov count, #4
Expand Down
78 changes: 27 additions & 51 deletions mlkem/native/aarch64/src/poly_clean.S
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,6 @@
#include "../../../common.h"
#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN)

/* We use a single literal pool for all functions in this file.
* This is OK even when the file gets expanded through SLOTHY,
* since PC-relative offets are up to 1MB in AArch64.
*
* The use of dup8h to build constant vectors in memory
* is slightly wasteful and could be avoided with a GPR-load
* followed by Neon `dup`, but we're ultimately only talking
* about 64 bytes, so it seems OK.
*/

.macro dup8h c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.short \c
.endm

.p2align 4
c_modulus: dup8h 3329 // ML-KEM modulus
c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27
c_mont_constant: dup8h -1044 // 2^16 % 3329
c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16)

/*
* Some modular arithmetic macros
*/
Expand Down Expand Up @@ -70,21 +43,23 @@ c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16)

ptr .req x0
count .req x1
wtmp .req w2

data .req v0
q_data .req q0

tmp .req v1
mask .req v2
modulus .req v3
q_modulus .req q3
modulus_twisted .req v4
q_modulus_twisted .req q4

MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean):

ldr q_modulus, c_modulus
ldr q_modulus_twisted, c_modulus_twisted
mov wtmp, #3329 // ML-KEM modulus
dup modulus.8h, wtmp

mov wtmp, #20159 // Barrett twist of 1 wrt 2^27
dup modulus_twisted.8h, wtmp

mov count, #8
loop_start:
Expand Down Expand Up @@ -115,16 +90,15 @@ loop_start:

.unreq ptr
.unreq count
.unreq wtmp

.unreq data
.unreq q_data

.unreq tmp
.unreq mask
.unreq modulus
.unreq q_modulus
.unreq modulus_twisted
.unreq q_modulus_twisted

/********************************************
* poly_mulcache_compute() *
Expand All @@ -137,6 +111,7 @@ loop_start:
zeta_ptr .req x2
zeta_twisted_ptr .req x3
count .req x4
wtmp .req w5

data_odd .req v0
zeta .req v1
Expand All @@ -152,13 +127,14 @@ loop_start:
q_dst .req q5

modulus .req v6
q_modulus .req q6
modulus_twisted .req v7
q_modulus_twisted .req q7

MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean):
ldr q_modulus, c_modulus
ldr q_modulus_twisted, c_modulus_twisted
mov wtmp, #3329
dup modulus.8h, wtmp

mov wtmp, #20159
dup modulus_twisted.8h, wtmp

mov count, #16
mulcache_compute_loop_start:
Expand All @@ -185,6 +161,7 @@ mulcache_compute_loop_start:
.unreq zeta_ptr
.unreq zeta_twisted_ptr
.unreq count
.unreq wtmp

.unreq data_odd
.unreq zeta
Expand All @@ -200,9 +177,7 @@ mulcache_compute_loop_start:
.unreq q_dst

.unreq modulus
.unreq q_modulus
.unreq modulus_twisted
.unreq q_modulus_twisted

/********************************************
* poly_tobytes() *
Expand Down Expand Up @@ -261,29 +236,33 @@ poly_tobytes_asm_clean_asm_loop_start:

src .req x0
count .req x1
wtmp .req w2

data .req v0
q_data .req q0
res .req v1
q_res .req q1

factor .req v2
q_factor .req q2
factor_t .req v3
q_factor_t .req q3
modulus .req v4
q_modulus .req q4
modulus_twisted .req v5
q_modulus_twisted .req q5

tmp0 .req v6

MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean):

ldr q_modulus, c_modulus
ldr q_modulus_twisted, c_modulus_twisted
ldr q_factor, c_mont_constant
ldr q_factor_t, c_barrett_twist
mov wtmp, #3329 // ML-KEM modulus
dup modulus.8h, wtmp

mov wtmp, #20159 // Barrett twist of 1 wrt 2^27
dup modulus_twisted.8h, wtmp

mov wtmp, #-1044 // 2^16 % 3329
dup factor.8h, wtmp

mov wtmp, #-10276 // Barrett twist of -1044 (wrt 2^16)
dup factor_t.8h, wtmp

mov count, #8
poly_tomont_asm_loop:
Expand Down Expand Up @@ -311,20 +290,17 @@ poly_tomont_asm_loop:

.unreq src
.unreq count
.unreq wtmp

.unreq data
.unreq q_data
.unreq res
.unreq q_res

.unreq factor
.unreq q_factor
.unreq factor_t
.unreq q_factor_t
.unreq modulus
.unreq q_modulus
.unreq modulus_twisted
.unreq q_modulus_twisted

.unreq tmp0

Expand Down
Loading

0 comments on commit 1427630

Please sign in to comment.