Skip to content

Commit

Permalink
refactor(distribution): move ValidateBasic logic to msgServer
Browse files Browse the repository at this point in the history
  • Loading branch information
julienrbrt committed Apr 7, 2023
1 parent 9a5413d commit 16d36c7
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 209 deletions.
109 changes: 74 additions & 35 deletions x/distribution/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package keeper
import (
"context"

"cosmossdk.io/errors"
"github.com/armon/go-metrics"

errorsmod "cosmossdk.io/errors"

"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/errors"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/distribution/types"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
)
Expand All @@ -27,16 +26,17 @@ func NewMsgServerImpl(keeper Keeper) types.MsgServer {
}

func (k msgServer) SetWithdrawAddress(goCtx context.Context, msg *types.MsgSetWithdrawAddress) (*types.MsgSetWithdrawAddressResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

delegatorAddress, err := k.authKeeper.StringToBytes(msg.DelegatorAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid delegator address: %s", err)
}

withdrawAddress, err := k.authKeeper.StringToBytes(msg.WithdrawAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid withdraw address: %s", err)
}

ctx := sdk.UnwrapSDKContext(goCtx)
err = k.SetWithdrawAddr(ctx, delegatorAddress, withdrawAddress)
if err != nil {
return nil, err
Expand All @@ -46,16 +46,17 @@ func (k msgServer) SetWithdrawAddress(goCtx context.Context, msg *types.MsgSetWi
}

func (k msgServer) WithdrawDelegatorReward(goCtx context.Context, msg *types.MsgWithdrawDelegatorReward) (*types.MsgWithdrawDelegatorRewardResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

valAddr, err := sdk.ValAddressFromBech32(msg.ValidatorAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid validator address: %s", err)
}

delegatorAddress, err := k.authKeeper.StringToBytes(msg.DelegatorAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid delegator address: %s", err)
}

ctx := sdk.UnwrapSDKContext(goCtx)
amount, err := k.WithdrawDelegationRewards(ctx, delegatorAddress, valAddr)
if err != nil {
return nil, err
Expand All @@ -77,12 +78,12 @@ func (k msgServer) WithdrawDelegatorReward(goCtx context.Context, msg *types.Msg
}

func (k msgServer) WithdrawValidatorCommission(goCtx context.Context, msg *types.MsgWithdrawValidatorCommission) (*types.MsgWithdrawValidatorCommissionResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

valAddr, err := sdk.ValAddressFromBech32(msg.ValidatorAddress)
if err != nil {
return nil, err
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid validator address: %s", err)
}

ctx := sdk.UnwrapSDKContext(goCtx)
amount, err := k.Keeper.WithdrawValidatorCommission(ctx, valAddr)
if err != nil {
return nil, err
Expand All @@ -104,71 +105,85 @@ func (k msgServer) WithdrawValidatorCommission(goCtx context.Context, msg *types
}

func (k msgServer) FundCommunityPool(goCtx context.Context, msg *types.MsgFundCommunityPool) (*types.MsgFundCommunityPoolResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

depositer, err := k.authKeeper.StringToBytes(msg.Depositor)
depositor, err := k.authKeeper.StringToBytes(msg.Depositor)
if err != nil {
return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid depositor address: %s", err)
}

if err := validateAmount(msg.Amount); err != nil {
return nil, err
}
if err := k.Keeper.FundCommunityPool(ctx, msg.Amount, depositer); err != nil {

ctx := sdk.UnwrapSDKContext(goCtx)
if err := k.Keeper.FundCommunityPool(ctx, msg.Amount, depositor); err != nil {
return nil, err
}

return &types.MsgFundCommunityPoolResponse{}, nil
}

func (k msgServer) UpdateParams(goCtx context.Context, req *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if k.authority != req.Authority {
return nil, errorsmod.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, req.Authority)
func (k msgServer) UpdateParams(goCtx context.Context, msg *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if err := k.validateAuthority(msg.Authority); err != nil {
return nil, err
}

if (!msg.Params.BaseProposerReward.IsNil() && !msg.Params.BaseProposerReward.IsZero()) || //nolint:staticcheck // deprecated but kept for backwards compatibility
(!msg.Params.BonusProposerReward.IsNil() && !msg.Params.BonusProposerReward.IsZero()) { //nolint:staticcheck // deprecated but kept for backwards compatibility
return nil, errors.Wrapf(sdkerrors.ErrInvalidRequest, "cannot update base or bonus proposer reward because these are deprecated fields")
}

if (!req.Params.BaseProposerReward.IsNil() && !req.Params.BaseProposerReward.IsZero()) || //nolint:staticcheck // deprecated but kept for backwards compatibility
(!req.Params.BonusProposerReward.IsNil() && !req.Params.BonusProposerReward.IsZero()) { //nolint:staticcheck // deprecated but kept for backwards compatibility
return nil, errorsmod.Wrapf(errors.ErrInvalidRequest, "cannot update base or bonus proposer reward because these are deprecated fields")
if err := msg.Params.ValidateBasic(); err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(goCtx)
if err := k.SetParams(ctx, req.Params); err != nil {
if err := k.SetParams(ctx, msg.Params); err != nil {
return nil, err
}

return &types.MsgUpdateParamsResponse{}, nil
}

func (k msgServer) CommunityPoolSpend(goCtx context.Context, req *types.MsgCommunityPoolSpend) (*types.MsgCommunityPoolSpendResponse, error) {
if k.authority != req.Authority {
return nil, errorsmod.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, req.Authority)
func (k msgServer) CommunityPoolSpend(goCtx context.Context, msg *types.MsgCommunityPoolSpend) (*types.MsgCommunityPoolSpendResponse, error) {
if err := k.validateAuthority(msg.Authority); err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(goCtx)
if err := validateAmount(msg.Amount); err != nil {
return nil, err
}

recipient, err := k.authKeeper.StringToBytes(req.Recipient)
recipient, err := k.authKeeper.StringToBytes(msg.Recipient)
if err != nil {
return nil, err
}

if k.bankKeeper.BlockedAddr(recipient) {
return nil, errorsmod.Wrapf(errors.ErrUnauthorized, "%s is not allowed to receive external funds", req.Recipient)
return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive external funds", msg.Recipient)
}

if err := k.DistributeFromFeePool(ctx, req.Amount, recipient); err != nil {
ctx := sdk.UnwrapSDKContext(goCtx)
if err := k.DistributeFromFeePool(ctx, msg.Amount, recipient); err != nil {
return nil, err
}

logger := k.Logger(ctx)
logger.Info("transferred from the community pool to recipient", "amount", req.Amount.String(), "recipient", req.Recipient)
logger.Info("transferred from the community pool to recipient", "amount", msg.Amount.String(), "recipient", msg.Recipient)

return &types.MsgCommunityPoolSpendResponse{}, nil
}

func (k msgServer) DepositValidatorRewardsPool(goCtx context.Context, req *types.MsgDepositValidatorRewardsPool) (*types.MsgDepositValidatorRewardsPoolResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if err := k.validateAuthority(req.Authority); err != nil {
return nil, err
}

authority, err := k.authKeeper.StringToBytes(req.Authority)
if err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(goCtx)
// deposit coins from sender's account to the distribution module
if err := k.bankKeeper.SendCoinsFromAccountToModule(ctx, authority, types.ModuleName, req.Amount); err != nil {
return nil, err
Expand All @@ -181,7 +196,7 @@ func (k msgServer) DepositValidatorRewardsPool(goCtx context.Context, req *types

validator := k.stakingKeeper.Validator(ctx, valAddr)
if validator == nil {
return nil, errorsmod.Wrapf(types.ErrNoValidatorExists, valAddr.String())
return nil, errors.Wrapf(types.ErrNoValidatorExists, valAddr.String())
}

// Allocate tokens from the distribution module to the validator, which are
Expand All @@ -199,3 +214,27 @@ func (k msgServer) DepositValidatorRewardsPool(goCtx context.Context, req *types

return &types.MsgDepositValidatorRewardsPoolResponse{}, nil
}

func (k *Keeper) validateAuthority(authority string) error {
if _, err := k.authKeeper.StringToBytes(authority); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid authority address: %s", err)
}

if k.authority != authority {
return errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, authority)
}

return nil
}

func validateAmount(amount sdk.Coins) error {
if amount == nil {
return errors.Wrap(sdkerrors.ErrInvalidCoins, "amount cannot be nil")
}

if err := amount.Validate(); err != nil {
return errors.Wrap(sdkerrors.ErrInvalidCoins, amount.String())
}

return nil
}
79 changes: 0 additions & 79 deletions x/distribution/types/msg.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
package types

import (
"errors"

errorsmod "cosmossdk.io/errors"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
)

Expand Down Expand Up @@ -46,18 +41,6 @@ func (msg MsgSetWithdrawAddress) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// quick validity check
func (msg MsgSetWithdrawAddress) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.DelegatorAddress); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid delegator address: %s", err)
}
if _, err := sdk.AccAddressFromBech32(msg.WithdrawAddress); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid withdraw address: %s", err)
}

return nil
}

func NewMsgWithdrawDelegatorReward(delAddr sdk.AccAddress, valAddr sdk.ValAddress) *MsgWithdrawDelegatorReward {
return &MsgWithdrawDelegatorReward{
DelegatorAddress: delAddr.String(),
Expand All @@ -77,17 +60,6 @@ func (msg MsgWithdrawDelegatorReward) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// quick validity check
func (msg MsgWithdrawDelegatorReward) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.DelegatorAddress); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid delegator address: %s", err)
}
if _, err := sdk.ValAddressFromBech32(msg.ValidatorAddress); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid validator address: %s", err)
}
return nil
}

func NewMsgWithdrawValidatorCommission(valAddr sdk.ValAddress) *MsgWithdrawValidatorCommission {
return &MsgWithdrawValidatorCommission{
ValidatorAddress: valAddr.String(),
Expand All @@ -106,14 +78,6 @@ func (msg MsgWithdrawValidatorCommission) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// quick validity check
func (msg MsgWithdrawValidatorCommission) ValidateBasic() error {
if _, err := sdk.ValAddressFromBech32(msg.ValidatorAddress); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid validator address: %s", err)
}
return nil
}

// NewMsgFundCommunityPool returns a new MsgFundCommunityPool with a sender and
// a funding amount.
func NewMsgFundCommunityPool(amount sdk.Coins, depositor sdk.AccAddress) *MsgFundCommunityPool {
Expand All @@ -137,17 +101,6 @@ func (msg MsgFundCommunityPool) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// ValidateBasic performs basic MsgFundCommunityPool message validation.
func (msg MsgFundCommunityPool) ValidateBasic() error {
if !msg.Amount.IsValid() {
return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, msg.Amount.String())
}
if _, err := sdk.AccAddressFromBech32(msg.Depositor); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid depositor address: %s", err)
}
return nil
}

// GetSigners returns the signer addresses that are expected to sign the result
// of GetSignBytes.
func (msg MsgUpdateParams) GetSigners() []sdk.AccAddress {
Expand All @@ -162,20 +115,6 @@ func (msg MsgUpdateParams) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// ValidateBasic performs basic MsgUpdateParams message validation.
func (msg MsgUpdateParams) ValidateBasic() error {
if (!msg.Params.BaseProposerReward.IsNil() && !msg.Params.BaseProposerReward.IsZero()) ||
(!msg.Params.BonusProposerReward.IsNil() && !msg.Params.BonusProposerReward.IsZero()) {
return errors.New("base and bonus proposer reward are deprecated fields and should not be used")
}

if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid authority address: %s", err)
}

return msg.Params.ValidateBasic()
}

// GetSigners returns the signer addresses that are expected to sign the result
// of GetSignBytes, which is the authority.
func (msg MsgCommunityPoolSpend) GetSigners() []sdk.AccAddress {
Expand All @@ -190,15 +129,6 @@ func (msg MsgCommunityPoolSpend) GetSignBytes() []byte {
return sdk.MustSortJSON(bz)
}

// ValidateBasic performs basic MsgCommunityPoolSpend message validation.
func (msg MsgCommunityPoolSpend) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid authority address: %s", err)
}

return msg.Amount.Validate()
}

// NewMsgDepositValidatorRewardsPool returns a new MsgDepositValidatorRewardsPool
// with a sender and a funding amount.
func NewMsgDepositValidatorRewardsPool(depositor sdk.AccAddress, valAddr sdk.ValAddress, amount sdk.Coins) *MsgDepositValidatorRewardsPool {
Expand All @@ -222,12 +152,3 @@ func (msg MsgDepositValidatorRewardsPool) GetSignBytes() []byte {
bz := ModuleCdc.MustMarshalJSON(&msg)
return sdk.MustSortJSON(bz)
}

// ValidateBasic performs basic MsgDepositValidatorRewardsPool message validation.
func (msg MsgDepositValidatorRewardsPool) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil {
return sdkerrors.ErrInvalidAddress.Wrapf("invalid authority address: %s", err)
}

return msg.Amount.Validate()
}
Loading

0 comments on commit 16d36c7

Please sign in to comment.