Skip to content

Commit

Permalink
solver, solverrpc, csppsolver: Add RootFactors
Browse files Browse the repository at this point in the history
The RootFactors function provides the result of the polynomial factorization,
including too few results or repeated roots, without erroring for these
conditions.
  • Loading branch information
jrick committed May 6, 2024
1 parent 09ca221 commit 3c81672
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 20 deletions.
25 changes: 20 additions & 5 deletions cmd/csppsolver/solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,38 @@ type Args struct {

type Result struct {
Roots []*big.Int
Exponents []int
RepeatedRoot *big.Int
}

func (*Solver) RootFactors(args Args, res *Result) error {
roots, exps, err := solver.RootFactors(args.A, args.F)
if err != nil {
return err
}
res.Roots = roots
res.Exponents = exps
return nil
}

type repeatedRoot interface {
RepeatedRoot() *big.Int
}

func (*Solver) Roots(args Args, res *Result) error {
roots, err := solver.Roots(args.A, args.F)
if rr, ok := err.(repeatedRoot); ok {
res.RepeatedRoot = rr.RepeatedRoot()
return nil // error set by client package
}
roots, exps, err := solver.RootFactors(args.A, args.F)
if err != nil {
return err
}
for i, exp := range exps {
if exp != 1 {
res.RepeatedRoot = roots[i]
return nil // error set by client package
}
}

res.Roots = roots
res.Exponents = exps
return nil
}

Expand Down
44 changes: 29 additions & 15 deletions solver/solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,12 @@ func factorPoly(fac *C.fmpz_mod_poly_factor_struct, i uintptr) *C.fmpz_mod_poly_
return (*C.fmpz_mod_poly_struct)(unsafe.Pointer(uintptr(unsafe.Pointer(fac.poly)) + i*C.sizeof_fmpz_mod_poly_struct))
}

type repeatedRoot big.Int

func (r *repeatedRoot) Error() string { return "repeated roots" }
func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) }

// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F).
// Repeated roots are considered an error for the purposes of unique slot assignment.
func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
// RootFactors returns the roots and their number of solutions in the
// factorized polynomial. Repeated roots are an error in the mixing protocol
// but unlike the Roots function are not returned as an error here.
func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) {
if len(a) < 2 {
return nil, errors.New("too few coefficients")
return nil, nil, errors.New("too few coefficients")
}

var mod C.fmpz_t
Expand Down Expand Up @@ -80,6 +76,7 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
C.fmpz_mod_poly_factor(&factor[0], &poly[0], &modctx[0])

roots := make([]*big.Int, 0, len(a)-1)
exps := make([]int, 0, len(a)-1)
var m C.fmpz_t
C.fmpz_init(&m[0])
defer C.fmpz_clear(&m[0])
Expand All @@ -93,19 +90,36 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {

b, ok := new(big.Int).SetString(str, base)
if !ok {
return nil, errors.New("failed to read fmpz")
return nil, nil, errors.New("failed to read fmpz")
}
b.Neg(b)
b.Mod(b, F)

if factorExp(&factor[0], uintptr(i)) != 1 {
return nil, (*repeatedRoot)(b)
}
roots = append(roots, b)
exps = append(exps, int(factorExp(&factor[0], uintptr(i))))
}

return roots, exps, nil
}

type repeatedRoot big.Int

func (r *repeatedRoot) Error() string { return "repeated roots" }
func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) }

// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F).
// Repeated roots are considered an error for the purposes of unique slot
// assignment, and an error with method RepeatedRoot() *big.Int is returned.
func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
roots, exps, err := RootFactors(a, F)
if err != nil {
return roots, err
}

if len(roots) != len(a)-1 {
return nil, errors.New("too few roots")
for i, exp := range exps {
if exp != 1 {
return nil, (*repeatedRoot)(roots[i])
}
}

return roots, nil
Expand Down
27 changes: 27 additions & 0 deletions solver/solver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,33 @@ func TestRoots(t *testing.T) {
}
}

func TestRootFactors(t *testing.T) {
for i := range tests {
roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field)
if err != nil {
t.Error(err)
continue
}
for i, exp := range exps {
if exp != 1 {
t.Errorf("repeated root %v at index %v", roots[i], i)
continue
}
}
if len(roots) != len(tests[i].messages) {
t.Error("wrong root count")
continue
}
sortBig(tests[i].messages)
sortBig(roots)
for j := range roots {
if roots[j].Cmp(tests[i].messages[j]) != 0 {
t.Error("recovered wrong message")
}
}
}
}

func BenchmarkRoots(b *testing.B) {
for i := range tests {
b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) {
Expand Down
25 changes: 25 additions & 0 deletions solverrpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ func StartSolver() error {
return onceErr
}

// RootFactors returns the roots and their number of solutions in the
// factorized polynomial. Repeated roots are an error in the mixing protocol
// but unlike the Roots function are not returned as an error here.
func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) {
if err := StartSolver(); err != nil {
return nil, nil, err
}

var args struct {
A []*big.Int
F *big.Int
}
args.A = a
args.F = F
var result struct {
Roots []*big.Int
Exponents []int
}
err := client.Call("Solver.RootFactors", args, &result)
if err != nil {
return nil, nil, err
}
return result.Roots, result.Exponents, nil
}

type repeatedRoot big.Int

func (r *repeatedRoot) Error() string { return "repeated roots" }
Expand Down
27 changes: 27 additions & 0 deletions solverrpc/solver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,33 @@ func TestRoots(t *testing.T) {
}
}

func TestRootFactors(t *testing.T) {
for i := range tests {
roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field)
if err != nil {
t.Error(err)
continue
}
for i, exp := range exps {
if exp != 1 {
t.Errorf("repeated root %v at index %v", roots[i], i)
continue
}
}
if len(roots) != len(tests[i].messages) {
t.Error("wrong root count")
continue
}
sortBig(tests[i].messages)
sortBig(roots)
for j := range roots {
if roots[j].Cmp(tests[i].messages[j]) != 0 {
t.Error("recovered wrong message")
}
}
}
}

func BenchmarkRoots(b *testing.B) {
for i := range tests {
b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) {
Expand Down

0 comments on commit 3c81672

Please sign in to comment.