//by Aashish Dugar #ifndef __primefield_header #define __primefield_header #include <gmp.h> /** * Adds two numbers which are in the prime field * * This is similar to normal addition except that the result * is bound between 0 and p. * See https://www.johannes-bauer.com/compsci/ecc/#anchor03 for details. * * res is the return variable. It must be initialized. * a and b are the numbers to add. They have to be within the prime field. * p is the prime number defining the field. */ void prime_field_add(mpz_t res, mpz_t a, mpz_t b, mpz_t p) { mpz_t tmp; mpz_init(tmp); mpz_add(tmp, a, b); if (mpz_cmp(tmp, p) >= 0) mpz_sub(res, tmp, p); else if (mpz_cmp_ui(tmp, 0UL) < 0) mpz_add(res, tmp, p); else mpz_set(res, tmp); mpz_clear(tmp); } /** * Subtracts two numbers which are in the prime field * * This is similar to normal subtraction except that the result * is bound between 0 and p. * See https://www.johannes-bauer.com/compsci/ecc/#anchor03 for details. * * res is the return variable. It must be initialized. * a and b are the numbers to subtract. They have to be within the prime field. * p is the prime number defining the field. */ void prime_field_sub(mpz_t res, mpz_t a, mpz_t b, mpz_t p) { mpz_t tmp; mpz_init(tmp); mpz_neg(tmp, b); prime_field_add(res, a, tmp, p); mpz_clear(tmp); } /** * Multiplies two numbers which are in the prime field * * The function loops copies b into a throwaway variable and loops * over the bits of b, starting with most significant bit. If the * bit is set, it adds the value of the copied throwaway to the result. * Then it doubles the value of the throwaway. All operations are * prime field operations. * See https://www.johannes-bauer.com/compsci/ecc/#anchor05 for details. * * res is the return variable. It must be initialized. * a and b are the numbers to multiply. They have to be within the prime field. * p is the prime number defining the field. */ void prime_field_mul(mpz_t res, mpz_t a, mpz_t b, mpz_t p) { mpz_t copy; mpz_t tmp; mpz_init_set(copy, a); mpz_init(tmp); mpz_set_ui(res, 0UL); char *bits = mpz_get_str(NULL, 2, b); size_t bitlength = strlen(bits); int i; for (i = bitlength - 1; i >= 0; i--) { if (bits[i] == '1') { prime_field_add(tmp, res, copy, p); mpz_set(res, tmp); } prime_field_add(tmp, copy, copy, p); mpz_set(copy, tmp); } mpz_clear(copy); mpz_clear(tmp); free(bits); } /** * Divides two numbers which are in the prime field * * The function first calculates the inverse of b in the prime field, * and then multiplies a with that number to get the result. * See https://www.johannes-bauer.com/compsci/ecc/#anchor07 for details. * * res is the return variable. It must be initialized. * a is the dividend and b is the divisor. Both must be in the prime field. * p is the prime number defining the field. */ void prime_field_div(mpz_t res, mpz_t a, mpz_t b, mpz_t p) { mpz_t q, r, s, t, u, v, copy_b, copy_p, u_new, v_new, tmp; mpz_init(q); mpz_init(r); mpz_init_set_ui(s, 1UL); mpz_init_set_ui(t, 0UL); mpz_init_set_ui(u, 0UL); mpz_init_set_ui(v, 1UL); mpz_init_set(copy_b, b); mpz_init_set(copy_p, p); mpz_init(u_new); mpz_init(v_new); mpz_init(tmp); while (mpz_cmp_ui(copy_p, 0UL) != 0) { mpz_fdiv_qr(q, r, a, copy_p); mpz_set(u_new, s); mpz_set(v_new, t); mpz_mul(tmp, q, s); mpz_sub(s, u, tmp); mpz_mul(tmp, q, t); mpz_sub(t, v, tmp); mpz_set(copy_b, copy_p); mpz_set(copy_p, r); mpz_set(u, u_new); mpz_set(v, v_new); } prime_field_mul(res, a, u, p); mpz_clear(q); mpz_clear(r); mpz_clear(s); mpz_clear(t); mpz_clear(u); mpz_clear(v); mpz_clear(copy_p); mpz_clear(copy_b); mpz_clear(u_new); mpz_clear(v_new); mpz_clear(tmp); } /** * Squares a number in the prime field * * This is uses the same approach as multiplication. * See https://www.johannes-bauer.com/compsci/ecc/#anchor09 for details * * res is the return variable. It must be initialized. * a is the number to square. * p is the prime number defining the field. */ void prime_field_sq(mpz_t res, mpz_t a, mpz_t p) { mpz_t copy; mpz_t tmp; mpz_init_set(copy, a); mpz_init(tmp); mpz_set_ui(res, 1UL); char *bits = "10"; int i; for (i = 1; i >= 0; i--) { if (bits[i] == '1') { prime_field_mul(tmp, res, copy, p); mpz_set(res, tmp); } prime_field_mul(tmp, copy, copy, p); mpz_set(copy, tmp); } mpz_clear(tmp); mpz_clear(copy); } /** * Converts a hex-string representation of a scalar to * a GMP integer * * scalar is an uninitialized pointer to the result * str is the hex string containing the number */ int str_to_scalar(mpz_t scalar, const char *str) { return mpz_init_set_str(scalar, str, 16); } /** * Returns the hex-string for the given scalar * * The string is null terminated but the calculated length * excludes the null terminator. * * scalar is the number to convert * *len is a pointer which will hold the length of the result */ char *scalar_to_str(mpz_t scalar, size_t *len) { *len = mpz_sizeinbase(scalar, 16) + 2; char *str = malloc((*len) * sizeof(*str)); mpz_get_str(str, 16, scalar); *len = strlen(str); return str; } #endif