Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove NTT_BOUND_NATIVE and INVNTT_BOUND_NATIVE #645

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 0 additions & 70 deletions examples/monolithic_build/mlkem_native_monobuild.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,46 +378,11 @@
#undef CASSERT
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_CONCAT)
#undef MLKEM_CONCAT
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_CONCAT_)
#undef MLKEM_CONCAT_
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_DEBUG_H)
#undef MLKEM_DEBUG_H
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_STATIC_ASSERT_ADD_ERROR)
#undef MLKEM_STATIC_ASSERT_ADD_ERROR
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_STATIC_ASSERT_ADD_LINE0)
#undef MLKEM_STATIC_ASSERT_ADD_LINE0
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_STATIC_ASSERT_ADD_LINE1)
#undef MLKEM_STATIC_ASSERT_ADD_LINE1
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_STATIC_ASSERT_ADD_LINE2)
#undef MLKEM_STATIC_ASSERT_ADD_LINE2
#endif

/* mlkem/debug/debug.h */
#if defined(MLKEM_STATIC_ASSERT_DEFINE)
#undef MLKEM_STATIC_ASSERT_DEFINE
#endif

/* mlkem/debug/debug.h */
#if defined(POLYVEC_BOUND)
#undef POLYVEC_BOUND
Expand Down Expand Up @@ -493,11 +458,6 @@
#undef STATIC_ASSERT
#endif

/* mlkem/debug/debug.h */
#if defined(STATIC_ASSERT)
#undef STATIC_ASSERT
#endif

/* mlkem/debug/debug.h */
#if defined(UBOUND)
#undef UBOUND
Expand Down Expand Up @@ -973,11 +933,6 @@
#undef rej_uniform_table
#endif

/* mlkem/native/aarch64/src/clean_impl.h */
#if defined(INVNTT_BOUND_NATIVE)
#undef INVNTT_BOUND_NATIVE
#endif

/* mlkem/native/aarch64/src/clean_impl.h */
#if defined(MLKEM_NATIVE_ARITH_PROFILE_IMPL_H)
#undef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H
Expand Down Expand Up @@ -1038,11 +993,6 @@
#undef zetas_mulcache_twisted_native
#endif

/* mlkem/native/aarch64/src/opt_impl.h */
#if defined(INVNTT_BOUND_NATIVE)
#undef INVNTT_BOUND_NATIVE
#endif

/* mlkem/native/aarch64/src/opt_impl.h */
#if defined(MLKEM_NATIVE_ARITH_PROFILE_IMPL_H)
#undef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H
Expand Down Expand Up @@ -1088,11 +1038,6 @@
#undef MLKEM_USE_NATIVE_REJ_UNIFORM
#endif

/* mlkem/native/aarch64/src/opt_impl.h */
#if defined(NTT_BOUND_NATIVE)
#undef NTT_BOUND_NATIVE
#endif

/* mlkem/native/aarch64/src/rej_uniform_table.c */
#if defined(empty_cu_aarch64_rej_uniform_table)
#undef empty_cu_aarch64_rej_uniform_table
Expand Down Expand Up @@ -1408,11 +1353,6 @@
#undef qdata
#endif

/* mlkem/native/x86_64/src/default_impl.h */
#if defined(INVNTT_BOUND_NATIVE)
#undef INVNTT_BOUND_NATIVE
#endif

/* mlkem/native/x86_64/src/default_impl.h */
#if defined(MLKEM_NATIVE_ARITH_PROFILE_IMPL_H)
#undef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H
Expand Down Expand Up @@ -1468,11 +1408,6 @@
#undef MLKEM_USE_NATIVE_REJ_UNIFORM
#endif

/* mlkem/native/x86_64/src/default_impl.h */
#if defined(NTT_BOUND_NATIVE)
#undef NTT_BOUND_NATIVE
#endif

/* mlkem/native/x86_64/src/rej_uniform_avx2.c */
#if defined(empty_cu_rej_uniform_avx2)
#undef empty_cu_rej_uniform_avx2
Expand All @@ -1483,11 +1418,6 @@
#undef empty_cu_avx2_rej_uniform_table
#endif

/* mlkem/ntt.c */
#if defined(INVNTT_BOUND_REF)
#undef INVNTT_BOUND_REF
#endif

/* mlkem/ntt.c */
#if defined(invntt_layer)
#undef invntt_layer
Expand Down
21 changes: 0 additions & 21 deletions mlkem/debug/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,6 @@ void mlkem_debug_check_bounds(const char *file, int line,
"polyvec unsigned bound for " #ptr ".vec[i]"); \
} while (0)

#define MLKEM_CONCAT_(left, right) left##right
#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right)

/* Following AWS-LC to define a C99-compliant static assert */
#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \
typedef struct \
{ \
unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \
} MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \
__attribute__((unused));

#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \
MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix)
#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix))
#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error)

#else /* MLKEM_DEBUG */

#define CASSERT(val, msg) \
Expand Down
15 changes: 0 additions & 15 deletions mlkem/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,6 @@ __contract__(
}
}



STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0)

MLKEM_NATIVE_INTERNAL_API
void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
Expand Down Expand Up @@ -458,7 +454,6 @@ void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
matvec_mul(&pkpv, a, &skpv, &skpv_cache);
polyvec_tomont(&pkpv);

/* Arithmetic cannot overflow, see static assertion at the top */
polyvec_add(&pkpv, &e);
polyvec_reduce(&pkpv);
polyvec_reduce(&skpv);
Expand All @@ -468,11 +463,6 @@ void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
}


/* Check that the arithmetic in indcpa_enc() does not overflow */
STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0)
STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX,
indcpa_enc_bound_1)

MLKEM_NATIVE_INTERNAL_API
void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
const uint8_t m[MLKEM_INDCPA_MSGBYTES],
Expand Down Expand Up @@ -519,7 +509,6 @@ void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
polyvec_invntt_tomont(&b);
poly_invntt_tomont(&v);

/* Arithmetic cannot overflow, see static assertion at the top */
polyvec_add(&b, &ep);
poly_add(&v, &epp);
poly_add(&v, &k);
Expand All @@ -530,9 +519,6 @@ void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
pack_ciphertext(c, &b, &v);
}

/* Check that the arithmetic in indcpa_dec() does not overflow */
STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0)

MLKEM_NATIVE_INTERNAL_API
void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
const uint8_t c[MLKEM_INDCPA_BYTES],
Expand All @@ -548,7 +534,6 @@ void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
polyvec_basemul_acc_montgomery(&sb, &skpv, &b);
poly_invntt_tomont(&sb);

/* Arithmetic cannot overflow, see static assertion at the top */
poly_sub(&v, &sb);
poly_reduce(&v);

Expand Down
1 change: 0 additions & 1 deletion mlkem/native/aarch64/src/clean_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ static INLINE void ntt_native(poly *data)
aarch64_ntt_zetas_layer56);
}

#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q)
static INLINE void intt_native(poly *data)
{
intt_asm_clean(data->coeffs, aarch64_invntt_zetas_layer01234,
Expand Down
2 changes: 0 additions & 2 deletions mlkem/native/aarch64/src/opt_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@
#define MLKEM_USE_NATIVE_POLY_TOBYTES
#define MLKEM_USE_NATIVE_REJ_UNIFORM

#define NTT_BOUND_NATIVE (6 * MLKEM_Q)
static INLINE void ntt_native(poly *data)
{
ntt_asm_opt(data->coeffs, aarch64_ntt_zetas_layer01234,
aarch64_ntt_zetas_layer56);
}

#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q)
static INLINE void intt_native(poly *data)
{
intt_asm_opt(data->coeffs, aarch64_invntt_zetas_layer01234,
Expand Down
3 changes: 0 additions & 3 deletions mlkem/native/x86_64/src/default_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
#define MLKEM_USE_NATIVE_POLY_TOBYTES
#define MLKEM_USE_NATIVE_POLY_FROMBYTES

#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q)
#define NTT_BOUND_NATIVE (8 * MLKEM_Q)

static INLINE void poly_permute_bitrev_to_custom(poly *data)
{
nttunpack_avx2((__m256i *)(data->coeffs), qdata.vec);
Expand Down
16 changes: 3 additions & 13 deletions mlkem/ntt.c
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,17 @@ void poly_ntt(poly *p)
}
#else /* MLKEM_USE_NATIVE_NTT */

/* Check that bound for native NTT implies contractual bound */
STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound)

MLKEM_NATIVE_INTERNAL_API
void poly_ntt(poly *p)
{
POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input");
ntt_native(p);
POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output");
POLY_BOUND_MSG(p, NTT_BOUND, "native ntt output");
}
#endif /* MLKEM_USE_NATIVE_NTT */

#if !defined(MLKEM_USE_NATIVE_INTT)

/* Check that bound for reference invNTT implies contractual bound */
#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4)
STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound)

/* Compute one layer of inverse NTT */
static void invntt_layer(int16_t *r, int len, int layer)
__contract__(
Expand Down Expand Up @@ -232,18 +225,15 @@ void poly_invntt_tomont(poly *p)
invntt_layer(p->coeffs, len, layer);
}

POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output");
POLY_BOUND_MSG(p, INVNTT_BOUND, "ref intt output");
}
#else /* MLKEM_USE_NATIVE_INTT */

/* Check that bound for native invNTT implies contractual bound */
STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound)

MLKEM_NATIVE_INTERNAL_API
void poly_invntt_tomont(poly *p)
{
intt_native(p);
POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output");
POLY_BOUND_MSG(p, INVNTT_BOUND, "native intt output");
}
#endif /* MLKEM_USE_NATIVE_INTT */

Expand Down
Loading