diff --git a/ethereum/eip712/domain.go b/ethereum/eip712/domain.go index 7e2c1c2000..d768a17ab8 100644 --- a/ethereum/eip712/domain.go +++ b/ethereum/eip712/domain.go @@ -16,19 +16,22 @@ package eip712 import ( + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/signer/core/apitypes" ) // createEIP712Domain creates the typed data domain for the given chainID. func createEIP712Domain(chainID uint64) apitypes.TypedDataDomain { + contractStr := "cosmos" + contractAddr := common.BytesToAddress([]byte(contractStr)) domain := apitypes.TypedDataDomain{ Name: "Carbon", Version: "1.0.0", ChainId: math.NewHexOrDecimal256(int64(chainID)), - VerifyingContract: "cosmos", + VerifyingContract: contractAddr.Hex(), Salt: "1", } - + return domain } diff --git a/ethereum/eip712/encoding.go b/ethereum/eip712/encoding.go index 772b8eb376..c5006c9412 100644 --- a/ethereum/eip712/encoding.go +++ b/ethereum/eip712/encoding.go @@ -18,11 +18,13 @@ package eip712 import ( "errors" "fmt" + "strings" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" sdk "github.com/cosmos/cosmos-sdk/types" txTypes "github.com/cosmos/cosmos-sdk/types/tx" + evmkeeper "github.com/evmos/ethermint/x/evm/keeper" sdktestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" apitypes "github.com/ethereum/go-ethereum/signer/core/apitypes" @@ -84,6 +86,31 @@ func isValidEIP712Payload(typedData apitypes.TypedData) bool { return len(typedData.Message) != 0 && len(typedData.Types) != 0 && typedData.PrimaryType != "" && typedData.Domain != apitypes.TypedDataDomain{} } +// getChainIDFromMemo extracts the signed chain ID from the memo field of a cross chain signing transaction. +// This is for cross-chain EIP712. +func getChainIDFromMemo(memo string) (string, error) { + splitMemoCrossChain := strings.Split(memo, "|CROSSCHAIN-SIGNING|") + if len(splitMemoCrossChain) != 2 { + return "", errors.New("invalid memo") + } + memoSuffix := splitMemoCrossChain[1] + memoChainIDs := strings.Split(memoSuffix, ";") + + if len(memoChainIDs) != 2 { + return "", errors.New("invalid memo") + } + signedChainSplit := strings.Split(memoChainIDs[0], ":") + carbonChainSplit := strings.Split(memoChainIDs[1], ":") + if len(signedChainSplit) != 2 || len(carbonChainSplit) != 2 { + return "", errors.New("invalid memo") + } + // Check the carbon chain ID to prevent replay attack + if carbonChainSplit[1] != evmkeeper.EvmChainId { + return "", fmt.Errorf("invalid chainId, expected %v, got %v", evmkeeper.EvmChainId, carbonChainSplit[1]) + } + return signedChainSplit[1], nil +} + // decodeAminoSignDoc attempts to decode the provided sign doc (bytes) as an Amino payload // and returns a signable EIP-712 TypedData object. func decodeAminoSignDoc(signDocBytes []byte) (apitypes.TypedData, error) { @@ -116,11 +143,20 @@ func decodeAminoSignDoc(signDocBytes []byte) (apitypes.TypedData, error) { return apitypes.TypedData{}, err } + // Cross chain signing portion to convert aminoDoc.ChainId to the chainID of the signed chain + memo := aminoDoc.Memo + if strings.Contains(memo, "|CROSSCHAIN-SIGNING|") { + signedChainID, err := getChainIDFromMemo(memo) + if err != nil { + return apitypes.TypedData{}, err + } + aminoDoc.ChainID = signedChainID + } + chainID, err := types.ParseChainID(aminoDoc.ChainID) if err != nil { return apitypes.TypedData{}, errors.New("invalid chain ID passed as argument") } - typedData, err := WrapTxToTypedData( chainID.Uint64(), signDocBytes, @@ -180,11 +216,6 @@ func decodeProtobufSignDoc(signDocBytes []byte) (apitypes.TypedData, error) { signerInfo := authInfo.SignerInfos[0] - chainID, err := types.ParseChainID(signDoc.ChainId) - if err != nil { - return apitypes.TypedData{}, fmt.Errorf("invalid chain ID passed as argument: %w", err) - } - stdFee := &legacytx.StdFee{ Amount: authInfo.Fee.Amount, Gas: authInfo.Fee.GasLimit, @@ -201,6 +232,21 @@ func decodeProtobufSignDoc(signDocBytes []byte) (apitypes.TypedData, error) { body.Memo, ) + // Cross chain signing portion to convert signDoc.ChainId to the chainID of the signed chain + memo := body.Memo + if strings.Contains(memo, "|CROSSCHAIN-SIGNING|") { + signedChainID, err := getChainIDFromMemo(memo) + if err != nil { + return apitypes.TypedData{}, err + } + signDoc.ChainId = signedChainID + } + + chainID, err := types.ParseChainID(signDoc.ChainId) + if err != nil { + return apitypes.TypedData{}, fmt.Errorf("invalid chain ID passed as argument: %w", err) + } + typedData, err := WrapTxToTypedData( chainID.Uint64(), signBytes,