From 93f9e580bca245aa213a0febdfc9b2dc4b6a2067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobaczewski?= Date: Sat, 17 Feb 2024 01:07:01 +0100 Subject: [PATCH] v0.6.0 (#7) * Use pointers to pass tx and calls in RPC client, add chain ID tx modifier * Add context to key methods, add RPC key * Allow to pass contract code in constructor * Add GetUncle* methods * Add GetBlockReceipts method * Add support for web3 and net RPC methods * Add safe and finalized block tags * Add filter RPC methods * Use receiver only channels in RPC client * Do not use cache map in ChainIDProvider * Update README * Allow to set fixed chain ID in ChainIDProvider * Create response channel before writing to stream --- README.md | 127 ++--- abi/constructor.go | 32 +- abi/constructor_test.go | 6 +- examples/call-abi/main.go | 2 +- examples/call/main.go | 2 +- examples/custom-type-advanced/main.go | 29 -- examples/custom-type-advenced/main.go | 122 +++++ examples/custom-type-simple/main.go | 111 +---- examples/events/main.go | 2 +- examples/send-tx/main.go | 12 +- examples/subscription/main.go | 2 +- rpc/base.go | 183 +++++++- rpc/base_test.go | 642 +++++++++++++++++++++++++- rpc/client.go | 150 +++--- rpc/client_test.go | 22 +- rpc/mocks_test.go | 10 +- rpc/rpc.go | 126 +++-- rpc/transport/stream.go | 8 +- txmodifier/chainid.go | 82 ++++ txmodifier/chainid_test.go | 89 ++++ txmodifier/gaslimit.go | 2 +- txmodifier/gaslimit_test.go | 6 +- txmodifier/txmodifier_test.go | 9 +- types/types.go | 86 +++- types/types_test.go | 169 ++++++- wallet/key.go | 28 +- wallet/key_priv.go | 15 +- wallet/key_rpc.go | 60 +++ 28 files changed, 1720 insertions(+), 414 deletions(-) delete mode 100644 examples/custom-type-advanced/main.go create mode 100644 examples/custom-type-advenced/main.go create mode 100644 txmodifier/chainid.go create mode 100644 txmodifier/chainid_test.go create mode 100644 wallet/key_rpc.go diff --git a/README.md b/README.md index a278f1c..26e124a 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,12 @@ Some of key features include: * [Decoding method return values](#decoding-method-return-values) * [Events / Logs](#events--logs) * [Decoding events](#decoding-events) - * [Errors](#errors) - * [Reverts](#reverts) - * [Panics](#panics) * [Contract ABI](#contract-abi) * [JSON-ABI](#json-abi) * [Human-Readable ABI](#human-readable-abi) + * [Errors](#errors) + * [Reverts](#reverts) + * [Panics](#panics) * [Signature parser syntax](#signature-parser-syntax) * [Custom types](#custom-types) * [Simple types](#simple-types) @@ -147,7 +147,7 @@ func main() { SetInput(calldata) // Call balanceOf. - b, _, err := c.Call(context.Background(), *call, types.LatestBlockNumber) + b, _, err := c.Call(context.Background(), call, types.LatestBlockNumber) if err != nil { panic(err) } @@ -249,7 +249,7 @@ func main() { SetInput(calldata) // Call the contract. - b, _, err := c.Call(context.Background(), *call, types.LatestBlockNumber) + b, _, err := c.Call(context.Background(), call, types.LatestBlockNumber) if err != nil { panic(err) } @@ -320,10 +320,6 @@ func main() { // does not have a 'From' field set. rpc.WithDefaultAddress(key.Address()), - // Specify a chain ID for SendTransaction when the transaction - // does not have a 'ChainID' field set. - rpc.WithChainID(1), - // TX modifiers enable modifications to the transaction before signing // and sending to the node. While not mandatory, without them, transaction // parameters like gas limit, gas price, and nonce must be set manually. @@ -345,6 +341,12 @@ func main() { txmodifier.NewNonceProvider(txmodifier.NonceProviderOptions{ UsePendingBlock: false, }), + + // ChainIDProvider automatically sets the chain ID for the transaction. + txmodifier.NewChainIDProvider(txmodifier.ChainIDProviderOptions{ + Replace: false, + Cache: true, + }), ), ) if err != nil { @@ -362,7 +364,7 @@ func main() { SetTo(types.MustAddressFromHex("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")). SetInput(calldata) - txHash, _, err := c.SendTransaction(context.Background(), *tx) + txHash, _, err := c.SendTransaction(context.Background(), tx) if err != nil { panic(err) } @@ -422,7 +424,7 @@ func main() { SetTopics([]types.Hash{transfer.Topic0()}) // Fetch logs for WETH transfer events. - logs, err := c.SubscribeLogs(ctx, *query) + logs, err := c.SubscribeLogs(ctx, query) if err != nil { panic(err) } @@ -445,15 +447,15 @@ func main() { To connect to a node, it is necessary to choose a suitable transport method. The transport is responsible for executing a low-level communication protocol with the node. The `go-eth` package offers the following transport options: -| Transport | Description | Subscriptions | -|-----------|---------------------------------------------------------------------------------------------|-----------------| -| HTTP | Connects to a node using the HTTP protocol. | No | -| WebSocket | Connects to a node using the WebSocket protocol. | Yes | -| IPC | Connects to a node using the IPC protocol. | Yes | -| Retry | Wraps a transport and retries requests in case of an error. | Yes2 | -| Combined | Wraps two transports and uses one for requests and the other for subscriptions.1 | Yes | +| Transport | Description | Subscriptions | +|-----------|--------------------------------------------------------------------------------------------|-----------------| +| HTTP | Connects to a node using the HTTP protocol. | No | +| WebSocket | Connects to a node using the WebSocket protocol. | Yes | +| IPC | Connects to a node using the IPC protocol. | Yes | +| Retry | Wraps a transport and retries requests in case of an error. | Yes2 | +| Combined | Wraps two transports and uses one for methods and the other for subscriptions.1 | Yes | -1. It is recommended by some RPC providers to use HTTP for requests and WebSocket for subscriptions. +1. It is recommended by some RPC providers to use HTTP for methods and WebSocket for subscriptions. 2. Only if the underlying transport supports subscriptions. Transports can be created using the `transport.New*` functions. It is also possible to create custom transport by @@ -470,6 +472,7 @@ The `go-eth` package provides support for the following wallet types: | JSON key file1 | `key, err := wallet.NewKeyFromJSON(path, password)` | | JSON key content1 | `key, err := wallet.NewKeyFromJSONContent(jsonContent, password)` | | Mnemonic | `key, err := wallet.NewKeyFromMnemonic(mnemonic, password, account, index)` | +| Remote RPC | `key := wallet.NewKeyRPC(client, address)` | 1. Only V3 JSON keys are supported. @@ -562,8 +565,7 @@ func main() { In the example above, data is encoded and decoded using a struct. The `abi` tags map the struct fields to the corresponding tuple or struct fields. These tags are optional. If absent, fields are mapped by name, with the first -consecutive uppercase letters converted to lowercase. For instance, the `Number` struct field maps to the `number` -field, and the `DAPPName` field maps to the `dappName` field. +consecutive uppercase letters converted to lowercase. It is also possible to encode and decode values to a separate variables: @@ -602,8 +604,8 @@ func main() { } ``` -Note that in both examples above, similarly named functions are used to encode and decode data. The only difference is -that the second example uses the plural form of the function. The plural form is used to encode and decode data from +**Note that in both examples above, similarly named functions are used to encode and decode data. The only difference is +that the second example uses the plural form of the function.** The plural form is used to encode and decode data from separate variables, while the singular form is used for structs or maps. This is a common pattern in the `go-eth` package. @@ -680,8 +682,8 @@ When mapping between Go and Solidity types, the following rules apply: * ✓ - Supported * ✗ - Not supported -1. Destination type must be able to hold the value of the source type. For example, `uint16` can be mapped to `uint8`, - but only if the value is less than 256. +1. Destination type must be able to hold the value of the source type. Otherwise, the mapping will result in an error. + For example, `uint16` can be mapped to `uint8`, but only if the value is less than 256. 2. Mapping of negative values is supported only if both types support negative values. 3. Only mapping from/to `bytes32` is supported. 4. Only mapping from/to `bytes20` is supported. @@ -783,8 +785,7 @@ func main() { To decode contract events, the `abi.Event` structure needs to be created. Events may be created using different methods: - `abi.ParseEvent` / `abi.MustParseEvent` - creates a new event by parsing an event signature. -- - - `abi.NewEvent(name, inputs)` - creates a new event using provided arguments. +- `abi.NewEvent(name, inputs)` - creates a new event using provided arguments. - Using the `abi.Contract` struct (see [Contract ABI](#contract-abi) section). #### Decoding events @@ -828,7 +829,7 @@ func main() { SetTopics([]types.Hash{transfer.Topic0()}) // Fetch logs for WETH transfer events. - logs, err := c.GetLogs(context.Background(), *query) + logs, err := c.GetLogs(context.Background(), query) if err != nil { panic(err) } @@ -843,41 +844,18 @@ func main() { } ``` -### Errors - -To decode custom contract errors, first a `abi.Error` struct must be created. Errors may be created using different -methods: - -- `abi.ParseError` / `abi.MustParseError` - creates a new error by parsing an error signature. -- `abi.NewError(name, inputs)` - creates a new error using provided arguments. -- Using the `abi.Contract` struct (see [Contract ABI](#contract-abi) section). - -Custom errors may be decoded from errors returned by the `Call` function using the `abi.Error.HandleError` method. - -When using a `abi.Contract`, errors may be decoded from call errors using the `abi.Contract.HandleError` method. This -method will try to decode the error using all errors defined in the contract, also including reverts and panics. - -### Reverts - -Reverts are special errors returned by the EVM when a contract call fails. Reverts are ABI-encoded errors with -the `Error(string)` signature. The `abi.DecodeRevert` function can be used to decode reverts. Optionally, the `abi` -package provides `abi.Revert`, a predefined error type that can be used to decode reverts. - -To verify if an error is a revert, use the `abi.IsRevert` function. - -### Panics - -Similar to reverts, panics are special errors returned by the EVM when a contract call fails. Panics are ABI-encoded -errors with the `Panic(uint256)` signature. The `abi.DecodePanic` function can be used to decode panics. Optionally, the -`abi` package also provides `abi.Panic`, a predefined error type that can be used to decode panics. - -To verify if an error is a panic, use the `abi.IsPanic` function. - ### Contract ABI The `abi.Contract` structure is a utility that provides an interface to a contract. It can be created using a JSON-ABI file or by supplying a list of signatures (also known as a Human-Readable ABI). +To create a contract struct, the following methods may be used: + +- `abi.LoadJSON` / `abi.MustLoadJSON` - creates a new contract by loading a JSON-ABI file. +- `abi.ParseJSON` / `abi.MustParseJSON` - creates a new contract by parsing a JSON-ABI string. +- `abi.ParseSignatures` / `abi.MustParseSignatures` - creates a new contract by parsing a list of signatures ( + Human-Readable ABI). + #### JSON-ABI @@ -956,6 +934,36 @@ func main() { } ``` +### Errors + +To decode custom contract errors, first a `abi.Error` struct must be created. Errors may be created using different +methods: + +- `abi.ParseError` / `abi.MustParseError` - creates a new error by parsing an error signature. +- `abi.NewError(name, inputs)` - creates a new error using provided arguments. +- Using the `abi.Contract` struct (see [Contract ABI](#contract-abi) section). + +Custom errors may be decoded from errors returned by the `Call` function using the `abi.Error.HandleError` method. + +When using a `abi.Contract`, errors may be decoded from call errors using the `abi.Contract.HandleError` method. This +method will try to decode the error using all errors defined in the contract, also including reverts and panics. + +### Reverts + +Reverts are special errors returned by the EVM when a contract call fails. Reverts are ABI-encoded errors with +the `Error(string)` signature. The `abi.DecodeRevert` function can be used to decode reverts. Optionally, the `abi` +package provides `abi.Revert`, a predefined error type that can be used to decode reverts. + +To verify if an error is a revert, use the `abi.IsRevert` function. + +### Panics + +Similar to reverts, panics are special errors returned by the EVM when a contract call fails. Panics are ABI-encoded +errors with the `Panic(uint256)` signature. The `abi.DecodePanic` function can be used to decode panics. Optionally, the +`abi` package also provides `abi.Panic`, a predefined error type that can be used to decode panics. + +To verify if an error is a panic, use the `abi.IsPanic` function. + ### Signature parser syntax The parser is based on Solidity grammar, but it allows for the omission of argument names, as well as the `returns` @@ -974,7 +982,8 @@ Examples of signatures that are accepted by the parser: ### Custom types -It is possible to add custom types to the `abi` package. +The `go-eth` package allows for the creation of custom types that can be used with the ABI encoder and decoder and with +the signature parser. #### Simple types diff --git a/abi/constructor.go b/abi/constructor.go index c19086d..ef6c806 100644 --- a/abi/constructor.go +++ b/abi/constructor.go @@ -70,38 +70,46 @@ func (m *Constructor) Inputs() *TupleType { return m.inputs } -// EncodeArg encodes arguments for a constructor call using a provided map or -// structure. The map or structure must have fields with the same names as -// the constructor arguments. -func (m *Constructor) EncodeArg(arg any) ([]byte, error) { +// EncodeArg encodes an argument for a contract deployment. +// The map or structure must have fields with the same names as the +// constructor arguments. +func (m *Constructor) EncodeArg(code []byte, arg any) ([]byte, error) { encoded, err := m.abi.EncodeValue(m.inputs, arg) if err != nil { return nil, err } - return encoded, nil + input := make([]byte, len(code)+len(encoded)) + copy(input, code) + copy(input[len(code):], encoded) + return input, nil } // MustEncodeArg is like EncodeArg but panics on error. -func (m *Constructor) MustEncodeArg(arg any) []byte { - encoded, err := m.EncodeArg(arg) +func (m *Constructor) MustEncodeArg(code []byte, arg any) []byte { + encoded, err := m.EncodeArg(code, arg) if err != nil { panic(err) } return encoded } -// EncodeArgs encodes arguments for a constructor call. -func (m *Constructor) EncodeArgs(args ...any) ([]byte, error) { +// EncodeArgs encodes arguments for a contract deployment. +// The map or structure must have fields with the same names as the +// constructor arguments. +func (m *Constructor) EncodeArgs(code []byte, args ...any) ([]byte, error) { encoded, err := m.abi.EncodeValues(m.inputs, args...) if err != nil { return nil, err } - return encoded, nil + input := make([]byte, len(code)+len(encoded)) + copy(input, code) + copy(input[len(code):], encoded) + return input, nil } // MustEncodeArgs is like EncodeArgs but panics on error. -func (m *Constructor) MustEncodeArgs(args ...any) []byte { - encoded, err := m.EncodeArgs(args...) +func (m *Constructor) MustEncodeArgs(code []byte, args ...any) []byte { + encoded, err := m.EncodeArgs(code, args...) if err != nil { panic(err) } diff --git a/abi/constructor_test.go b/abi/constructor_test.go index 1b05935..a31ae37 100644 --- a/abi/constructor_test.go +++ b/abi/constructor_test.go @@ -44,14 +44,14 @@ func TestConstructor_EncodeArgs(t *testing.T) { arg []any expected string }{ - {signature: "constructor()", arg: nil, expected: ""}, - {signature: "constructor(uint256)", arg: []any{1}, expected: "0000000000000000000000000000000000000000000000000000000000000001"}, + {signature: "constructor()", arg: nil, expected: "aabb"}, + {signature: "constructor(uint256)", arg: []any{1}, expected: "aabb0000000000000000000000000000000000000000000000000000000000000001"}, } for n, tt := range tests { t.Run(fmt.Sprintf("case-%d", n+1), func(t *testing.T) { c, err := ParseConstructor(tt.signature) require.NoError(t, err) - enc, err := c.EncodeArgs(tt.arg...) + enc, err := c.EncodeArgs([]byte{0xAA, 0xBB}, tt.arg...) require.NoError(t, err) assert.Equal(t, tt.expected, hex.EncodeToString(enc)) }) diff --git a/examples/call-abi/main.go b/examples/call-abi/main.go index c3eabef..07ebfcd 100644 --- a/examples/call-abi/main.go +++ b/examples/call-abi/main.go @@ -78,7 +78,7 @@ func main() { SetInput(calldata) // Call the contract. - b, _, err := c.Call(context.Background(), *call, types.LatestBlockNumber) + b, _, err := c.Call(context.Background(), call, types.LatestBlockNumber) if err != nil { panic(err) } diff --git a/examples/call/main.go b/examples/call/main.go index 1217fba..1acbe17 100644 --- a/examples/call/main.go +++ b/examples/call/main.go @@ -36,7 +36,7 @@ func main() { SetInput(calldata) // Call balanceOf. - b, _, err := c.Call(context.Background(), *call, types.LatestBlockNumber) + b, _, err := c.Call(context.Background(), call, types.LatestBlockNumber) if err != nil { panic(err) } diff --git a/examples/custom-type-advanced/main.go b/examples/custom-type-advanced/main.go deleted file mode 100644 index 6283301..0000000 --- a/examples/custom-type-advanced/main.go +++ /dev/null @@ -1,29 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/defiweb/go-eth/abi" - "github.com/defiweb/go-eth/hexutil" -) - -type Point struct { - X int - Y int -} - -func main() { - // Add custom type. - abi.Default.Types["Point"] = abi.MustParseStruct("struct {int256 x; int256 y;}") - - // Generate calldata. - addTriangle := abi.MustParseMethod("addTriangle(Point a, Point b, Point c)") - calldata := addTriangle.MustEncodeArgs( - Point{X: 1, Y: 2}, - Point{X: 3, Y: 4}, - Point{X: 5, Y: 6}, - ) - - // Print the calldata. - fmt.Printf("Calldata: %s\n", hexutil.BytesToHex(calldata)) -} diff --git a/examples/custom-type-advenced/main.go b/examples/custom-type-advenced/main.go new file mode 100644 index 0000000..8109113 --- /dev/null +++ b/examples/custom-type-advenced/main.go @@ -0,0 +1,122 @@ +package main + +import ( + "fmt" + + "github.com/defiweb/go-eth/abi" + "github.com/defiweb/go-eth/hexutil" +) + +// BoolFlagsType is a custom type that represents a 256-bit bitfield. +// +// It must implement the abi.Type interface. +type BoolFlagsType struct{} + +// IsDynamic returns true if the type is dynamic-length, like string or bytes. +func (b BoolFlagsType) IsDynamic() bool { + return false +} + +// CanonicalType is the type as it would appear in the ABI. +// It must only use the types defined in the ABI specification: +// https://docs.soliditylang.org/en/latest/abi-spec.html +func (b BoolFlagsType) CanonicalType() string { + return "bytes32" +} + +// String returns the custom type name. +func (b BoolFlagsType) String() string { + return "BoolFlags" +} + +// Value returns the zero value for this type. +func (b BoolFlagsType) Value() abi.Value { + return &BoolFlagsValue{} +} + +// BoolFlagsValue is the value of the custom type. +// +// It must implement the abi.Value interface. +type BoolFlagsValue [256]bool + +// IsDynamic returns true if the type is dynamic-length, like string or bytes. +func (b BoolFlagsValue) IsDynamic() bool { + return false +} + +// EncodeABI encodes the value to the ABI format. +func (b BoolFlagsValue) EncodeABI() (abi.Words, error) { + var w abi.Word + for i, v := range b { + if v { + w[i/8] |= 1 << uint(i%8) + } + } + return abi.Words{w}, nil +} + +// DecodeABI decodes the value from the ABI format. +func (b *BoolFlagsValue) DecodeABI(words abi.Words) (int, error) { + if len(words) == 0 { + return 0, fmt.Errorf("abi: cannot decode BytesFlags from empty data") + } + for i, v := range words[0] { + for j := 0; j < 8; j++ { + b[i*8+j] = v&(1< 256 { + return fmt.Errorf("abi: cannot map []bool of length %d to BytesFlags", len(src)) + } + for i, v := range src { + b[i] = v + } + } + return nil +} + +// MapTo maps value to a different type. +func (b *BoolFlagsValue) MapTo(_ abi.Mapper, dst any) error { + switch dst := dst.(type) { + case *[256]bool: + *dst = *b + case *[]bool: + *dst = make([]bool, 256) + for i, v := range b { + (*dst)[i] = v + } + } + return nil +} + +func main() { + // Add custom type. + abi.Default.Types["BoolFlags"] = &BoolFlagsType{} + + // Generate calldata. + setFlags := abi.MustParseMethod("setFlags(BoolFlags flags)") + calldata, _ := setFlags.EncodeArgs( + []bool{true, false, true, true, false, true, false, true}, + ) + + // Print the calldata. + fmt.Printf("Calldata: %s\n", hexutil.BytesToHex(calldata)) +} diff --git a/examples/custom-type-simple/main.go b/examples/custom-type-simple/main.go index 8109113..6283301 100644 --- a/examples/custom-type-simple/main.go +++ b/examples/custom-type-simple/main.go @@ -7,114 +7,21 @@ import ( "github.com/defiweb/go-eth/hexutil" ) -// BoolFlagsType is a custom type that represents a 256-bit bitfield. -// -// It must implement the abi.Type interface. -type BoolFlagsType struct{} - -// IsDynamic returns true if the type is dynamic-length, like string or bytes. -func (b BoolFlagsType) IsDynamic() bool { - return false -} - -// CanonicalType is the type as it would appear in the ABI. -// It must only use the types defined in the ABI specification: -// https://docs.soliditylang.org/en/latest/abi-spec.html -func (b BoolFlagsType) CanonicalType() string { - return "bytes32" -} - -// String returns the custom type name. -func (b BoolFlagsType) String() string { - return "BoolFlags" -} - -// Value returns the zero value for this type. -func (b BoolFlagsType) Value() abi.Value { - return &BoolFlagsValue{} -} - -// BoolFlagsValue is the value of the custom type. -// -// It must implement the abi.Value interface. -type BoolFlagsValue [256]bool - -// IsDynamic returns true if the type is dynamic-length, like string or bytes. -func (b BoolFlagsValue) IsDynamic() bool { - return false -} - -// EncodeABI encodes the value to the ABI format. -func (b BoolFlagsValue) EncodeABI() (abi.Words, error) { - var w abi.Word - for i, v := range b { - if v { - w[i/8] |= 1 << uint(i%8) - } - } - return abi.Words{w}, nil -} - -// DecodeABI decodes the value from the ABI format. -func (b *BoolFlagsValue) DecodeABI(words abi.Words) (int, error) { - if len(words) == 0 { - return 0, fmt.Errorf("abi: cannot decode BytesFlags from empty data") - } - for i, v := range words[0] { - for j := 0; j < 8; j++ { - b[i*8+j] = v&(1< 256 { - return fmt.Errorf("abi: cannot map []bool of length %d to BytesFlags", len(src)) - } - for i, v := range src { - b[i] = v - } - } - return nil -} - -// MapTo maps value to a different type. -func (b *BoolFlagsValue) MapTo(_ abi.Mapper, dst any) error { - switch dst := dst.(type) { - case *[256]bool: - *dst = *b - case *[]bool: - *dst = make([]bool, 256) - for i, v := range b { - (*dst)[i] = v - } - } - return nil +type Point struct { + X int + Y int } func main() { // Add custom type. - abi.Default.Types["BoolFlags"] = &BoolFlagsType{} + abi.Default.Types["Point"] = abi.MustParseStruct("struct {int256 x; int256 y;}") // Generate calldata. - setFlags := abi.MustParseMethod("setFlags(BoolFlags flags)") - calldata, _ := setFlags.EncodeArgs( - []bool{true, false, true, true, false, true, false, true}, + addTriangle := abi.MustParseMethod("addTriangle(Point a, Point b, Point c)") + calldata := addTriangle.MustEncodeArgs( + Point{X: 1, Y: 2}, + Point{X: 3, Y: 4}, + Point{X: 5, Y: 6}, ) // Print the calldata. diff --git a/examples/events/main.go b/examples/events/main.go index 977c68e..77eb0bc 100644 --- a/examples/events/main.go +++ b/examples/events/main.go @@ -34,7 +34,7 @@ func main() { SetTopics([]types.Hash{transfer.Topic0()}) // Fetch logs for WETH transfer events. - logs, err := c.GetLogs(context.Background(), *query) + logs, err := c.GetLogs(context.Background(), query) if err != nil { panic(err) } diff --git a/examples/send-tx/main.go b/examples/send-tx/main.go index 542a492..3b6545f 100644 --- a/examples/send-tx/main.go +++ b/examples/send-tx/main.go @@ -40,10 +40,6 @@ func main() { // does not have a 'From' field set. rpc.WithDefaultAddress(key.Address()), - // Specify a chain ID for SendTransaction when the transaction - // does not have a 'ChainID' field set. - rpc.WithChainID(1), - // TX modifiers enable modifications to the transaction before signing // and sending to the node. While not mandatory, without them, transaction // parameters like gas limit, gas price, and nonce must be set manually. @@ -65,6 +61,12 @@ func main() { txmodifier.NewNonceProvider(txmodifier.NonceProviderOptions{ UsePendingBlock: false, }), + + // ChainIDProvider automatically sets the chain ID for the transaction. + txmodifier.NewChainIDProvider(txmodifier.ChainIDProviderOptions{ + Replace: false, + Cache: true, + }), ), ) if err != nil { @@ -82,7 +84,7 @@ func main() { SetTo(types.MustAddressFromHex("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")). SetInput(calldata) - txHash, _, err := c.SendTransaction(context.Background(), *tx) + txHash, _, err := c.SendTransaction(context.Background(), tx) if err != nil { panic(err) } diff --git a/examples/subscription/main.go b/examples/subscription/main.go index 9e1edf2..82f32a9 100644 --- a/examples/subscription/main.go +++ b/examples/subscription/main.go @@ -41,7 +41,7 @@ func main() { SetTopics([]types.Hash{transfer.Topic0()}) // Fetch logs for WETH transfer events. - logs, err := c.SubscribeLogs(ctx, *query) + logs, err := c.SubscribeLogs(ctx, query) if err != nil { panic(err) } diff --git a/rpc/base.go b/rpc/base.go index 5817e8b..1fd6dd8 100644 --- a/rpc/base.go +++ b/rpc/base.go @@ -17,6 +17,60 @@ type baseClient struct { transport transport.Transport } +// ClientVersion implements the RPC interface. +func (c *baseClient) ClientVersion(ctx context.Context) (string, error) { + var res string + if err := c.transport.Call(ctx, &res, "web3_clientVersion"); err != nil { + return "", err + } + return res, nil +} + +// Listening implements the RPC interface. +func (c *baseClient) Listening(ctx context.Context) (bool, error) { + var res bool + if err := c.transport.Call(ctx, &res, "net_listening"); err != nil { + return false, err + } + return res, nil +} + +// PeerCount implements the RPC interface. +func (c *baseClient) PeerCount(ctx context.Context) (uint64, error) { + var res types.Number + if err := c.transport.Call(ctx, &res, "net_peerCount"); err != nil { + return 0, err + } + return res.Big().Uint64(), nil +} + +// ProtocolVersion implements the RPC interface. +func (c *baseClient) ProtocolVersion(ctx context.Context) (uint64, error) { + var res types.Number + if err := c.transport.Call(ctx, &res, "eth_protocolVersion"); err != nil { + return 0, err + } + return res.Big().Uint64(), nil +} + +// Syncing implements the RPC interface. +func (c *baseClient) Syncing(ctx context.Context) (*types.SyncStatus, error) { + var res types.SyncStatus + if err := c.transport.Call(ctx, &res, "eth_syncing"); err != nil { + return nil, err + } + return &res, nil +} + +// NetworkID implements the RPC interface. +func (c *baseClient) NetworkID(ctx context.Context) (uint64, error) { + var res types.Number + if err := c.transport.Call(ctx, &res, "net_version"); err != nil { + return 0, err + } + return res.Big().Uint64(), nil +} + // ChainID implements the RPC interface. func (c *baseClient) ChainID(ctx context.Context) (uint64, error) { var res types.Number @@ -153,7 +207,10 @@ func (c *baseClient) Sign(ctx context.Context, account types.Address, data []byt } // SignTransaction implements the RPC interface. -func (c *baseClient) SignTransaction(ctx context.Context, tx types.Transaction) ([]byte, *types.Transaction, error) { +func (c *baseClient) SignTransaction(ctx context.Context, tx *types.Transaction) ([]byte, *types.Transaction, error) { + if tx == nil { + return nil, nil, errors.New("rpc client: transaction is nil") + } var res signTransactionResult if err := c.transport.Call(ctx, &res, "eth_signTransaction", tx); err != nil { return nil, nil, err @@ -162,12 +219,15 @@ func (c *baseClient) SignTransaction(ctx context.Context, tx types.Transaction) } // SendTransaction implements the RPC interface. -func (c *baseClient) SendTransaction(ctx context.Context, tx types.Transaction) (*types.Hash, *types.Transaction, error) { +func (c *baseClient) SendTransaction(ctx context.Context, tx *types.Transaction) (*types.Hash, *types.Transaction, error) { + if tx == nil { + return nil, nil, errors.New("rpc client: transaction is nil") + } var res types.Hash if err := c.transport.Call(ctx, &res, "eth_sendTransaction", tx); err != nil { return nil, nil, err } - return &res, &tx, nil + return &res, tx, nil } // SendRawTransaction implements the RPC interface. @@ -180,24 +240,30 @@ func (c *baseClient) SendRawTransaction(ctx context.Context, data []byte) (*type } // Call implements the RPC interface. -func (c *baseClient) Call(ctx context.Context, call types.Call, block types.BlockNumber) ([]byte, *types.Call, error) { +func (c *baseClient) Call(ctx context.Context, call *types.Call, block types.BlockNumber) ([]byte, *types.Call, error) { + if call == nil { + return nil, nil, errors.New("rpc client: call is nil") + } var res types.Bytes if err := c.transport.Call(ctx, &res, "eth_call", call, block); err != nil { return nil, nil, err } - return res, &call, nil + return res, call, nil } // EstimateGas implements the RPC interface. -func (c *baseClient) EstimateGas(ctx context.Context, call types.Call, block types.BlockNumber) (uint64, error) { +func (c *baseClient) EstimateGas(ctx context.Context, call *types.Call, block types.BlockNumber) (uint64, *types.Call, error) { + if call == nil { + return 0, nil, errors.New("rpc client: call is nil") + } var res types.Number if err := c.transport.Call(ctx, &res, "eth_estimateGas", call, block); err != nil { - return 0, err + return 0, nil, err } if !res.Big().IsUint64() { - return 0, errors.New("gas estimate is too big") + return 0, nil, errors.New("gas estimate is too big") } - return res.Big().Uint64(), nil + return res.Big().Uint64(), call, nil } // BlockByHash implements the RPC interface. @@ -254,8 +320,99 @@ func (c *baseClient) GetTransactionReceipt(ctx context.Context, hash types.Hash) return &res, nil } +// GetBlockReceipts implements the RPC interface. +func (c *baseClient) GetBlockReceipts(ctx context.Context, block types.BlockNumber) ([]*types.TransactionReceipt, error) { + var res []*types.TransactionReceipt + if err := c.transport.Call(ctx, &res, "eth_getBlockReceipts", block); err != nil { + return nil, err + } + return res, nil +} + +// GetUncleByBlockHashAndIndex implements the RPC interface. +func (c *baseClient) GetUncleByBlockHashAndIndex(ctx context.Context, hash types.Hash, index uint64) (*types.Block, error) { + var res types.Block + if err := c.transport.Call(ctx, &res, "eth_getUncleByBlockHashAndIndex", hash, types.NumberFromUint64(index)); err != nil { + return nil, err + } + return &res, nil +} + +// GetUncleByBlockNumberAndIndex implements the RPC interface. +func (c *baseClient) GetUncleByBlockNumberAndIndex(ctx context.Context, number types.BlockNumber, index uint64) (*types.Block, error) { + var res types.Block + if err := c.transport.Call(ctx, &res, "eth_getUncleByBlockNumberAndIndex", number, types.NumberFromUint64(index)); err != nil { + return nil, err + } + return &res, nil +} + +// NewFilter implements the RPC interface. +func (c *baseClient) NewFilter(ctx context.Context, query *types.FilterLogsQuery) (*big.Int, error) { + var res *types.Number + if err := c.transport.Call(ctx, &res, "eth_newFilter", query); err != nil { + return nil, err + } + return res.Big(), nil +} + +// NewBlockFilter implements the RPC interface. +func (c *baseClient) NewBlockFilter(ctx context.Context) (*big.Int, error) { + var res *types.Number + if err := c.transport.Call(ctx, &res, "eth_newBlockFilter"); err != nil { + return nil, err + } + return res.Big(), nil + +} + +// NewPendingTransactionFilter implements the RPC interface. +func (c *baseClient) NewPendingTransactionFilter(ctx context.Context) (*big.Int, error) { + var res *types.Number + if err := c.transport.Call(ctx, &res, "eth_newPendingTransactionFilter"); err != nil { + return nil, err + } + return res.Big(), nil +} + +// UninstallFilter implements the RPC interface. +func (c *baseClient) UninstallFilter(ctx context.Context, id *big.Int) (bool, error) { + var res bool + if err := c.transport.Call(ctx, &res, "eth_uninstallFilter", types.NumberFromBigInt(id)); err != nil { + return false, err + } + return res, nil +} + +// GetFilterChanges implements the RPC interface. +func (c *baseClient) GetFilterChanges(ctx context.Context, id *big.Int) ([]types.Log, error) { + var res []types.Log + if err := c.transport.Call(ctx, &res, "eth_getFilterChanges", types.NumberFromBigInt(id)); err != nil { + return nil, err + } + return res, nil +} + +// GetFilterLogs implements the RPC interface. +func (c *baseClient) GetFilterLogs(ctx context.Context, id *big.Int) ([]types.Log, error) { + var res []types.Log + if err := c.transport.Call(ctx, &res, "eth_getFilterLogs", types.NumberFromBigInt(id)); err != nil { + return nil, err + } + return res, nil +} + +// GetBlockFilterChanges implements the RPC interface. +func (c *baseClient) GetBlockFilterChanges(ctx context.Context, id *big.Int) ([]types.Hash, error) { + var res []types.Hash + if err := c.transport.Call(ctx, &res, "eth_getFilterChanges", types.NumberFromBigInt(id)); err != nil { + return nil, err + } + return res, nil +} + // GetLogs implements the RPC interface. -func (c *baseClient) GetLogs(ctx context.Context, query types.FilterLogsQuery) ([]types.Log, error) { +func (c *baseClient) GetLogs(ctx context.Context, query *types.FilterLogsQuery) ([]types.Log, error) { var res []types.Log if err := c.transport.Call(ctx, &res, "eth_getLogs", query); err != nil { return nil, err @@ -273,17 +430,17 @@ func (c *baseClient) MaxPriorityFeePerGas(ctx context.Context) (*big.Int, error) } // SubscribeLogs implements the RPC interface. -func (c *baseClient) SubscribeLogs(ctx context.Context, query types.FilterLogsQuery) (chan types.Log, error) { +func (c *baseClient) SubscribeLogs(ctx context.Context, query *types.FilterLogsQuery) (<-chan types.Log, error) { return subscribe[types.Log](ctx, c.transport, "logs", query) } // SubscribeNewHeads implements the RPC interface. -func (c *baseClient) SubscribeNewHeads(ctx context.Context) (chan types.Block, error) { +func (c *baseClient) SubscribeNewHeads(ctx context.Context) (<-chan types.Block, error) { return subscribe[types.Block](ctx, c.transport, "newHeads") } // SubscribeNewPendingTransactions implements the RPC interface. -func (c *baseClient) SubscribeNewPendingTransactions(ctx context.Context) (chan types.Hash, error) { +func (c *baseClient) SubscribeNewPendingTransactions(ctx context.Context) (<-chan types.Hash, error) { return subscribe[types.Hash](ctx, c.transport, "newPendingTransactions") } diff --git a/rpc/base_test.go b/rpc/base_test.go index 21ed3b2..130a168 100644 --- a/rpc/base_test.go +++ b/rpc/base_test.go @@ -17,6 +17,206 @@ import ( "github.com/defiweb/go-eth/types" ) +const mockClientVersionRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "web3_clientVersion", + "params": [] + } +` + +const mockClientVersionResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "Geth/v1.9.25-unstable-3f0b5e4e-20201014/linux-amd64/go1.15.2" + } +` + +func TestBaseClient_ClientVersion(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockClientVersionResponse)), + } + + clientVersion, err := client.ClientVersion(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockClientVersionRequest, readBody(httpMock.Request)) + assert.Equal(t, "Geth/v1.9.25-unstable-3f0b5e4e-20201014/linux-amd64/go1.15.2", clientVersion) +} + +const mockListeningRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "net_listening", + "params": [] + } +` + +const mockListeningResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": true + } +` + +func TestBaseClient_Listening(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockListeningResponse)), + } + + listening, err := client.Listening(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockListeningRequest, readBody(httpMock.Request)) + assert.True(t, listening) +} + +const mockPeerCountRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "net_peerCount", + "params": [] + } +` + +const mockPeerCountResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_PeerCount(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockPeerCountResponse)), + } + + peerCount, err := client.PeerCount(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockPeerCountRequest, readBody(httpMock.Request)) + assert.Equal(t, uint64(1), peerCount) +} + +const mockProtocolVersionRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_protocolVersion", + "params": [] + } +` + +const mockProtocolVersionResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_ProtocolVersion(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockProtocolVersionResponse)), + } + + protocolVersion, err := client.ProtocolVersion(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockProtocolVersionRequest, readBody(httpMock.Request)) + assert.Equal(t, uint64(1), protocolVersion) +} + +const mockSyncingRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_syncing", + "params": [] + } +` + +const mockSyncingResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "startingBlock": "0x384", + "currentBlock": "0x386", + "highestBlock": "0x454" + } + } +` + +func TestBaseClient_Syncing(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockSyncingResponse)), + } + + syncing, err := client.Syncing(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockSyncingRequest, readBody(httpMock.Request)) + assert.Equal(t, &types.SyncStatus{ + StartingBlock: types.MustBlockNumberFromHex("0x384"), + CurrentBlock: types.MustBlockNumberFromHex("0x386"), + HighestBlock: types.MustBlockNumberFromHex("0x454"), + }, syncing) +} + +const mockNetworkIDRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "net_version", + "params": [] + } +` + +const mockNetworkIDResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_NetworkID(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockNetworkIDResponse)), + } + + networkID, err := client.NetworkID(context.Background()) + require.NoError(t, err) + assert.JSONEq(t, mockNetworkIDRequest, readBody(httpMock.Request)) + assert.Equal(t, uint64(1), networkID) +} + const mockChanIDRequest = ` { "jsonrpc": "2.0", @@ -516,7 +716,7 @@ func TestBaseClient_SignTransaction(t *testing.T) { chainID := uint64(1) raw, tx, err := client.SignTransaction( context.Background(), - types.Transaction{ + &types.Transaction{ ChainID: &chainID, Call: types.Call{ From: &from, @@ -585,7 +785,7 @@ func TestBaseClient_SendTransaction(t *testing.T) { chainID := uint64(1) txHash, tx, err := client.SendTransaction( context.Background(), - types.Transaction{ + &types.Transaction{ ChainID: &chainID, Call: types.Call{ From: &from, @@ -689,7 +889,7 @@ func TestBaseClient_Call(t *testing.T) { input := hexToBytes("0x3333333333333333333333333333333333333333333333333333333333333333333333333333333333") calldata, call, err := client.Call( context.Background(), - types.Call{ + &types.Call{ From: from, To: to, GasLimit: &gasLimit, @@ -747,9 +947,9 @@ func TestBaseClient_EstimateGas(t *testing.T) { } gasLimit := uint64(30400) - gas, err := client.EstimateGas( + gas, _, err := client.EstimateGas( context.Background(), - types.Call{ + &types.Call{ From: types.MustAddressFromHexPtr("0x1111111111111111111111111111111111111111"), To: types.MustAddressFromHexPtr("0x2222222222222222222222222222222222222222"), GasLimit: &gasLimit, @@ -1118,6 +1318,434 @@ func TestBaseClient_GetTransactionReceipt(t *testing.T) { assert.Equal(t, false, receipt.Logs[0].Removed) } +const mockGetBlockReceiptsRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getBlockReceipts", + "params": [ + "0x1" + ] + } +` + +const mockGetBlockReceiptsResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": [ + { + "blockHash": "0x1111111111111111111111111111111111111111111111111111111111111111", + "blockNumber": "0x2222", + "contractAddress": null, + "cumulativeGasUsed": "0x33333", + "effectiveGasPrice": "0x4444444444", + "from": "0x5555555555555555555555555555555555555555", + "gasUsed": "0x66666", + "logs": [ + { + "address": "0x7777777777777777777777777777777777777777", + "blockHash": "0x1111111111111111111111111111111111111111111111111111111111111111", + "blockNumber": "0x2222", + "data": "0x000000000000000000000000398137383b3d25c92898c656696e41950e47316b00000000000000000000000000000000000000000000000000000000000cee6100000000000000000000000000000000000000000000000000000000000ac3e100000000000000000000000000000000000000000000000000000000005baf35", + "logIndex": "0x8", + "removed": false, + "topics": [ + "0x9999999999999999999999999999999999999999999999999999999999999999" + ], + "transactionHash": "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "transactionIndex": "0x11" + } + ], + "logsBloom": "0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000200000000000000000000000000000", + "status": "0x1", + "to": "0x7777777777777777777777777777777777777777", + "transactionHash": "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "transactionIndex": "0x11", + "type": "0x0" + } + ] + } +` + +func TestBaseClient_GetBlockReceipts(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockGetBlockReceiptsResponse)), + } + + receipts, err := client.GetBlockReceipts( + context.Background(), + types.MustBlockNumberFromHex("0x1"), + ) + + require.NoError(t, err) + assert.JSONEq(t, mockGetBlockReceiptsRequest, readBody(httpMock.Request)) + require.Len(t, receipts, 1) + assert.Equal(t, types.MustHashFromHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", types.PadNone), receipts[0].TransactionHash) +} + +const mockGetUncleByBlockHashAndIndexRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getUncleByBlockHashAndIndex", + "params": [ + "0x1111111111111111111111111111111111111111111111111111111111111111", + "0x0" + ] + } +` + +func TestBaseClient_GetUncleByBlockHashAndIndex(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockBlockByNumberResponse)), + } + + block, err := client.GetUncleByBlockHashAndIndex( + context.Background(), + types.MustHashFromHex("0x1111111111111111111111111111111111111111111111111111111111111111", types.PadNone), + 0, + ) + + require.NoError(t, err) + assert.JSONEq(t, mockGetUncleByBlockHashAndIndexRequest, readBody(httpMock.Request)) + assert.Equal(t, types.MustHashFromHex("0x2222222222222222222222222222222222222222222222222222222222222222", types.PadNone), block.Hash) +} + +const mockGetUncleByBlockNumberAndIndexRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getUncleByBlockNumberAndIndex", + "params": [ + "0x1", + "0x2" + ] + } +` + +func TestBaseClient_GetUncleByBlockNumberAndIndex(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockBlockByNumberResponse)), + } + + block, err := client.GetUncleByBlockNumberAndIndex( + context.Background(), + types.MustBlockNumberFromHex("0x1"), + 2, + ) + + require.NoError(t, err) + assert.JSONEq(t, mockGetUncleByBlockNumberAndIndexRequest, readBody(httpMock.Request)) + assert.Equal(t, types.MustHashFromHex("0x2222222222222222222222222222222222222222222222222222222222222222", types.PadNone), block.Hash) +} + +const mockNewFilterRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_newFilter", + "params": [ + { + "fromBlock": "0x1", + "toBlock": "0x2", + "address": "0x3333333333333333333333333333333333333333", + "topics": ["0x4444444444444444444444444444444444444444444444444444444444444444"] + } + ] + } +` + +const mockNewFilterResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_NewFilter(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockNewFilterResponse)), + } + + from := types.MustBlockNumberFromHex("0x1") + to := types.MustBlockNumberFromHex("0x2") + filterID, err := client.NewFilter(context.Background(), &types.FilterLogsQuery{ + FromBlock: &from, + ToBlock: &to, + Address: []types.Address{types.MustAddressFromHex("0x3333333333333333333333333333333333333333")}, + Topics: [][]types.Hash{ + {types.MustHashFromHex("0x4444444444444444444444444444444444444444444444444444444444444444", types.PadNone)}, + }, + }) + + require.NoError(t, err) + assert.JSONEq(t, mockNewFilterRequest, readBody(httpMock.Request)) + assert.Equal(t, big.NewInt(1), filterID) +} + +const mockNewBlockFilterRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_newBlockFilter", + "params": [] + } +` + +const mockNewBlockFilterResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_NewBlockFilter(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockNewBlockFilterResponse)), + } + + filterID, err := client.NewBlockFilter(context.Background()) + + require.NoError(t, err) + assert.JSONEq(t, mockNewBlockFilterRequest, readBody(httpMock.Request)) + assert.Equal(t, big.NewInt(1), filterID) +} + +const mockNewPendingTransactionFilterRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_newPendingTransactionFilter", + "params": [] + } +` + +const mockNewPendingTransactionFilterResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": "0x1" + } +` + +func TestBaseClient_NewPendingTransactionFilter(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockNewPendingTransactionFilterResponse)), + } + + filterID, err := client.NewPendingTransactionFilter(context.Background()) + + require.NoError(t, err) + assert.JSONEq(t, mockNewPendingTransactionFilterRequest, readBody(httpMock.Request)) + assert.Equal(t, big.NewInt(1), filterID) +} + +const mockUninstallFilterRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_uninstallFilter", + "params": ["0x1"] + } +` + +const mockUninstallFilterResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": true + } +` + +func TestBaseClient_UninstallFilter(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockUninstallFilterResponse)), + } + + filterID := big.NewInt(1) + success, err := client.UninstallFilter(context.Background(), filterID) + + require.NoError(t, err) + assert.JSONEq(t, mockUninstallFilterRequest, readBody(httpMock.Request)) + assert.True(t, success) +} + +const mockGetFilterChangesRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getFilterChanges", + "params": ["0x1"] + } +` + +const mockGetFilterChangesResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": [ + { + "address": "0x1111111111111111111111111111111111111111", + "topics": ["0x2222222222222222222222222222222222222222222222222222222222222222"], + "data": "0x3333333333333333333333333333333333333333333333333333333333333333", + "blockNumber": "0x44444", + "transactionHash": "0x5555555555555555555555555555555555555555555555555555555555555555", + "transactionIndex": "0x66", + "blockHash": "0x7777777777777777777777777777777777777777777777777777777777777777", + "logIndex": "0x88", + "removed": false + } + ] + } +` + +func TestBaseClient_GetFilterChanges(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockGetFilterChangesResponse)), + } + + filterID := big.NewInt(1) + logs, err := client.GetFilterChanges(context.Background(), filterID) + + require.NoError(t, err) + assert.JSONEq(t, mockGetFilterChangesRequest, readBody(httpMock.Request)) + assert.Len(t, logs, 1) + assert.Equal(t, types.MustAddressFromHex("0x1111111111111111111111111111111111111111"), logs[0].Address) + assert.Equal(t, []types.Hash{types.MustHashFromHex("0x2222222222222222222222222222222222222222222222222222222222222222", types.PadNone)}, logs[0].Topics) + assert.Equal(t, hexutil.MustHexToBytes("0x3333333333333333333333333333333333333333333333333333333333333333"), logs[0].Data) + assert.Equal(t, types.MustBlockNumberFromHexPtr("0x44444").Big(), logs[0].BlockNumber) + assert.Equal(t, types.MustHashFromHexPtr("0x5555555555555555555555555555555555555555555555555555555555555555", types.PadNone), logs[0].TransactionHash) + assert.Equal(t, uint64(0x66), *logs[0].TransactionIndex) + assert.Equal(t, types.MustHashFromHexPtr("0x7777777777777777777777777777777777777777777777777777777777777777", types.PadNone), logs[0].BlockHash) + assert.Equal(t, uint64(0x88), *logs[0].LogIndex) + assert.False(t, logs[0].Removed) +} + +const mockGetBlockFilterChangesRequest = ` + { + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getFilterChanges", + "params": ["0x1"] + } +` + +const mockGetBlockFilterChangesResponse = ` + { + "jsonrpc": "2.0", + "id": 1, + "result": ["0x1111111111111111111111111111111111111111111111111111111111111111"] + } +` + +func TestBaseClient_GetBlockFilterChanges(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockGetBlockFilterChangesResponse)), + } + + filterID := big.NewInt(1) + blockHashes, err := client.GetBlockFilterChanges(context.Background(), filterID) + + require.NoError(t, err) + assert.JSONEq(t, mockGetBlockFilterChangesRequest, readBody(httpMock.Request)) + assert.Len(t, blockHashes, 1) + assert.Equal(t, types.MustHashFromHex("0x1111111111111111111111111111111111111111111111111111111111111111", types.PadNone), blockHashes[0]) +} + +const mockGetFilterLogsRequest = ` +{ + "jsonrpc": "2.0", + "id": 1, + "method": "eth_getFilterLogs", + "params": ["0x1"] +} +` + +const mockGetFilterLogsResponse = ` +{ + "jsonrpc": "2.0", + "id": 1, + "result": [ + { + "address": "0x1111111111111111111111111111111111111111", + "topics": ["0x2222222222222222222222222222222222222222222222222222222222222222"], + "data": "0x3333333333333333333333333333333333333333333333333333333333333333", + "blockNumber": "0x1", + "transactionHash": "0x5555555555555555555555555555555555555555555555555555555555555555", + "transactionIndex": "0x0", + "blockHash": "0x7777777777777777777777777777777777777777777777777777777777777777", + "logIndex": "0x1", + "removed": false + } + ] +} +` + +func TestBaseClient_GetFilterLogs(t *testing.T) { + httpMock := newHTTPMock() + client := &baseClient{transport: httpMock} + + httpMock.ResponseMock = &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(mockGetFilterLogsResponse)), + } + + filterID := big.NewInt(1) + logs, err := client.GetFilterLogs(context.Background(), filterID) + + require.NoError(t, err) + assert.JSONEq(t, mockGetFilterLogsRequest, readBody(httpMock.Request)) + assert.Len(t, logs, 1) + assert.Equal(t, types.MustAddressFromHex("0x1111111111111111111111111111111111111111"), logs[0].Address) + assert.Equal(t, []types.Hash{types.MustHashFromHex("0x2222222222222222222222222222222222222222222222222222222222222222", types.PadNone)}, logs[0].Topics) + assert.Equal(t, hexutil.MustHexToBytes("0x3333333333333333333333333333333333333333333333333333333333333333"), logs[0].Data) + assert.Equal(t, big.NewInt(1), logs[0].BlockNumber) + assert.Equal(t, types.MustHashFromHexPtr("0x5555555555555555555555555555555555555555555555555555555555555555", types.PadNone), logs[0].TransactionHash) + assert.Equal(t, uint64(0), *logs[0].TransactionIndex) + assert.Equal(t, types.MustHashFromHexPtr("0x7777777777777777777777777777777777777777777777777777777777777777", types.PadNone), logs[0].BlockHash) + assert.Equal(t, uint64(1), *logs[0].LogIndex) + assert.False(t, logs[0].Removed) +} + const mockGetLogsRequest = ` { "jsonrpc": "2.0", @@ -1169,7 +1797,7 @@ func TestBaseClient_GetLogs(t *testing.T) { from := types.MustBlockNumberFromHex("0x1") to := types.MustBlockNumberFromHex("0x2") - logs, err := client.GetLogs(context.Background(), types.FilterLogsQuery{ + logs, err := client.GetLogs(context.Background(), &types.FilterLogsQuery{ FromBlock: &from, ToBlock: &to, Address: []types.Address{types.MustAddressFromHex("0x3333333333333333333333333333333333333333")}, @@ -1245,7 +1873,7 @@ func TestBaseClient_SubscribeLogs(t *testing.T) { // Mock subscribe response rawCh := make(chan json.RawMessage) - query := types.FilterLogsQuery{ + query := &types.FilterLogsQuery{ FromBlock: types.BlockNumberFromUint64Ptr(1), ToBlock: types.BlockNumberFromUint64Ptr(2), Address: []types.Address{types.MustAddressFromHex("0x3333333333333333333333333333333333333333")}, diff --git a/rpc/client.go b/rpc/client.go index ab31c3a..5737513 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -13,9 +13,8 @@ import ( type Client struct { baseClient - keys []wallet.Key + keys map[types.Address]wallet.Key defaultAddr *types.Address - chainID *uint64 txModifiers []TXModifier } @@ -42,8 +41,7 @@ func WithTransport(transport transport.Transport) ClientOptions { } // WithKeys allows to set keys that will be used to sign data. -// It allows to emulate the behavior of the RPC methods that require a key -// to be available in the node. +// It allows to emulate the behavior of the RPC methods that require a key. // // The following methods are affected: // - Accounts - returns the addresses of the provided keys @@ -53,15 +51,19 @@ func WithTransport(transport transport.Transport) ClientOptions { // using SendRawTransaction func WithKeys(keys ...wallet.Key) ClientOptions { return func(c *Client) error { - c.keys = keys + for _, k := range keys { + c.keys[k.Address()] = k + } return nil } } -// WithDefaultAddress sets the transaction.From address if it is not set -// in the following methods: -// - SignTransaction -// - SendTransaction +// WithDefaultAddress sets the call "from" address if it is not set in the +// following methods: +// - SignTransaction +// - SendTransaction +// - Call +// - EstimateGas func WithDefaultAddress(addr types.Address) ClientOptions { return func(c *Client) error { c.defaultAddr = &addr @@ -69,20 +71,6 @@ func WithDefaultAddress(addr types.Address) ClientOptions { } } -// WithChainID sets the transaction.ChainID if it is not set in the following -// methods: -// - SignTransaction -// - SendTransaction -// -// If the transaction has a ChainID set, it will return an error if it does not -// match the provided chain ID. -func WithChainID(chainID uint64) ClientOptions { - return func(c *Client) error { - c.chainID = &chainID - return nil - } -} - // WithTXModifiers allows to modify the transaction before it is signed and // sent to the node. // @@ -97,7 +85,7 @@ func WithTXModifiers(modifiers ...TXModifier) ClientOptions { // NewClient creates a new RPC client. // The WithTransport option is required. func NewClient(opts ...ClientOptions) (*Client, error) { - c := &Client{} + c := &Client{keys: make(map[types.Address]wallet.Key)} for _, opt := range opts { if err := opt(c); err != nil { return nil, err @@ -127,73 +115,47 @@ func (c *Client) Sign(ctx context.Context, account types.Address, data []byte) ( return c.baseClient.Sign(ctx, account, data) } if key := c.findKey(&account); key != nil { - return key.SignMessage(data) + return key.SignMessage(ctx, data) } return nil, fmt.Errorf("rpc client: no key found for address %s", account) } // SignTransaction implements the RPC interface. -func (c *Client) SignTransaction(ctx context.Context, tx types.Transaction) ([]byte, *types.Transaction, error) { - txCpy := tx.Copy() - if txCpy.ChainID == nil && c.chainID != nil { - chainID := *c.chainID - txCpy.ChainID = &chainID - } - if txCpy.Call.From == nil && c.defaultAddr != nil { - defaultAddr := *c.defaultAddr - txCpy.Call.From = &defaultAddr - } - if err := c.verifyTXChainID(txCpy); err != nil { +func (c *Client) SignTransaction(ctx context.Context, tx *types.Transaction) ([]byte, *types.Transaction, error) { + tx, err := c.PrepareTransaction(ctx, tx) + if err != nil { return nil, nil, err } - for _, modifier := range c.txModifiers { - if err := modifier.Modify(ctx, c, txCpy); err != nil { - return nil, nil, err - } - } if len(c.keys) == 0 { - return c.baseClient.SignTransaction(ctx, *txCpy) + return c.baseClient.SignTransaction(ctx, tx) } - if key := c.findKey(txCpy.Call.From); key != nil { - if err := key.SignTransaction(txCpy); err != nil { + if key := c.findKey(tx.Call.From); key != nil { + if err := key.SignTransaction(ctx, tx); err != nil { return nil, nil, err } - raw, err := txCpy.Raw() + raw, err := tx.Raw() if err != nil { return nil, nil, err } - return raw, txCpy, nil + return raw, tx, nil } return nil, nil, fmt.Errorf("rpc client: no key found for address %s", tx.Call.From) } // SendTransaction implements the RPC interface. -func (c *Client) SendTransaction(ctx context.Context, tx types.Transaction) (*types.Hash, *types.Transaction, error) { - txCpy := tx.Copy() - if txCpy.ChainID == nil && c.chainID != nil { - chainID := *c.chainID - txCpy.ChainID = &chainID - } - if txCpy.Call.From == nil && c.defaultAddr != nil { - defaultAddr := *c.defaultAddr - txCpy.Call.From = &defaultAddr - } - if err := c.verifyTXChainID(txCpy); err != nil { +func (c *Client) SendTransaction(ctx context.Context, tx *types.Transaction) (*types.Hash, *types.Transaction, error) { + tx, err := c.PrepareTransaction(ctx, tx) + if err != nil { return nil, nil, err } - for _, modifier := range c.txModifiers { - if err := modifier.Modify(ctx, c, txCpy); err != nil { - return nil, nil, err - } - } if len(c.keys) == 0 { - return c.baseClient.SendTransaction(ctx, *txCpy) + return c.baseClient.SendTransaction(ctx, tx) } - if key := c.findKey(txCpy.Call.From); key != nil { - if err := key.SignTransaction(txCpy); err != nil { + if key := c.findKey(tx.Call.From); key != nil { + if err := key.SignTransaction(ctx, tx); err != nil { return nil, nil, err } - raw, err := txCpy.Raw() + raw, err := tx.Raw() if err != nil { return nil, nil, err } @@ -201,42 +163,56 @@ func (c *Client) SendTransaction(ctx context.Context, tx types.Transaction) (*ty if err != nil { return nil, nil, err } - return txHash, txCpy, nil + return txHash, tx, nil } return nil, nil, fmt.Errorf("rpc client: no key found for address %s", tx.Call.From) } +// PrepareTransaction prepares the transaction by applying transaction +// modifiers and setting the default address if it is not set. +// +// A copy of the modified transaction is returned. +func (c *Client) PrepareTransaction(ctx context.Context, tx *types.Transaction) (*types.Transaction, error) { + if tx == nil { + return nil, fmt.Errorf("rpc client: transaction is nil") + } + txCpy := tx.Copy() + if txCpy.Call.From == nil && c.defaultAddr != nil { + defaultAddr := *c.defaultAddr + txCpy.Call.From = &defaultAddr + } + for _, modifier := range c.txModifiers { + if err := modifier.Modify(ctx, c, txCpy); err != nil { + return nil, err + } + } + return txCpy, nil +} + // Call implements the RPC interface. -func (c *Client) Call(ctx context.Context, call types.Call, block types.BlockNumber) ([]byte, *types.Call, error) { +func (c *Client) Call(ctx context.Context, call *types.Call, block types.BlockNumber) ([]byte, *types.Call, error) { + if call == nil { + return nil, nil, fmt.Errorf("rpc client: call is nil") + } callCpy := call.Copy() if callCpy.From == nil && c.defaultAddr != nil { defaultAddr := *c.defaultAddr callCpy.From = &defaultAddr } - return c.baseClient.Call(ctx, *callCpy, block) + return c.baseClient.Call(ctx, callCpy, block) } // EstimateGas implements the RPC interface. -func (c *Client) EstimateGas(ctx context.Context, call types.Call, block types.BlockNumber) (uint64, error) { +func (c *Client) EstimateGas(ctx context.Context, call *types.Call, block types.BlockNumber) (uint64, *types.Call, error) { + if call == nil { + return 0, nil, fmt.Errorf("rpc client: call is nil") + } callCpy := call.Copy() if callCpy.From == nil && c.defaultAddr != nil { defaultAddr := *c.defaultAddr callCpy.From = &defaultAddr } - return c.baseClient.EstimateGas(ctx, *callCpy, block) -} - -// verifyTXChainID verifies that the transaction chain ID is set. If the client -// has a chain ID set, it will also verify that the transaction chain ID matches -// the client chain ID. -func (c *Client) verifyTXChainID(tx *types.Transaction) error { - if tx.ChainID == nil { - return fmt.Errorf("rpc client: transaction chain ID is not set") - } - if c.chainID != nil && *tx.ChainID != *c.chainID { - return fmt.Errorf("rpc client: transaction chain ID does not match") - } - return nil + return c.baseClient.EstimateGas(ctx, callCpy, block) } // findKey finds a key by address. @@ -244,10 +220,8 @@ func (c *Client) findKey(addr *types.Address) wallet.Key { if addr == nil { return nil } - for _, key := range c.keys { - if key.Address() == *addr { - return key - } + if key, ok := c.keys[*addr]; ok { + return key } return nil } diff --git a/rpc/client_test.go b/rpc/client_test.go index 64e10cc..302a5ab 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -17,8 +17,6 @@ import ( func TestClient_Sign(t *testing.T) { httpMock := newHTTPMock() keyMock := &keyMock{} - client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) - keyMock.addressCallback = func() types.Address { return types.MustAddressFromHex("0x1111111111111111111111111111111111111111") } @@ -26,6 +24,8 @@ func TestClient_Sign(t *testing.T) { return types.MustSignatureFromHexPtr("0xa3a7b12762dbc5df6cfbedbecdf8a821929c6112d2634abbb0d99dc63ad914908051b2c8c7d159db49ad19bd01026156eedab2f3d8c1dfdd07d21c07a4bbdd846f"), nil } + client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) + signature, err := client.Sign( context.Background(), types.MustAddressFromHex("0x1111111111111111111111111111111111111111"), @@ -38,8 +38,6 @@ func TestClient_Sign(t *testing.T) { func TestClient_SignTransaction(t *testing.T) { httpMock := newHTTPMock() keyMock := &keyMock{} - client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) - keyMock.addressCallback = func() types.Address { return types.MustAddressFromHex("0xb60e8dd61c5d32be8058bb8eb970870f07233155") } @@ -48,13 +46,15 @@ func TestClient_SignTransaction(t *testing.T) { return nil } + client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) + from := types.MustAddressFromHex("0xb60e8dd61c5d32be8058bb8eb970870f07233155") to := types.MustAddressFromHex("0xd46e8dd67c5d32be8058bb8eb970870f07244567") gasLimit := uint64(30400) chainID := uint64(1) raw, tx, err := client.SignTransaction( context.Background(), - types.Transaction{ + &types.Transaction{ ChainID: &chainID, Call: types.Call{ From: &from, @@ -82,8 +82,6 @@ func TestClient_SignTransaction(t *testing.T) { func TestClient_SendTransaction(t *testing.T) { httpMock := newHTTPMock() keyMock := &keyMock{} - client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) - keyMock.addressCallback = func() types.Address { return types.MustAddressFromHex("0xb60e8dd61c5d32be8058bb8eb970870f07233155") } @@ -92,6 +90,8 @@ func TestClient_SendTransaction(t *testing.T) { return nil } + client, _ := NewClient(WithTransport(httpMock), WithKeys(keyMock)) + httpMock.ResponseMock = &http.Response{ StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString(mockSendRawTransactionResponse)), @@ -106,7 +106,7 @@ func TestClient_SendTransaction(t *testing.T) { chainID := uint64(1) txHash, tx, err := client.SendTransaction( context.Background(), - types.Transaction{ + &types.Transaction{ ChainID: &chainID, Call: types.Call{ From: &from, @@ -144,7 +144,7 @@ func TestClient_Call(t *testing.T) { gasLimit := uint64(30400) _, _, err := client.Call( context.Background(), - types.Call{ + &types.Call{ From: nil, To: &to, GasLimit: &gasLimit, @@ -172,9 +172,9 @@ func TestClient_EstimateGas(t *testing.T) { to := types.MustAddressFromHex("0x2222222222222222222222222222222222222222") gasLimit := uint64(30400) - _, err := client.EstimateGas( + _, _, err := client.EstimateGas( context.Background(), - types.Call{ + &types.Call{ From: nil, To: &to, GasLimit: &gasLimit, diff --git a/rpc/mocks_test.go b/rpc/mocks_test.go index b4be6cc..24fdab1 100644 --- a/rpc/mocks_test.go +++ b/rpc/mocks_test.go @@ -99,22 +99,22 @@ func (k *keyMock) Address() types.Address { return k.addressCallback() } -func (k *keyMock) SignHash(hash types.Hash) (*types.Signature, error) { +func (k *keyMock) SignHash(ctx context.Context, hash types.Hash) (*types.Signature, error) { return k.signHashCallback(hash) } -func (k *keyMock) SignMessage(data []byte) (*types.Signature, error) { +func (k *keyMock) SignMessage(ctx context.Context, data []byte) (*types.Signature, error) { return k.signMessageCallback(data) } -func (k *keyMock) SignTransaction(tx *types.Transaction) error { +func (k *keyMock) SignTransaction(ctx context.Context, tx *types.Transaction) error { return k.signTransactionCallback(tx) } -func (k *keyMock) VerifyHash(hash types.Hash, sig types.Signature) bool { +func (k *keyMock) VerifyHash(ctx context.Context, hash types.Hash, sig types.Signature) bool { return false } -func (k keyMock) VerifyMessage(data []byte, sig types.Signature) bool { +func (k keyMock) VerifyMessage(ctx context.Context, data []byte, sig types.Signature) bool { return false } diff --git a/rpc/rpc.go b/rpc/rpc.go index 76214e9..1c5c03c 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -9,16 +9,35 @@ import ( // RPC is an RPC client for the Ethereum-compatible nodes. type RPC interface { - // TODO: web3_clientVersion - // TODO: web3_sha3 - // TODO: net_version - // TODO: net_listening - // TODO: net_peerCount - // TODO: eth_protocolVersion - // TODO: eth_syncing - // TODO: eth_coinbase - // TODO: eth_mining - // TODO: eth_hashrate + // ClientVersion performs web3_clientVersion RPC call. + // + // It returns the current client version. + ClientVersion(ctx context.Context) (string, error) + + // Listening performs net_listening RPC call. + // + // It returns true if the client is actively listening for network. + Listening(ctx context.Context) (bool, error) + + // PeerCount performs net_peerCount RPC call. + // + // It returns the number of connected peers. + PeerCount(ctx context.Context) (uint64, error) + + // ProtocolVersion performs eth_protocolVersion RPC call. + // + // It returns the current Ethereum protocol version. + ProtocolVersion(ctx context.Context) (uint64, error) + + // Syncing performs eth_syncing RPC call. + // + // It returns an object with data about the sync status or false. + Syncing(ctx context.Context) (*types.SyncStatus, error) + + // NetworkID performs net_version RPC call. + // + // It returns the current network ID. + NetworkID(ctx context.Context) (uint64, error) // ChainID performs eth_chainId RPC call. // @@ -91,14 +110,14 @@ type RPC interface { // It signs the given transaction. // // If transaction was internally mutated, the mutated call is returned. - SignTransaction(ctx context.Context, tx types.Transaction) ([]byte, *types.Transaction, error) + SignTransaction(ctx context.Context, tx *types.Transaction) ([]byte, *types.Transaction, error) // SendTransaction performs eth_sendTransaction RPC call. // // It sends a transaction to the network. // // If transaction was internally mutated, the mutated call is returned. - SendTransaction(ctx context.Context, tx types.Transaction) (*types.Hash, *types.Transaction, error) + SendTransaction(ctx context.Context, tx *types.Transaction) (*types.Hash, *types.Transaction, error) // SendRawTransaction performs eth_sendRawTransaction RPC call. // @@ -111,12 +130,14 @@ type RPC interface { // transaction on the blockchain. // // If call was internally mutated, the mutated call is returned. - Call(ctx context.Context, call types.Call, block types.BlockNumber) ([]byte, *types.Call, error) + Call(ctx context.Context, call *types.Call, block types.BlockNumber) ([]byte, *types.Call, error) // EstimateGas performs eth_estimateGas RPC call. // // It estimates the gas necessary to execute a specific transaction. - EstimateGas(ctx context.Context, call types.Call, block types.BlockNumber) (uint64, error) + // + // If call was internally mutated, the mutated call is returned. + EstimateGas(ctx context.Context, call *types.Call, block types.BlockNumber) (uint64, *types.Call, error) // BlockByHash performs eth_getBlockByHash RPC call. // @@ -148,27 +169,64 @@ type RPC interface { // It returns the receipt of a transaction by transaction hash. GetTransactionReceipt(ctx context.Context, hash types.Hash) (*types.TransactionReceipt, error) - // TODO: eth_getUncleByBlockHashAndIndex - // TODO: eth_getUncleByBlockNumberAndIndex - // TODO: eth_getCompilers - // TODO: eth_compileSolidity - // TODO: eth_compileLLL - // TODO: eth_compileSerpent - // TODO: eth_newFilter - // TODO: eth_newBlockFilter - // TODO: eth_newPendingTransactionFilter - // TODO: eth_uninstallFilter - // TODO: eth_getFilterChanges - // TODO: eth_getFilterLogs + // GetBlockReceipts performs eth_getBlockReceipts RPC call. + // + // It returns all transaction receipts for a given block hash or number. + GetBlockReceipts(ctx context.Context, block types.BlockNumber) ([]*types.TransactionReceipt, error) + + // GetUncleByBlockHashAndIndex performs eth_getUncleByBlockNumberAndIndex RPC call. + // + // It returns information about an uncle of a block by number and uncle index position. + GetUncleByBlockHashAndIndex(ctx context.Context, hash types.Hash, index uint64) (*types.Block, error) + + // GetUncleByBlockNumberAndIndex performs eth_getUncleByBlockNumberAndIndex RPC call. + // + // It returns information about an uncle of a block by hash and uncle index position. + GetUncleByBlockNumberAndIndex(ctx context.Context, number types.BlockNumber, index uint64) (*types.Block, error) + + // NewFilter performs eth_newFilter RPC call. + // + // It creates a filter object based on the given filter options. To check + // if the state has changed, use GetFilterChanges. + NewFilter(ctx context.Context, query *types.FilterLogsQuery) (*big.Int, error) + + // NewBlockFilter performs eth_newBlockFilter RPC call. + // + // It creates a filter in the node, to notify when a new block arrives. To + // check if the state has changed, use GetBlockFilterChanges. + NewBlockFilter(ctx context.Context) (*big.Int, error) + + // NewPendingTransactionFilter performs eth_newPendingTransactionFilter RPC call. + // + // It creates a filter in the node, to notify when new pending transactions + // arrive. To check if the state has changed, use GetFilterChanges. + NewPendingTransactionFilter(ctx context.Context) (*big.Int, error) + + // UninstallFilter performs eth_uninstallFilter RPC call. + // + // It uninstalls a filter with given ID. Should always be called when watch + // is no longer needed. + UninstallFilter(ctx context.Context, id *big.Int) (bool, error) + + // GetFilterChanges performs eth_getFilterChanges RPC call. + // + // It returns an array of logs that occurred since the given filter ID. + GetFilterChanges(ctx context.Context, id *big.Int) ([]types.Log, error) + + // GetBlockFilterChanges performs eth_getFilterChanges RPC call. + // + // It returns an array of block hashes that occurred since the given filter ID. + GetBlockFilterChanges(ctx context.Context, id *big.Int) ([]types.Hash, error) + + // GetFilterLogs performs eth_getFilterLogs RPC call. + // + // It returns an array of all logs matching filter with given ID. + GetFilterLogs(ctx context.Context, id *big.Int) ([]types.Log, error) // GetLogs performs eth_getLogs RPC call. // // It returns logs that match the given query. - GetLogs(ctx context.Context, query types.FilterLogsQuery) ([]types.Log, error) - - // TODO: eth_getWork - // TODO: eth_submitWork - // TODO: eth_submitHashrate + GetLogs(ctx context.Context, query *types.FilterLogsQuery) ([]types.Log, error) // MaxPriorityFeePerGas performs eth_maxPriorityFeePerGas RPC call. // @@ -181,7 +239,7 @@ type RPC interface { // It creates a subscription that will send logs that match the given query. // // Subscription channel will be closed when the context is canceled. - SubscribeLogs(ctx context.Context, query types.FilterLogsQuery) (chan types.Log, error) + SubscribeLogs(ctx context.Context, query *types.FilterLogsQuery) (<-chan types.Log, error) // SubscribeNewHeads performs eth_subscribe RPC call with "newHeads" // subscription type. @@ -189,7 +247,7 @@ type RPC interface { // It creates a subscription that will send new block headers. // // Subscription channel will be closed when the context is canceled. - SubscribeNewHeads(ctx context.Context) (chan types.Block, error) + SubscribeNewHeads(ctx context.Context) (<-chan types.Block, error) // SubscribeNewPendingTransactions performs eth_subscribe RPC call with // "newPendingTransactions" subscription type. @@ -197,5 +255,5 @@ type RPC interface { // It creates a subscription that will send new pending transactions. // // Subscription channel will be closed when the context is canceled. - SubscribeNewPendingTransactions(ctx context.Context) (chan types.Hash, error) + SubscribeNewPendingTransactions(ctx context.Context) (<-chan types.Hash, error) } diff --git a/rpc/transport/stream.go b/rpc/transport/stream.go index 5650b53..79d67c6 100644 --- a/rpc/transport/stream.go +++ b/rpc/transport/stream.go @@ -53,15 +53,17 @@ func (s *stream) Call(ctx context.Context, result any, method string, args ...an return fmt.Errorf("failed to create RPC request: %w", err) } + // Prepare the channel for the response. + ch := make(chan rpcResponse) + s.addCallCh(id, ch) + defer s.delCallCh(id) + // Send the request. s.writerCh <- req // Wait for the response. // The response is handled by the streamRoutine. It will send the response // to the ch channel. - ch := make(chan rpcResponse) - s.addCallCh(id, ch) - defer s.delCallCh(id) select { case res := <-ch: if res.Error != nil { diff --git a/txmodifier/chainid.go b/txmodifier/chainid.go new file mode 100644 index 0000000..9cc9bee --- /dev/null +++ b/txmodifier/chainid.go @@ -0,0 +1,82 @@ +package txmodifier + +import ( + "context" + "fmt" + "sync" + + "github.com/defiweb/go-eth/rpc" + "github.com/defiweb/go-eth/types" +) + +// ChainIDProvider is a transaction modifier that sets the chain ID of the +// transaction. +// +// To use this modifier, add it using the WithTXModifiers option when creating +// a new rpc.Client. +type ChainIDProvider struct { + mu sync.Mutex + chainID uint64 + replace bool + cache bool +} + +// ChainIDProviderOptions is the options for NewChainIDProvider. +type ChainIDProviderOptions struct { + // ChainID is the chain ID that will be set for the transaction. + // If 0, the chain ID will be queried from the node. + ChainID uint64 + + // Replace is true if the transaction chain ID should be replaced even if + // it is already set. + Replace bool + + // Cache is true if the chain ID will be cached instead of being queried + // for each transaction. Cached chain ID will be used for all RPC clients + // that use the same ChainIDProvider instance. + // + // If ChainID is set, this option is ignored. + Cache bool +} + +// NewChainIDProvider returns a new ChainIDProvider. +func NewChainIDProvider(opts ChainIDProviderOptions) *ChainIDProvider { + if opts.ChainID != 0 { + opts.Cache = true + } + return &ChainIDProvider{ + chainID: opts.ChainID, + replace: opts.Replace, + cache: opts.Cache, + } +} + +// Modify implements the rpc.TXModifier interface. +func (p *ChainIDProvider) Modify(ctx context.Context, client rpc.RPC, tx *types.Transaction) error { + if !p.replace && tx.ChainID != nil { + return nil + } + if !p.cache { + chainID, err := client.ChainID(ctx) + if err != nil { + return fmt.Errorf("chain ID provider: %w", err) + } + tx.ChainID = &chainID + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + var cid uint64 + if p.chainID != 0 { + cid = p.chainID + } else { + chainID, err := client.ChainID(ctx) + if err != nil { + return fmt.Errorf("chain ID provider: %w", err) + } + p.chainID = chainID + cid = chainID + } + tx.ChainID = &cid + return nil +} diff --git a/txmodifier/chainid_test.go b/txmodifier/chainid_test.go new file mode 100644 index 0000000..c67189c --- /dev/null +++ b/txmodifier/chainid_test.go @@ -0,0 +1,89 @@ +package txmodifier + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/defiweb/go-eth/types" +) + +func TestChainIDSetter_Modify(t *testing.T) { + ctx := context.Background() + fromAddress := types.MustAddressFromHex("0x1234567890abcdef1234567890abcdef12345678") + + t.Run("cache chain ID", func(t *testing.T) { + tx := &types.Transaction{Call: types.Call{From: &fromAddress}} + rpcMock := new(mockRPC) + + provider := NewChainIDProvider(ChainIDProviderOptions{ + ChainID: 1, + }) + _ = provider.Modify(ctx, rpcMock, tx) + + assert.Equal(t, uint64(1), *tx.ChainID) + }) + + t.Run("query RPC node", func(t *testing.T) { + tx := &types.Transaction{Call: types.Call{From: &fromAddress}} + rpcMock := new(mockRPC) + rpcMock.On("ChainID", ctx).Return(uint64(1), nil) + + provider := NewChainIDProvider(ChainIDProviderOptions{ + Replace: false, + Cache: false, + }) + err := provider.Modify(ctx, rpcMock, tx) + + assert.NoError(t, err) + assert.Equal(t, uint64(1), *tx.ChainID) + }) + + t.Run("replace chain ID", func(t *testing.T) { + tx := &types.Transaction{Call: types.Call{From: &fromAddress}, ChainID: uint64Ptr(2)} + rpcMock := new(mockRPC) + rpcMock.On("ChainID", ctx).Return(uint64(1), nil) + + provider := NewChainIDProvider(ChainIDProviderOptions{ + Replace: true, + Cache: false, + }) + err := provider.Modify(ctx, rpcMock, tx) + + assert.NoError(t, err) + assert.NotEqual(t, uint64(2), *tx.ChainID) + }) + + t.Run("do not replace chain ID", func(t *testing.T) { + tx := &types.Transaction{Call: types.Call{From: &fromAddress}, ChainID: uint64Ptr(2)} + rpcMock := new(mockRPC) + rpcMock.On("ChainID", ctx).Return(uint64(1), nil) + + provider := NewChainIDProvider(ChainIDProviderOptions{ + Replace: false, + Cache: false, + }) + err := provider.Modify(ctx, rpcMock, tx) + + assert.NoError(t, err) + assert.NotEqual(t, uint64(1), *tx.ChainID) + }) + + t.Run("cache chain ID", func(t *testing.T) { + tx := &types.Transaction{Call: types.Call{From: &fromAddress}, ChainID: uint64Ptr(2)} + rpcMock := new(mockRPC) + rpcMock.On("ChainID", ctx).Return(uint64(1), nil).Once() + + provider := NewChainIDProvider(ChainIDProviderOptions{ + Replace: true, + Cache: true, + }) + _ = provider.Modify(ctx, rpcMock, tx) + _ = provider.Modify(ctx, rpcMock, tx) + }) +} + +func uint64Ptr(i uint64) *uint64 { + return &i +} diff --git a/txmodifier/gaslimit.go b/txmodifier/gaslimit.go index c10319c..f9c4891 100644 --- a/txmodifier/gaslimit.go +++ b/txmodifier/gaslimit.go @@ -44,7 +44,7 @@ func (e *GasLimitEstimator) Modify(ctx context.Context, client rpc.RPC, tx *type if !e.replace && tx.GasLimit != nil { return nil } - gasLimit, err := client.EstimateGas(ctx, tx.Call, types.LatestBlockNumber) + gasLimit, _, err := client.EstimateGas(ctx, &tx.Call, types.LatestBlockNumber) if err != nil { return fmt.Errorf("gas limit estimator: failed to estimate gas limit: %w", err) } diff --git a/txmodifier/gaslimit_test.go b/txmodifier/gaslimit_test.go index 0c11289..abc0f02 100644 --- a/txmodifier/gaslimit_test.go +++ b/txmodifier/gaslimit_test.go @@ -16,7 +16,7 @@ func TestGasLimitEstimator_Modify(t *testing.T) { t.Run("successful gas estimation", func(t *testing.T) { tx := &types.Transaction{} rpcMock := new(mockRPC) - rpcMock.On("EstimateGas", ctx, tx.Call, types.LatestBlockNumber).Return(uint64(1000), nil) + rpcMock.On("EstimateGas", ctx, &tx.Call, types.LatestBlockNumber).Return(uint64(1000), &tx.Call, nil) estimator := NewGasLimitEstimator(GasLimitEstimatorOptions{ Multiplier: 1.5, @@ -32,7 +32,7 @@ func TestGasLimitEstimator_Modify(t *testing.T) { t.Run("gas estimation error", func(t *testing.T) { tx := &types.Transaction{} rpcMock := new(mockRPC) - rpcMock.On("EstimateGas", ctx, tx.Call, types.LatestBlockNumber).Return(uint64(0), errors.New("rpc error")) + rpcMock.On("EstimateGas", ctx, &tx.Call, types.LatestBlockNumber).Return(uint64(0), &tx.Call, errors.New("rpc error")) estimator := NewGasLimitEstimator(GasLimitEstimatorOptions{ Multiplier: 1.5, @@ -48,7 +48,7 @@ func TestGasLimitEstimator_Modify(t *testing.T) { t.Run("gas out of range", func(t *testing.T) { tx := &types.Transaction{} rpcMock := new(mockRPC) - rpcMock.On("EstimateGas", ctx, tx.Call, types.LatestBlockNumber).Return(uint64(3000), nil) + rpcMock.On("EstimateGas", ctx, &tx.Call, types.LatestBlockNumber).Return(uint64(3000), &tx.Call, nil) estimator := NewGasLimitEstimator(GasLimitEstimatorOptions{ Multiplier: 1.5, diff --git a/txmodifier/txmodifier_test.go b/txmodifier/txmodifier_test.go index 0c4286d..71f2d71 100644 --- a/txmodifier/txmodifier_test.go +++ b/txmodifier/txmodifier_test.go @@ -15,11 +15,16 @@ type mockRPC struct { mock.Mock } -func (m *mockRPC) EstimateGas(ctx context.Context, call types.Call, block types.BlockNumber) (uint64, error) { - args := m.Called(ctx, call, block) +func (m *mockRPC) ChainID(ctx context.Context) (uint64, error) { + args := m.Called(ctx) return args.Get(0).(uint64), args.Error(1) } +func (m *mockRPC) EstimateGas(ctx context.Context, call *types.Call, block types.BlockNumber) (uint64, *types.Call, error) { + args := m.Called(ctx, call, block) + return args.Get(0).(uint64), call, args.Error(2) +} + func (m *mockRPC) GasPrice(ctx context.Context) (*big.Int, error) { args := m.Called(ctx) return args.Get(0).(*big.Int), args.Error(1) diff --git a/types/types.go b/types/types.go index 0a50ecf..4bcaf41 100644 --- a/types/types.go +++ b/types/types.go @@ -393,20 +393,24 @@ func (t *Hash) DecodeRLP(data []byte) (int, error) { type BlockNumber struct{ x big.Int } const ( - earliestBlockNumber = -1 - latestBlockNumber = -2 - pendingBlockNumber = -3 + earliestBlockNumber = -1 + latestBlockNumber = -2 + pendingBlockNumber = -3 + safeBlockNumber = -4 + finalizedBlockNumber = -5 ) var ( - EarliestBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(earliestBlockNumber)} - LatestBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(latestBlockNumber)} - PendingBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(pendingBlockNumber)} + EarliestBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(earliestBlockNumber)} + LatestBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(latestBlockNumber)} + PendingBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(pendingBlockNumber)} + SafeBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(safeBlockNumber)} + FinalizedBlockNumber = BlockNumber{x: *new(big.Int).SetInt64(finalizedBlockNumber)} ) // BlockNumberFromHex converts a string to a BlockNumber type. // The string can be a hex number or one of the following strings: -// "earliest", "latest", "pending". +// "earliest", "latest", "safe", "finalized", "pending". // If the string is not a valid block number, it returns an error. func BlockNumberFromHex(h string) (BlockNumber, error) { b := &BlockNumber{} @@ -416,7 +420,7 @@ func BlockNumberFromHex(h string) (BlockNumber, error) { // BlockNumberFromHexPtr converts a string to a *BlockNumber type. // The string can be a hex number or one of the following strings: -// "earliest", "latest", "pending". +// "earliest", "latest", "safe", "finalized", "pending". // If the string is not a valid block number, it returns nil. func BlockNumberFromHexPtr(h string) *BlockNumber { b, err := BlockNumberFromHex(h) @@ -428,7 +432,7 @@ func BlockNumberFromHexPtr(h string) *BlockNumber { // MustBlockNumberFromHex converts a string to a BlockNumber type. // The string can be a hex number or one of the following strings: -// "earliest", "latest", "pending". +// "earliest", "latest", "safe", "finalized", "pending". // It panics if the string is not a valid block number. func MustBlockNumberFromHex(h string) BlockNumber { b, err := BlockNumberFromHex(h) @@ -440,7 +444,7 @@ func MustBlockNumberFromHex(h string) BlockNumber { // MustBlockNumberFromHexPtr converts a string to a *BlockNumber type. // The string can be a hex number or one of the following strings: -// "earliest", "latest", "pending". +// "earliest", "latest", "safe", "finalized", "pending". // It panics if the string is not a valid block number. func MustBlockNumberFromHexPtr(h string) *BlockNumber { b := MustBlockNumberFromHex(h) @@ -487,6 +491,16 @@ func (t *BlockNumber) IsPending() bool { return t.Big().Int64() == pendingBlockNumber } +// IsSafe returns true if the block tag is "safe". +func (t *BlockNumber) IsSafe() bool { + return t.Big().Int64() == safeBlockNumber +} + +// IsFinalized returns true if the block tag is "finalized". +func (t *BlockNumber) IsFinalized() bool { + return t.Big().Int64() == finalizedBlockNumber +} + // IsTag returns true if the block tag is used. func (t *BlockNumber) IsTag() bool { return t.Big().Sign() < 0 @@ -506,6 +520,10 @@ func (t *BlockNumber) String() string { return "latest" case t.IsPending(): return "pending" + case t.IsSafe(): + return "safe" + case t.IsFinalized(): + return "finalized" default: return "0x" + t.x.Text(16) } @@ -531,6 +549,10 @@ func (t BlockNumber) MarshalText() ([]byte, error) { return []byte("latest"), nil case t.IsPending(): return []byte("pending"), nil + case t.IsSafe(): + return []byte("safe"), nil + case t.IsFinalized(): + return []byte("finalized"), nil default: return []byte(hexutil.BigIntToHex(&t.x)), nil } @@ -547,6 +569,12 @@ func (t *BlockNumber) UnmarshalText(input []byte) error { case "pending": *t = BlockNumber{x: *new(big.Int).SetInt64(pendingBlockNumber)} return nil + case "safe": + *t = BlockNumber{x: *new(big.Int).SetInt64(safeBlockNumber)} + return nil + case "finalized": + *t = BlockNumber{x: *new(big.Int).SetInt64(finalizedBlockNumber)} + return nil default: u, err := hexutil.HexToBigInt(string(input)) if err != nil { @@ -719,6 +747,33 @@ func (s Signature) IsZero() bool { return true } +// Equal returns true if the signature is equal to the given signature. +// +// Nil values are considered as zero. +func (s Signature) Equal(c Signature) bool { + sv, sr, ss := s.V, s.R, s.S + cv, cr, cs := c.V, c.R, c.S + if sv == nil { + sv = new(big.Int) + } + if sr == nil { + sr = new(big.Int) + } + if ss == nil { + ss = new(big.Int) + } + if cv == nil { + cv = new(big.Int) + } + if cr == nil { + cr = new(big.Int) + } + if cs == nil { + cs = new(big.Int) + } + return sv.Cmp(cv) == 0 && sr.Cmp(cr) == 0 && ss.Cmp(cs) == 0 +} + func (s Signature) Copy() *Signature { cpy := &Signature{} if s.V != nil { @@ -983,6 +1038,17 @@ func (b *Bytes) UnmarshalText(input []byte) error { return bytesUnmarshalText(input, (*[]byte)(b)) } +// +// SyncStatus type: +// + +// SyncStatus represents the sync status of a node. +type SyncStatus struct { + StartingBlock BlockNumber `json:"startingBlock"` + CurrentBlock BlockNumber `json:"currentBlock"` + HighestBlock BlockNumber `json:"highestBlock"` +} + // // Internal types: // diff --git a/types/types_test.go b/types/types_test.go index e79c1f1..b6e6c4a 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -317,13 +317,15 @@ func Test_AddressesType_Marshal(t *testing.T) { func Test_BlockNumberType_Unmarshal(t *testing.T) { tests := []struct { - arg string - want BlockNumber - wantErr bool - isTag bool - isEarliest bool - isLatest bool - isPending bool + arg string + want BlockNumber + wantErr bool + isTag bool + isEarliest bool + isLatest bool + isPending bool + isSafe bool + isFinalized bool }{ {arg: `"0x0"`, want: BlockNumberFromUint64(0)}, {arg: `"0xF"`, want: BlockNumberFromUint64(15)}, @@ -332,6 +334,8 @@ func Test_BlockNumberType_Unmarshal(t *testing.T) { {arg: `"earliest"`, want: EarliestBlockNumber, isTag: true, isEarliest: true}, {arg: `"latest"`, want: LatestBlockNumber, isTag: true, isLatest: true}, {arg: `"pending"`, want: PendingBlockNumber, isTag: true, isPending: true}, + {arg: `"safe"`, want: SafeBlockNumber, isTag: true, isSafe: true}, + {arg: `"finalized"`, want: FinalizedBlockNumber, isTag: true, isFinalized: true}, {arg: `"foo"`, wantErr: true}, {arg: `"0xZ"`, wantErr: true}, } @@ -348,6 +352,8 @@ func Test_BlockNumberType_Unmarshal(t *testing.T) { assert.Equal(t, tt.isEarliest, v.IsEarliest()) assert.Equal(t, tt.isLatest, v.IsLatest()) assert.Equal(t, tt.isPending, v.IsPending()) + assert.Equal(t, tt.isSafe, v.IsSafe()) + assert.Equal(t, tt.isFinalized, v.IsFinalized()) } }) } @@ -363,6 +369,8 @@ func Test_BlockNumberType_Marshal(t *testing.T) { {arg: EarliestBlockNumber, want: `"earliest"`}, {arg: LatestBlockNumber, want: `"latest"`}, {arg: PendingBlockNumber, want: `"pending"`}, + {arg: SafeBlockNumber, want: `"safe"`}, + {arg: FinalizedBlockNumber, want: `"finalized"`}, } for n, tt := range tests { t.Run(fmt.Sprintf("case-%d", n+1), func(t *testing.T) { @@ -373,6 +381,153 @@ func Test_BlockNumberType_Marshal(t *testing.T) { } } +func Test_SignatureType_Unmarshal(t *testing.T) { + tests := []struct { + arg string + want Signature + wantErr bool + }{ + { + arg: `"0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"`, + want: Signature{ + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + { + arg: `"0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"`, + want: Signature{ + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + { + arg: `"0x000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000021b"`, + want: Signature{ + V: big.NewInt(27), + R: big.NewInt(1), + S: big.NewInt(2), + }, + }, + } + for n, tt := range tests { + t.Run(fmt.Sprintf("case-%d", n+1), func(t *testing.T) { + v := &Signature{} + err := v.UnmarshalJSON([]byte(tt.arg)) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.True(t, tt.want.Equal(*v)) + } + }) + } +} + +func Test_SignatureType_Marshal(t *testing.T) { + tests := []struct { + signature Signature + want string + wantErr bool + }{ + { + signature: Signature{}, + want: `"0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"`, + }, + { + signature: Signature{ + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + want: `"0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"`, + }, + { + signature: Signature{ + V: big.NewInt(27), + R: big.NewInt(1), + S: big.NewInt(2), + }, + want: `"0x000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000021b"`, + }, + } + for n, tt := range tests { + t.Run(fmt.Sprintf("case-%d", n+1), func(t *testing.T) { + j, err := tt.signature.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, tt.want, string(j)) + }) + } +} + +func Test_SignatureType_Equal(t *testing.T) { + tests := []struct { + a, b Signature + want bool + }{ + { + a: Signature{}, + b: Signature{}, + want: true, + }, + { + a: Signature{}, + b: Signature{ + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + want: true, + }, + { + a: Signature{ + V: big.NewInt(0), + R: nil, + S: big.NewInt(0), + }, + b: Signature{ + V: nil, + R: big.NewInt(0), + S: big.NewInt(0), + }, + want: true, + }, + { + a: Signature{ + V: big.NewInt(27), + R: big.NewInt(1), + S: big.NewInt(2), + }, + b: Signature{ + V: big.NewInt(27), + R: big.NewInt(1), + S: big.NewInt(2), + }, + want: true, + }, + { + a: Signature{ + V: big.NewInt(27), + R: nil, + S: big.NewInt(2), + }, + b: Signature{ + V: nil, + R: big.NewInt(2), + S: big.NewInt(2), + }, + want: false, + }, + } + for n, tt := range tests { + t.Run(fmt.Sprintf("case-%d", n+1), func(t *testing.T) { + assert.Equal(t, tt.want, tt.a.Equal(tt.b)) + }) + } +} + func Test_BytesType_Unmarshal(t *testing.T) { tests := []struct { arg string diff --git a/wallet/key.go b/wallet/key.go index db9ab54..74e8f85 100644 --- a/wallet/key.go +++ b/wallet/key.go @@ -1,25 +1,35 @@ package wallet import ( + "context" + "github.com/defiweb/go-eth/types" ) +// Key is the interface for an Ethereum key. type Key interface { // Address returns the address of the key. Address() types.Address - // SignHash signs the given hash. - SignHash(hash types.Hash) (*types.Signature, error) - // SignMessage signs the given message. - SignMessage(data []byte) (*types.Signature, error) + SignMessage(ctx context.Context, data []byte) (*types.Signature, error) // SignTransaction signs the given transaction. - SignTransaction(tx *types.Transaction) error - - // VerifyHash whether the given hash is signed by the key. - VerifyHash(hash types.Hash, sig types.Signature) bool + SignTransaction(ctx context.Context, tx *types.Transaction) error // VerifyMessage verifies whether the given data is signed by the key. - VerifyMessage(data []byte, sig types.Signature) bool + VerifyMessage(ctx context.Context, data []byte, sig types.Signature) bool +} + +// KeyWithHashSigner is the interface for an Ethereum key that can sign data using +// a private key, skipping the EIP-191 message prefix. +type KeyWithHashSigner interface { + Key + + // SignHash signs the given hash without the EIP-191 message prefix. + SignHash(ctx context.Context, hash types.Hash) (*types.Signature, error) + + // VerifyHash whether the given hash is signed by the key without the + // EIP-191 message prefix. + VerifyHash(ctx context.Context, hash types.Hash, sig types.Signature) bool } diff --git a/wallet/key_priv.go b/wallet/key_priv.go index c2c59aa..3c4e371 100644 --- a/wallet/key_priv.go +++ b/wallet/key_priv.go @@ -1,6 +1,7 @@ package wallet import ( + "context" "crypto/ecdsa" "crypto/rand" "encoding/json" @@ -71,23 +72,23 @@ func (k *PrivateKey) Address() types.Address { return k.address } -// SignHash implements the Key interface. -func (k *PrivateKey) SignHash(hash types.Hash) (*types.Signature, error) { +// SignHash implements the KeyWithHashSigner interface. +func (k *PrivateKey) SignHash(_ context.Context, hash types.Hash) (*types.Signature, error) { return k.sign.SignHash(hash) } // SignMessage implements the Key interface. -func (k *PrivateKey) SignMessage(data []byte) (*types.Signature, error) { +func (k *PrivateKey) SignMessage(_ context.Context, data []byte) (*types.Signature, error) { return k.sign.SignMessage(data) } // SignTransaction implements the Key interface. -func (k *PrivateKey) SignTransaction(tx *types.Transaction) error { +func (k *PrivateKey) SignTransaction(_ context.Context, tx *types.Transaction) error { return k.sign.SignTransaction(tx) } -// VerifyHash implements the Key interface. -func (k *PrivateKey) VerifyHash(hash types.Hash, sig types.Signature) bool { +// VerifyHash implements the KeyWithHashSigner interface. +func (k *PrivateKey) VerifyHash(_ context.Context, hash types.Hash, sig types.Signature) bool { addr, err := k.recover.RecoverHash(hash, sig) if err != nil { return false @@ -96,7 +97,7 @@ func (k *PrivateKey) VerifyHash(hash types.Hash, sig types.Signature) bool { } // VerifyMessage implements the Key interface. -func (k *PrivateKey) VerifyMessage(data []byte, sig types.Signature) bool { +func (k *PrivateKey) VerifyMessage(_ context.Context, data []byte, sig types.Signature) bool { addr, err := k.recover.RecoverMessage(data, sig) if err != nil { return false diff --git a/wallet/key_rpc.go b/wallet/key_rpc.go new file mode 100644 index 0000000..36c4851 --- /dev/null +++ b/wallet/key_rpc.go @@ -0,0 +1,60 @@ +package wallet + +import ( + "context" + + "github.com/defiweb/go-eth/crypto" + "github.com/defiweb/go-eth/types" +) + +// RPCSigningClient is the interface for an Ethereum RPC client that can +// sign messages and transactions. +type RPCSigningClient interface { + Sign(ctx context.Context, account types.Address, data []byte) (*types.Signature, error) + SignTransaction(ctx context.Context, tx *types.Transaction) ([]byte, *types.Transaction, error) +} + +// KeyRPC is an Ethereum key that uses an RPC client to sign messages and transactions. +type KeyRPC struct { + client RPCSigningClient + address types.Address + recover crypto.Recoverer +} + +// NewKeyRPC returns a new KeyRPC. +func NewKeyRPC(client RPCSigningClient, address types.Address) *KeyRPC { + return &KeyRPC{ + client: client, + address: address, + recover: crypto.ECRecoverer, + } +} + +// Address implements the Key interface. +func (k *KeyRPC) Address() types.Address { + return k.address +} + +// SignMessage implements the Key interface. +func (k *KeyRPC) SignMessage(ctx context.Context, data []byte) (*types.Signature, error) { + return k.client.Sign(ctx, k.address, data) +} + +// SignTransaction implements the Key interface. +func (k *KeyRPC) SignTransaction(ctx context.Context, tx *types.Transaction) error { + _, signedTX, err := k.client.SignTransaction(ctx, tx) + if err != nil { + return err + } + *tx = *signedTX + return err +} + +// VerifyMessage implements the Key interface. +func (k *KeyRPC) VerifyMessage(_ context.Context, data []byte, sig types.Signature) bool { + addr, err := k.recover.RecoverMessage(data, sig) + if err != nil { + return false + } + return *addr == k.address +}