diff --git a/batch-submitter/drivers/proposer/driver.go b/batch-submitter/drivers/proposer/driver.go index ef6ad2a44..4a58b4cd3 100644 --- a/batch-submitter/drivers/proposer/driver.go +++ b/batch-submitter/drivers/proposer/driver.go @@ -657,8 +657,8 @@ func (d *Driver) RequestTssSignature(requestType uint64, start, offsetStartsAtIn var tssResponse tssClient.TssResponse tssReqParams := tss_types.SignStateRequest{ Type: requestType, - StartBlock: start.String(), - OffsetStartsAtIndex: offsetStartsAtIndex.String(), + StartBlock: start, + OffsetStartsAtIndex: offsetStartsAtIndex, Challenge: challenge, StateRoots: stateRoots, } diff --git a/tss/common/types.go b/tss/common/types.go index 7d870ff02..977016274 100644 --- a/tss/common/types.go +++ b/tss/common/types.go @@ -3,6 +3,7 @@ package common import ( "encoding/hex" "fmt" + "math/big" "github.com/ethereum/go-ethereum/common" ) @@ -29,8 +30,8 @@ func (m Method) String() string { type SignStateRequest struct { Type uint64 `json:"type"` - StartBlock string `json:"start_block"` - OffsetStartsAtIndex string `json:"offset_starts_at_index"` + StartBlock *big.Int `json:"start_block"` + OffsetStartsAtIndex *big.Int `json:"offset_starts_at_index"` Challenge string `json:"challenge"` StateRoots [][32]byte `json:"state_roots"` ElectionId uint64 `json:"election_id"` @@ -41,7 +42,7 @@ func (ssr SignStateRequest) String() string { for _, sr := range ssr.StateRoots { srs = srs + hex.EncodeToString(sr[:]) + " " } - return fmt.Sprintf("start_block: %s, offset_starts_at_index: %s, election_id: %d, state_roots: %s", ssr.StartBlock, ssr.OffsetStartsAtIndex, ssr.ElectionId, srs) + return fmt.Sprintf("start_block: %v, offset_starts_at_index: %v, election_id: %d, state_roots: %s", ssr.StartBlock, ssr.OffsetStartsAtIndex, ssr.ElectionId, srs) } type SlashRequest struct { @@ -51,7 +52,7 @@ type SlashRequest struct { } type RollBackRequest struct { - StartBlock string `json:"start_block"` + StartBlock *big.Int `json:"start_block"` } type AskResponse struct { diff --git a/tss/manager/manage.go b/tss/manager/manage.go index b5abac78e..ba23d894f 100644 --- a/tss/manager/manage.go +++ b/tss/manager/manage.go @@ -154,8 +154,7 @@ func (m *Manager) recoverGenerateKey() { func (m *Manager) SignStateBatch(request tss.SignStateRequest) ([]byte, error) { log.Info("received sign state request", "start block", request.StartBlock, "len", len(request.StateRoots), "index", request.OffsetStartsAtIndex) - offsetStartsAtIndex, _ := new(big.Int).SetString(request.OffsetStartsAtIndex, 10) - digestBz, err := tss.StateBatchHash(request.StateRoots, offsetStartsAtIndex) + digestBz, err := tss.StateBatchHash(request.StateRoots, request.OffsetStartsAtIndex) if err != nil { return nil, err } @@ -222,9 +221,8 @@ func (m *Manager) SignStateBatch(request tss.SignStateRequest) ([]byte, error) { //change unApprovals to approvals to do sign ctx = ctx.WithApprovers(ctx.UnApprovers()) rollback = true - startBlock, _ := new(big.Int).SetString(request.StartBlock, 10) rollBackRequest := tss.RollBackRequest{StartBlock: request.StartBlock} - rollBackBz, err := tss.RollBackHash(startBlock) + rollBackBz, err := tss.RollBackHash(request.StartBlock) if err != nil { return nil, err } @@ -313,10 +311,8 @@ func (m *Manager) SignRollBack(request tss.SignStateRequest) ([]byte, error) { } var resp tss.SignResponse - - startBlock, _ := new(big.Int).SetString(request.StartBlock, 10) rollBackRequest := tss.RollBackRequest{StartBlock: request.StartBlock} - rollBackBz, err := tss.RollBackHash(startBlock) + rollBackBz, err := tss.RollBackHash(request.StartBlock) if err != nil { return nil, err } diff --git a/tss/manager/router/registry.go b/tss/manager/router/registry.go index 0cc0fde73..aa750bce8 100644 --- a/tss/manager/router/registry.go +++ b/tss/manager/router/registry.go @@ -35,15 +35,11 @@ func (registry *Registry) SignStateHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, errors.New("invalid request body")) return } - - _, succ := new(big.Int).SetString(request.StartBlock, 10) - if !succ { - c.JSON(http.StatusBadRequest, errors.New("wrong StartBlock, can not be converted to number")) - return - } - _, succ = new(big.Int).SetString(request.OffsetStartsAtIndex, 10) - if !succ { - c.JSON(http.StatusBadRequest, errors.New("wrong OffsetStartsAtIndex, can not be converted to number")) + if request.StartBlock == nil || + request.OffsetStartsAtIndex == nil || + request.StartBlock.Cmp(big.NewInt(0)) < 0 || + request.OffsetStartsAtIndex.Cmp(big.NewInt(0)) < 0 { + c.JSON(http.StatusBadRequest, errors.New("StartBlock and OffsetStartsAtIndex must not be nil or negative")) return } var signature []byte diff --git a/tss/manager/setup_test.go b/tss/manager/setup_test.go index ea768628a..32ceaa885 100644 --- a/tss/manager/setup_test.go +++ b/tss/manager/setup_test.go @@ -1,6 +1,7 @@ package manager import ( + "math/big" "time" tss "github.com/mantlenetworkio/mantle/tss/common" @@ -52,8 +53,8 @@ func setup(afterMsgSent afterMsgSendFunc, queryAliveNodes queryAliveNodesFunc) ( cpkConfirmTimeout: 5 * time.Second, } request := tss.SignStateRequest{ - StartBlock: "1", - OffsetStartsAtIndex: "1", + StartBlock: big.NewInt(1), + OffsetStartsAtIndex: big.NewInt(1), StateRoots: [][32]byte{}, } return manager, request diff --git a/tss/node/signer/sign.go b/tss/node/signer/sign.go index a893e1599..45bfa8e8f 100644 --- a/tss/node/signer/sign.go +++ b/tss/node/signer/sign.go @@ -52,11 +52,20 @@ func (p *Processor) Sign() { } var requestBody tsscommon.SignStateRequest if err := json.Unmarshal(rawMsg, &requestBody); err != nil { - logger.Error().Msg("failed to umarshal ask's params request body") + logger.Error().Msg("failed to unmarshal asker's params request body") RpcResponse := tdtypes.NewRPCErrorResponse(req.ID, 201, "failed", err.Error()) p.wsClient.SendMsg(RpcResponse) continue } + if requestBody.StartBlock == nil || + requestBody.OffsetStartsAtIndex == nil || + requestBody.StartBlock.Cmp(big.NewInt(0)) < 0 || + requestBody.OffsetStartsAtIndex.Cmp(big.NewInt(0)) < 0 { + logger.Error().Msg("StartBlock and OffsetStartsAtIndex must not be nil or negative") + RpcResponse := tdtypes.NewRPCErrorResponse(req.ID, 201, "failed", "StartBlock and OffsetStartsAtIndex must not be nil or negative") + p.wsClient.SendMsg(RpcResponse) + continue + } nodeSignRequest.RequestBody = requestBody go p.SignGo(req.ID.(tdtypes.JSONRPCStringID), nodeSignRequest, logger) @@ -203,8 +212,7 @@ func (p *Processor) checkMessages(sign tsscommon.SignStateRequest) (err error, h } func signMsgToHash(msg tsscommon.SignStateRequest) ([]byte, error) { - offsetStartsAtIndex, _ := new(big.Int).SetString(msg.OffsetStartsAtIndex, 10) - return tsscommon.StateBatchHash(msg.StateRoots, offsetStartsAtIndex) + return tsscommon.StateBatchHash(msg.StateRoots, msg.OffsetStartsAtIndex) } func (p *Processor) removeWaitEvent(key string) { diff --git a/tss/node/signer/sign_rollback.go b/tss/node/signer/sign_rollback.go index c2678b870..4d9a9fd69 100644 --- a/tss/node/signer/sign_rollback.go +++ b/tss/node/signer/sign_rollback.go @@ -45,9 +45,15 @@ func (p *Processor) SignRollBack() { p.wsClient.SendMsg(RpcResponse) continue } + if requestBody.StartBlock == nil || + requestBody.StartBlock.Cmp(big.NewInt(0)) < 0 { + logger.Error().Msg("StartBlock must not be nil or negative") + RpcResponse := tdtypes.NewRPCErrorResponse(req.ID, 201, "failed", "StartBlock must not be nil or negative") + p.wsClient.SendMsg(RpcResponse) + continue + } nodeSignRequest.RequestBody = requestBody - startBlock, _ := new(big.Int).SetString(requestBody.StartBlock, 10) - hashTx, err := tsscommon.RollBackHash(startBlock) + hashTx, err := tsscommon.RollBackHash(requestBody.StartBlock) if err != nil { logger.Err(err).Msg("failed to encode roll back msg") RpcResponse := tdtypes.NewRPCErrorResponse(req.ID, 201, "failed", err.Error()) diff --git a/tss/node/signer/verify.go b/tss/node/signer/verify.go index 9b1dc8ccb..cc742d4b3 100644 --- a/tss/node/signer/verify.go +++ b/tss/node/signer/verify.go @@ -36,6 +36,15 @@ func (p *Processor) Verify() { p.wsClient.SendMsg(RpcResponse) continue } + if askRequest.StartBlock == nil || + askRequest.OffsetStartsAtIndex == nil || + askRequest.StartBlock.Cmp(big.NewInt(0)) < 0 || + askRequest.OffsetStartsAtIndex.Cmp(big.NewInt(0)) < 0 { + logger.Error().Msg("StartBlock and OffsetStartsAtIndex must not be nil or negative") + RpcResponse = tdtypes.NewRPCErrorResponse(req.ID, 201, "invalid askRequest", "StartBlock and OffsetStartsAtIndex must not be nil or negative") + p.wsClient.SendMsg(RpcResponse) + return + } var resId = req.ID var size = len(askRequest.StateRoots) logger.Info().Msgf("stateroots size %d ", size) @@ -79,12 +88,11 @@ func (p *Processor) Verify() { }() } -func (p *Processor) verify(start string, index int, stateRoot [32]byte, logger zerolog.Logger, wg *sync.WaitGroup) (bool, error) { +func (p *Processor) verify(start *big.Int, index int, stateRoot [32]byte, logger zerolog.Logger, wg *sync.WaitGroup) (bool, error) { defer wg.Done() offset := new(big.Int).SetInt64(int64(index)) - startBig, _ := new(big.Int).SetString(start, 10) - blockNumber := offset.Add(offset, startBig) + blockNumber := offset.Add(offset, start) logger.Info().Msgf("start to query block by number %d", blockNumber) value, ok := p.GetVerify(blockNumber.String())