diff --git a/src/gmalglib/core/sm9curve.c b/src/gmalglib/core/sm9curve.c index d0b91d3..0086c2c 100644 --- a/src/gmalglib/core/sm9curve.c +++ b/src/gmalglib/core/sm9curve.c @@ -983,7 +983,6 @@ int SM9FP2_MontHasSqrt(const SM9FP2Mont* x, SM9FP2Mont* y) static void SM9FP2_MontInv(const SM9FP2Mont* x, SM9FP2Mont* y) { - // use method of undetermined coefficients const SM9FP1Mont* x1 = x->fp1 + 1; const SM9FP1Mont* x0 = x->fp1; @@ -991,36 +990,19 @@ void SM9FP2_MontInv(const SM9FP2Mont* x, SM9FP2Mont* y) SM9FP1Mont* y1 = y_tmp.fp1 + 1; SM9FP1Mont* y0 = y_tmp.fp1; - if (UInt256_IsZero(x0)) - { - // (-1 / 2x1, 0) - UInt256_SetZero(y0); - SM9FP1_Add(x1, x1, y1); - SM9FP1_MontInv(y1, y1); - SM9FP1_Neg(y1, y1); - } - else if (UInt256_IsZero(x1)) - { - // (0, 1 / x0) - UInt256_SetZero(y1); - SM9FP1_MontInv(x0, y0); - } - else - { - // t = 1 / (x0^2 + 2(x1^2)) - SM9FP1_MontMul(x0, x0, y0); // x0^2 - SM9FP1_MontMul(x1, x1, y1); // x1^2 - SM9FP1_Add(y0, y1, y0); - SM9FP1_Add(y0, y1, y0); // x0^2 + 2(x1^2) - SM9FP1_MontInv(y0, y0); // 1 / (x0^2 + 2(x1^2)) - - // y1 = -tx1 - SM9FP1_MontMul(y0, x1, y1); - SM9FP1_Neg(y1, y1); - - // y0 = tx0 - SM9FP1_MontMul(y0, x0, y0); - } + // det = x0^2 + 2(x1^2) + SM9FP1_MontMul(x0, x0, y0); + SM9FP1_MontMul(x1, x1, y1); + SM9FP1_Add(y0, y1, y0); + SM9FP1_Add(y0, y1, y0); + SM9FP1_MontInv(y0, y0); + + // y1 = -x1 / det + SM9FP1_MontMul(y0, x1, y1); + SM9FP1_Neg(y1, y1); + + // y0 = x0 / det + SM9FP1_MontMul(y0, x0, y0); *y = y_tmp; }