diff --git a/cmd/ostracon/commands/show_validator.go b/cmd/ostracon/commands/show_validator.go index e19914b60..43b31d041 100644 --- a/cmd/ostracon/commands/show_validator.go +++ b/cmd/ostracon/commands/show_validator.go @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error if err != nil { return err } - pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger) + pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger) if err != nil { return err } diff --git a/cmd/ostracon/commands/show_validator_test.go b/cmd/ostracon/commands/show_validator_test.go index 9b4c43bf7..5f708d428 100644 --- a/cmd/ostracon/commands/show_validator_test.go +++ b/cmd/ostracon/commands/show_validator_test.go @@ -3,6 +3,7 @@ package commands import ( "bytes" "os" + "strings" "sync" "testing" @@ -79,6 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) { } privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) { config.PrivValidatorListenAddr = addr + config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")] require.NoFileExists(t, config.PrivValidatorKeyFile()) output, err := captureStdout(func() { err := showValidator(ShowValidatorCmd, nil, config) diff --git a/config/config.go b/config/config.go index a39a65ad9..052f01a85 100644 --- a/config/config.go +++ b/config/config.go @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned // TCP or UNIX socket address for Ostracon to listen on for // connections from an external PrivValidator process + // example) tcp://0.0.0.0:26659 PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"` + // Validator's remote address(without port) to allow a connection + // ostracon only allow a connection from this address + // example) 10.0.0.7 + PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"` + // A JSON file containing the private key to use for p2p authenticated encryption NodeKey string `mapstructure:"node_key_file"` diff --git a/config/toml.go b/config/toml.go index 2f52e4c2e..34dc59204 100644 --- a/config/toml.go +++ b/config/toml.go @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}" # TCP or UNIX socket address for Ostracon to listen on for # connections from an external PrivValidator process +# example) tcp://0.0.0.0:26659 priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}" +# Validator's remote address to allow a connection +# ostracon only allow a connection from this address +# example) 127.0.0.1 +priv_validator_raddr = "127.0.0.1" + # Path to the JSON file containing the private key to use for node authentication in the p2p protocol node_key_file = "{{ js .BaseConfig.NodeKey }}" diff --git a/node/node.go b/node/node.go index 5db7dc43f..7bad342bd 100644 --- a/node/node.go +++ b/node/node.go @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config, // external signing process. if config.PrivValidatorListenAddr != "" { // FIXME: we should start services inside OnStart - privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger) + privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger) if err != nil { return nil, fmt.Errorf("error with private validator socket client: %w", err) } @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error { return nil } -func CreateAndStartPrivValidatorSocketClient( - listenAddr, - chainID string, - logger log.Logger, -) (types.PrivValidator, error) { - pve, err := privval.NewSignerListener(listenAddr, logger) +func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) { + pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr) if err != nil { return nil, fmt.Errorf("failed to start private validator: %w", err) } diff --git a/node/node_test.go b/node/node_test.go index aba8fa313..f7d0e4f16 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -160,11 +160,17 @@ func TestNodeSetAppVersion(t *testing.T) { } func TestNodeSetPrivValTCP(t *testing.T) { + address := testFreeAddr(t) addr := "tcp://" + testFreeAddr(t) config := cfg.ResetTestRoot("node_priv_val_tcp_test") defer os.RemoveAll(config.RootDir) config.BaseConfig.PrivValidatorListenAddr = addr + addrPart, _, err := net.SplitHostPort(address) + if err != nil { + return + } + config.BaseConfig.PrivValidatorRemoteAddr = addrPart dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey()) dialerEndpoint := privval.NewSignerDialerEndpoint( diff --git a/privval/internal/conn_filter.go b/privval/internal/conn_filter.go new file mode 100644 index 000000000..405cb8b2b --- /dev/null +++ b/privval/internal/conn_filter.go @@ -0,0 +1,8 @@ +package internal + +import "net" + +type ConnectionFilter interface { + Filter(addr net.Addr) net.Addr + String() string +} diff --git a/privval/internal/ip_filter.go b/privval/internal/ip_filter.go new file mode 100644 index 000000000..e9f7cceb8 --- /dev/null +++ b/privval/internal/ip_filter.go @@ -0,0 +1,44 @@ +package internal + +import ( + "fmt" + "github.com/Finschia/ostracon/libs/log" + "net" +) + +type IpFilter struct { + allowAddr string + log log.Logger +} + +func NewIpFilter(addr string, l log.Logger) *IpFilter { + return &IpFilter{ + allowAddr: addr, + log: l, + } +} + +func (f *IpFilter) Filter(addr net.Addr) net.Addr { + if f.isAllowedAddr(addr) { + return addr + } + return nil +} + +func (f *IpFilter) String() string { + return f.allowAddr +} + +func (f *IpFilter) isAllowedAddr(addr net.Addr) bool { + if len(f.allowAddr) == 0 { + return false + } + hostAddr, _, err := net.SplitHostPort(addr.String()) + if err != nil { + if f.log != nil { + f.log.Error(fmt.Sprintf("IpFilter: can't split host and port from addr.String()=%s", addr.String())) + } + return false + } + return f.allowAddr == hostAddr +} diff --git a/privval/internal/ip_filter_test.go b/privval/internal/ip_filter_test.go new file mode 100644 index 000000000..6257fa7ca --- /dev/null +++ b/privval/internal/ip_filter_test.go @@ -0,0 +1,91 @@ +package internal + +import ( + "github.com/stretchr/testify/assert" + "net" + "testing" +) + +type addrStub struct { + address string +} + +func (a addrStub) Network() string { + return "" +} + +func (a addrStub) String() string { + return a.address +} + +func TestFilterRemoteConnectionByIP(t *testing.T) { + type fields struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + } + tests := []struct { + name string + fields fields + }{ + { + "should allow correct ip", + struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + }{"127.0.0.1", addrStub{"127.0.0.1:45678"}, addrStub{"127.0.0.1:45678"}}, + }, + { + "should not allow different ip", + struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + }{"127.0.0.1", addrStub{"10.0.0.2:45678"}, nil}, + }, + { + "should works for IPv6 with correct ip", + struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + }{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}}, + }, + { + "should works for IPv6 with incorrect ip", + struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + }{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, nil}, + }, + { + "empty allowIP should deny all", + struct { + allowIP string + remoteAddr net.Addr + expected net.Addr + }{"", addrStub{"127.0.0.1:45678"}, nil}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cut := NewIpFilter(tt.fields.allowIP, nil) + assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name) + }) + } +} + +func TestIpFilterShouldSetAllowAddress(t *testing.T) { + expected := "192.168.0.1" + + cut := NewIpFilter(expected, nil) + + assert.Equal(t, expected, cut.allowAddr) +} + +func TestIpFilterStringShouldReturnsIP(t *testing.T) { + expected := "127.0.0.1" + assert.Equal(t, expected, NewIpFilter(expected, nil).String()) +} diff --git a/privval/internal/null_object_filter.go b/privval/internal/null_object_filter.go new file mode 100644 index 000000000..9914df73d --- /dev/null +++ b/privval/internal/null_object_filter.go @@ -0,0 +1,19 @@ +package internal + +import "net" + +// NullObject is null object pattern. It does nothing +type NullObject struct { +} + +func NewNullObject() *NullObject { + return &NullObject{} +} + +func (n NullObject) Filter(addr net.Addr) net.Addr { + return addr +} + +func (n NullObject) String() string { + return "NullObject" +} diff --git a/privval/internal/null_object_filter_test.go b/privval/internal/null_object_filter_test.go new file mode 100644 index 000000000..6df8177f8 --- /dev/null +++ b/privval/internal/null_object_filter_test.go @@ -0,0 +1,40 @@ +package internal + +import ( + "github.com/stretchr/testify/assert" + "net" + "reflect" + "testing" +) + +func TestNullObject_filter(t *testing.T) { + stubInput := addrStub{} + tests := []struct { + name string + addr net.Addr + want net.Addr + }{ + { + name: "null object does nothing, returns what it receives", + addr: stubInput, + want: stubInput, + }, + { + name: "null object does nothing, returns nil it receives nil", + addr: nil, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NewNullObject() + if got := n.Filter(tt.addr); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Filter() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNullObjectString(t *testing.T) { + assert.Equal(t, "NullObject", NewNullObject().String()) +} diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 030b4de4d..c2c96642d 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -2,6 +2,7 @@ package privval import ( "fmt" + "github.com/Finschia/ostracon/privval/internal" "net" "time" @@ -24,6 +25,19 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout } } +// SignerListenerEndpointAllowAddress sets the address to allow +// connections from the only allowed address +// +func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption { + return func(sl *SignerListenerEndpoint) { + if protocol == "tcp" || len(protocol) == 0 { + sl.connFilter = internal.NewIpFilter(addr, sl.Logger) + return + } + sl.connFilter = internal.NewNullObject() + } +} + // SignerListenerEndpoint listens for an external process to dial in and keeps // the connection alive by dropping and reconnecting. // @@ -41,6 +55,7 @@ type SignerListenerEndpoint struct { pingInterval time.Duration instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest + connFilter internal.ConnectionFilter } // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. @@ -186,6 +201,12 @@ func (sl *SignerListenerEndpoint) serviceLoop() { { conn, err := sl.acceptNewConnection() if err == nil { + remoteAddr := conn.RemoteAddr() + if sl.filter(remoteAddr) == nil { + sl.Logger.Info(fmt.Sprintf("SignerListener: deny a connection request from remote address=%s, expected=%s", remoteAddr, sl.connFilter)) + conn.Close() + continue + } sl.Logger.Info("SignerListener: Connected") // We have a good connection, wait for someone that needs one otherwise cancellation @@ -207,6 +228,13 @@ func (sl *SignerListenerEndpoint) serviceLoop() { } } +func (sl *SignerListenerEndpoint) filter(addr net.Addr) net.Addr { + if sl.connFilter == nil { + return addr + } + return sl.connFilter.Filter(addr) +} + func (sl *SignerListenerEndpoint) pingLoop() { for { select { diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 27a707b74..317bab82f 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -1,6 +1,7 @@ package privval import ( + "github.com/Finschia/ostracon/privval/internal" "net" "testing" "time" @@ -213,3 +214,15 @@ func getMockEndpoints( return listenerEndpoint, dialerEndpoint } + +func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) { + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1")) + _, ok := cut.connFilter.(*internal.IpFilter) + assert.True(t, ok) +} + +func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) { + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01")) + _, ok := cut.connFilter.(*internal.NullObject) + assert.True(t, ok) +} diff --git a/privval/utils.go b/privval/utils.go index fe52ec7e2..34607235c 100644 --- a/privval/utils.go +++ b/privval/utils.go @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool { } // NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address -func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEndpoint, error) { +func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) { var listener net.Listener protocol, address := tmnet.ProtocolAndAddress(listenAddr) @@ -47,7 +47,7 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd ) } - pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener) + pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddr)) return pve, nil }