Skip to content

Commit

Permalink
cmd: move threshold check to CLI level (#3297)
Browse files Browse the repository at this point in the history
A [recent commit](98b84e1) introduced a misbehavior when omitting the optional `--threshold` flag of `create dkg` and `create cluster` commands.
Because the threshold configuration is tested before the threshold variable is assigned to the default value `ceil(2*n/3)`, the flag is not optional anymore.

This PR fixes this bug by moving the checks at the CLI level and by updating the corresponding tests accordingly.

It also adds an input validation check on the [`ThresholdSplit`](https://github.com/ObolNetwork/charon/blob/ced30abb5a8c168b358a9bfc976fbe23927d72de/tbls/herumi.go#L133) and [`ThresholdSplitInsecure`](https://github.com/ObolNetwork/charon/blob/ced30abb5a8c168b358a9bfc976fbe23927d72de/tbls/herumi.go#L83) functions to ensure they are called with a threshold parameter greater than 1.

category: bug
ticket: none
  • Loading branch information
KaloyanTanev committed Nov 19, 2024
1 parent 159d6f2 commit 72549c6
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 63 deletions.
27 changes: 16 additions & 11 deletions cmd/createcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -100,6 +99,22 @@ func newCreateClusterCmd(runFunc func(context.Context, io.Writer, clusterConfig)
bindClusterFlags(cmd.Flags(), &conf)
bindInsecureFlags(cmd.Flags(), &conf.InsecureKeys)

wrapPreRunE(cmd, func(cmd *cobra.Command, _ []string) error {
thresholdPresent := cmd.Flags().Lookup("threshold").Changed

if thresholdPresent {
if conf.Threshold < minThreshold {
return errors.New("threshold must be greater than 1", z.Int("threshold", conf.Threshold), z.Int("min", minThreshold))
}
if conf.Threshold > conf.NumNodes {
return errors.New("threshold cannot be greater than number of operators",
z.Int("threshold", conf.Threshold), z.Int("operators", conf.NumNodes))
}
}

return nil
})

return cmd
}

Expand Down Expand Up @@ -374,16 +389,6 @@ func validateCreateConfig(ctx context.Context, conf clusterConfig) error {
return errors.New("number of operators is below minimum", z.Int("operators", conf.NumNodes), z.Int("min", minNodes))
}

// Check for threshold parameter
minThreshold := int(math.Ceil(float64(conf.NumNodes*2) / 3))
if conf.Threshold < minThreshold {
return errors.New("threshold cannot be smaller than BFT quorum", z.Int("threshold", conf.Threshold), z.Int("min", minThreshold))
}
if conf.Threshold > conf.NumNodes {
return errors.New("threshold cannot be greater than number of operators",
z.Int("threshold", conf.Threshold), z.Int("operators", conf.NumNodes))
}

return nil
}

Expand Down
96 changes: 76 additions & 20 deletions cmd/createcluster_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,26 +250,6 @@ func TestCreateCluster(t *testing.T) {
},
},
},
{
Name: "threshold greater than the number of operators",
Config: clusterConfig{
NumNodes: 4,
Threshold: 5,
NumDVs: 1,
Network: defaultNetwork,
},
expectedErr: "threshold cannot be greater than number of operators",
},
{
Name: "threshold smaller than BFT quorum",
Config: clusterConfig{
NumNodes: 4,
Threshold: 2,
NumDVs: 1,
Network: defaultNetwork,
},
expectedErr: "threshold cannot be smaller than BFT quorum",
},
{
Name: "test with number of nodes below minimum",
Config: clusterConfig{
Expand Down Expand Up @@ -788,6 +768,82 @@ func TestPublish(t *testing.T) {
})
}

func TestClusterCLI(t *testing.T) {
feeRecipientArg := "--fee-recipient-addresses=" + validEthAddr
withdrawalArg := "--withdrawal-addresses=" + validEthAddr

tests := []struct {
name string
network string
nodes string
numValidators string
feeRecipient string
withdrawal string
threshold string
expectedErr string
cleanup func(*testing.T)
}{
{
name: "threshold below minimum",
nodes: "--nodes=3",
network: "--network=holesky",
numValidators: "--num-validators=1",
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
threshold: "--threshold=1",
expectedErr: "threshold must be greater than 1",
},
{
name: "threshold above maximum",
nodes: "--nodes=4",
network: "--network=holesky",
numValidators: "--num-validators=1",
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
threshold: "--threshold=5",
expectedErr: "threshold cannot be greater than number of operators",
},
{
name: "no threshold provided",
nodes: "--nodes=3",
network: "--network=holesky",
numValidators: "--num-validators=1",
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
threshold: "",
expectedErr: "",
cleanup: func(t *testing.T) {
t.Helper()
require.NoError(t, os.RemoveAll("node0"))
require.NoError(t, os.RemoveAll("node1"))
require.NoError(t, os.RemoveAll("node2"))
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cmd := newCreateCmd(newCreateClusterCmd(runCreateCluster))
if test.threshold != "" {
cmd.SetArgs([]string{"cluster", test.nodes, test.feeRecipient, test.withdrawal, test.network, test.numValidators, test.threshold})
} else {
cmd.SetArgs([]string{"cluster", test.nodes, test.feeRecipient, test.withdrawal, test.network, test.numValidators})
}

err := cmd.Execute()
if test.expectedErr != "" {
require.ErrorContains(t, err, test.expectedErr)
} else {
require.NoError(t, err)
}

if test.cleanup != nil {
test.cleanup(t)
}
})
}
}

// mockKeymanagerReq is a mock keymanager request for use in tests.
type mockKeymanagerReq struct {
Keystores []string `json:"keystores"`
Expand Down
33 changes: 19 additions & 14 deletions cmd/createdkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
crand "crypto/rand"
"encoding/json"
"math"
"os"
"path"

Expand Down Expand Up @@ -50,6 +49,22 @@ func newCreateDKGCmd(runFunc func(context.Context, createDKGConfig) error) *cobr

bindCreateDKGFlags(cmd, &config)

wrapPreRunE(cmd, func(cmd *cobra.Command, _ []string) error {
thresholdPresent := cmd.Flags().Lookup("threshold").Changed

if thresholdPresent {
if config.Threshold < minThreshold {
return errors.New("threshold must be greater than 1", z.Int("threshold", config.Threshold), z.Int("min", minThreshold))
}
if config.Threshold > len(config.OperatorENRs) {
return errors.New("threshold cannot be greater than number of operators",
z.Int("threshold", config.Threshold), z.Int("operators", len(config.OperatorENRs)))
}
}

return nil
})

return cmd
}

Expand Down Expand Up @@ -82,7 +97,7 @@ func runCreateDKG(ctx context.Context, conf createDKGConfig) (err error) {
conf.Network = eth2util.Goerli.Name
}

if err = validateDKGConfig(conf.Threshold, len(conf.OperatorENRs), conf.Network, conf.DepositAmounts); err != nil {
if err = validateDKGConfig(len(conf.OperatorENRs), conf.Network, conf.DepositAmounts); err != nil {
return err
}

Expand Down Expand Up @@ -115,7 +130,7 @@ func runCreateDKG(ctx context.Context, conf createDKGConfig) (err error) {
safeThreshold := cluster.Threshold(len(conf.OperatorENRs))
if conf.Threshold == 0 {
conf.Threshold = safeThreshold
} else if conf.Threshold != safeThreshold {
} else {
log.Warn(ctx, "Non standard `--threshold` flag provided, this will affect cluster safety", nil, z.Int("threshold", conf.Threshold), z.Int("safe_threshold", safeThreshold))
}

Expand Down Expand Up @@ -181,22 +196,12 @@ func validateWithdrawalAddrs(addrs []string, network string) error {
}

// validateDKGConfig returns an error if any of the provided config parameter is invalid.
func validateDKGConfig(threshold, numOperators int, network string, depositAmounts []int) error {
func validateDKGConfig(numOperators int, network string, depositAmounts []int) error {
// Don't allow cluster size to be less than 3.
if numOperators < minNodes {
return errors.New("number of operators is below minimum", z.Int("operators", numOperators), z.Int("min", minNodes))
}

// Ensure threshold setting is sound
minThreshold := int(math.Ceil(float64(numOperators*2) / 3))
if threshold < minThreshold {
return errors.New("threshold cannot be smaller than BFT quorum", z.Int("threshold", threshold), z.Int("min", minThreshold))
}
if threshold > numOperators {
return errors.New("threshold cannot be greater than length of operators",
z.Int("threshold", threshold), z.Int("operators", numOperators))
}

if !eth2util.ValidNetwork(network) {
return errors.New("unsupported network", z.Str("network", network))
}
Expand Down
109 changes: 91 additions & 18 deletions cmd/createdkg_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,36 +184,109 @@ func TestValidateWithdrawalAddr(t *testing.T) {
}

func TestValidateDKGConfig(t *testing.T) {
t.Run("threshold exceeds numOperators", func(t *testing.T) {
threshold := 5
numOperators := 4
err := validateDKGConfig(threshold, numOperators, "", nil)
require.ErrorContains(t, err, "threshold cannot be greater than length of operators")
})

t.Run("threshold equals 1", func(t *testing.T) {
threshold := 1
numOperators := 3
err := validateDKGConfig(threshold, numOperators, "", nil)
require.ErrorContains(t, err, "threshold cannot be smaller than BFT quorum")
})

t.Run("insufficient ENRs", func(t *testing.T) {
threshold := 2
numOperators := 2
err := validateDKGConfig(threshold, numOperators, "", nil)
err := validateDKGConfig(numOperators, "", nil)
require.ErrorContains(t, err, "number of operators is below minimum")
})

t.Run("invalid network", func(t *testing.T) {
threshold := 3
numOperators := 4
err := validateDKGConfig(threshold, numOperators, "cosmos", nil)
err := validateDKGConfig(numOperators, "cosmos", nil)
require.ErrorContains(t, err, "unsupported network")
})

t.Run("wrong deposit amounts sum", func(t *testing.T) {
err := validateDKGConfig(3, 4, "goerli", []int{8, 16})
err := validateDKGConfig(4, "goerli", []int{8, 16})
require.ErrorContains(t, err, "sum of partial deposit amounts must sum up to 32ETH")
})
}

func TestDKGCLI(t *testing.T) {
var enrs []string
for range minNodes {
enrs = append(enrs, "enr:-JG4QG472ZVvl8ySSnUK9uNVDrP_hjkUrUqIxUC75aayzmDVQedXkjbqc7QKyOOS71VmlqnYzri_taV8ZesFYaoQSIOGAYHtv1WsgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQKwwq_CAld6oVKOrixE-JzMtvvNgb9yyI-_rwq4NFtajIN0Y3CCDhqDdWRwgg4u")
}
enrArg := "--operator-enrs=" + strings.Join(enrs, ",")
feeRecipientArg := "--fee-recipient-addresses=" + validEthAddr
withdrawalArg := "--withdrawal-addresses=" + validEthAddr
outputDirArg := "--output-dir=.charon"

tests := []struct {
name string
enr string
feeRecipient string
withdrawal string
outputDir string
threshold string
expectedErr string
prepare func(*testing.T)
cleanup func(*testing.T)
}{
{
name: "threshold below minimum",
enr: enrArg,
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
outputDir: outputDirArg,
threshold: "--threshold=1",
expectedErr: "threshold must be greater than 1",
},
{
name: "threshold above maximum",
enr: enrArg,
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
outputDir: outputDirArg,
threshold: "--threshold=4",
expectedErr: "threshold cannot be greater than number of operators",
},
{
name: "no threshold provided",
enr: enrArg,
feeRecipient: feeRecipientArg,
withdrawal: withdrawalArg,
outputDir: outputDirArg,
threshold: "",
expectedErr: "",
prepare: func(t *testing.T) {
t.Helper()
charonDir := testutil.CreateTempCharonDir(t)
b := []byte("sample definition")
require.NoError(t, os.WriteFile(path.Join(charonDir, "cluster-definition.json"), b, 0o600))
},
cleanup: func(t *testing.T) {
t.Helper()
err := os.RemoveAll(".charon")
require.NoError(t, err)
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.prepare != nil {
test.prepare(t)
}

cmd := newCreateCmd(newCreateDKGCmd(runCreateDKG))
if test.threshold != "" {
cmd.SetArgs([]string{"dkg", test.enr, test.feeRecipient, test.withdrawal, test.outputDir, test.threshold})
} else {
cmd.SetArgs([]string{"dkg", test.enr, test.feeRecipient, test.withdrawal, test.outputDir})
}

err := cmd.Execute()
if test.expectedErr != "" {
require.ErrorContains(t, err, test.expectedErr)
} else {
require.NoError(t, err)
}

if test.cleanup != nil {
test.cleanup(t)
}
})
}
}
8 changes: 8 additions & 0 deletions tbls/herumi.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ func (Herumi) ThresholdSplitInsecure(t *testing.T, secret PrivateKey, total uint
t.Helper()
var p bls.SecretKey

if threshold <= 1 {
return nil, errors.New("threshold has to be greater than 1")
}

if err := p.Deserialize(secret[:]); err != nil {
return nil, errors.Wrap(err, "cannot unmarshal bytes into Herumi secret key")
}
Expand Down Expand Up @@ -133,6 +137,10 @@ func (Herumi) ThresholdSplitInsecure(t *testing.T, secret PrivateKey, total uint
func (Herumi) ThresholdSplit(secret PrivateKey, total uint, threshold uint) (map[int]PrivateKey, error) {
var p bls.SecretKey

if threshold <= 1 {
return nil, errors.New("threshold has to be greater than 1")
}

if err := p.Deserialize(secret[:]); err != nil {
return nil, errors.Wrap(err, "cannot unmarshal bytes into Herumi secret key")
}
Expand Down

0 comments on commit 72549c6

Please sign in to comment.