Skip to content

Commit

Permalink
fix(metarepos): check the NodeID for adding and removing peers
Browse files Browse the repository at this point in the history
This PR makes the MetadataRepository server validate the NodeID while processing
RPCs such as AddPeer and RemovePeer. When the given NodeID is invalid, it
returns the gRPC InvalidArgument status code.
In the case of AddPeer, it prevents unexpected peer addition that has an invalid
NodeID.
  • Loading branch information
ijsong committed Jul 30, 2024
1 parent 9cf865a commit 43003d4
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 38 deletions.
8 changes: 8 additions & 0 deletions internal/metarepos/raft_metadata_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,10 @@ func (mr *RaftMetadataRepository) Unseal(ctx context.Context, lsID types.LogStre
}

func (mr *RaftMetadataRepository) AddPeer(ctx context.Context, _ types.ClusterID, nodeID types.NodeID, url string) error {
if nodeID == types.InvalidNodeID {
return status.Error(codes.InvalidArgument, "invalid node id")
}

if mr.membership.IsMember(nodeID) ||
mr.membership.IsLearner(nodeID) {
return status.Errorf(codes.AlreadyExists, "node %d, addr:%s", nodeID, url)
Expand Down Expand Up @@ -1394,6 +1398,10 @@ func (mr *RaftMetadataRepository) AddPeer(ctx context.Context, _ types.ClusterID
}

func (mr *RaftMetadataRepository) RemovePeer(ctx context.Context, _ types.ClusterID, nodeID types.NodeID) error {
if nodeID == types.InvalidNodeID {
return status.Error(codes.InvalidArgument, "invalid node id")
}

if !mr.membership.IsMember(nodeID) &&
!mr.membership.IsLearner(nodeID) {
return status.Errorf(codes.NotFound, "node %d", nodeID)
Expand Down
174 changes: 174 additions & 0 deletions internal/metarepos/raft_metadata_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

"github.com/pkg/errors"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/raft/v3/raftpb"
"go.uber.org/goleak"
"go.uber.org/multierr"
Expand Down Expand Up @@ -3072,3 +3074,175 @@ func TestMain(m *testing.M) {
),
)
}

func TestMetadataRepository_AddPeer(t *testing.T) {
const clusterID = types.ClusterID(1)

tcs := []struct {
name string
testf func(t *testing.T, server *RaftMetadataRepository, client mrpb.ManagementClient)
}{
{
name: "InvalidNodeID",
testf: func(t *testing.T, _ *RaftMetadataRepository, client mrpb.ManagementClient) {
_, err := client.AddPeer(context.Background(), &mrpb.AddPeerRequest{
ClusterID: clusterID,
NodeID: types.InvalidNodeID,
Url: "http://127.0.0.1:11000",
})
require.Error(t, err)
require.Equal(t, codes.InvalidArgument, status.Code(err))
},
},
{
name: "AlreadyExists",
testf: func(t *testing.T, server *RaftMetadataRepository, client mrpb.ManagementClient) {
_, err := client.AddPeer(context.Background(), &mrpb.AddPeerRequest{
ClusterID: clusterID,
NodeID: server.nodeID,
Url: server.raftNode.url,
})
require.Error(t, err)
require.Equal(t, codes.AlreadyExists, status.Code(err))
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
portLease, err := ports.ReserveWeaklyWithRetry(10000)
require.NoError(t, err)

peer := fmt.Sprintf("http://127.0.0.1:%d", portLease.Base())
node := NewRaftMetadataRepository(
WithClusterID(clusterID),
WithRaftAddress(peer),
WithRPCAddress("127.0.0.1:0"),
WithRaftDirectory(t.TempDir()+"/raftdata"),
)
t.Cleanup(func() {
err := node.Close()
require.NoError(t, err)
})

node.Run()

// Wait for initialization
require.EventuallyWithT(t, func(collect *assert.CollectT) {
addr := node.endpointAddr.Load()
if !assert.NotNil(collect, addr) {
return
}
}, 3*time.Second, 100*time.Millisecond)
addr := node.endpointAddr.Load().(string)

require.EventuallyWithT(t, func(collect *assert.CollectT) {
conn, err := rpc.NewConn(context.Background(), addr)
assert.NoError(collect, err)
defer func() {
err := conn.Close()
assert.NoError(collect, err)
}()

healthClient := grpc_health_v1.NewHealthClient(conn.Conn)
_, err = healthClient.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{})
assert.NoError(collect, err)
}, 3*time.Second, 100*time.Millisecond)

conn, err := rpc.NewConn(context.Background(), addr)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
require.NoError(t, err)
})

client := mrpb.NewManagementClient(conn.Conn)
tc.testf(t, node, client)
})
}
}

func TestMetadataRepository_RemovePeer(t *testing.T) {
const clusterID = types.ClusterID(1)

tcs := []struct {
name string
testf func(t *testing.T, server *RaftMetadataRepository, client mrpb.ManagementClient)
}{
{
name: "InvalidNodeID",
testf: func(t *testing.T, _ *RaftMetadataRepository, client mrpb.ManagementClient) {
_, err := client.RemovePeer(context.Background(), &mrpb.RemovePeerRequest{
ClusterID: clusterID,
NodeID: types.InvalidNodeID,
})
require.Error(t, err)
require.Equal(t, codes.InvalidArgument, status.Code(err))
},
},
{
name: "NotFound",
testf: func(t *testing.T, server *RaftMetadataRepository, client mrpb.ManagementClient) {
_, err := client.RemovePeer(context.Background(), &mrpb.RemovePeerRequest{
ClusterID: clusterID,
NodeID: server.nodeID + 1,
})
require.Error(t, err)
require.Equal(t, codes.NotFound, status.Code(err))
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
portLease, err := ports.ReserveWeaklyWithRetry(10000)
require.NoError(t, err)

peer := fmt.Sprintf("http://127.0.0.1:%d", portLease.Base())
node := NewRaftMetadataRepository(
WithClusterID(clusterID),
WithRaftAddress(peer),
WithRPCAddress("127.0.0.1:0"),
WithRaftDirectory(t.TempDir()+"/raftdata"),
)
t.Cleanup(func() {
err := node.Close()
require.NoError(t, err)
})

node.Run()

// Wait for initialization
require.EventuallyWithT(t, func(collect *assert.CollectT) {
addr := node.endpointAddr.Load()
if !assert.NotNil(collect, addr) {
return
}
}, 3*time.Second, 100*time.Millisecond)
addr := node.endpointAddr.Load().(string)

require.EventuallyWithT(t, func(collect *assert.CollectT) {
conn, err := rpc.NewConn(context.Background(), addr)
assert.NoError(collect, err)
defer func() {
err := conn.Close()
assert.NoError(collect, err)
}()

healthClient := grpc_health_v1.NewHealthClient(conn.Conn)
_, err = healthClient.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{})
assert.NoError(collect, err)
}, 3*time.Second, 100*time.Millisecond)

conn, err := rpc.NewConn(context.Background(), addr)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
require.NoError(t, err)
})

client := mrpb.NewManagementClient(conn.Conn)
tc.testf(t, node, client)
})
}
}
70 changes: 45 additions & 25 deletions proto/mrpb/management.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 25 additions & 13 deletions proto/mrpb/management.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ option (gogoproto.goproto_sizecache_all) = false;

// AddPeerRequest is a request message for AddPeer RPC.
//
// TODO: TODO: Define a new message representing a new peer, such as "Peer" or
// TODO: Define a new message representing a new peer, such as "Peer" or
// "PeerInfo" and use it rather than primitive-type fields.
// See:
// - https://protobuf.dev/programming-guides/api/#dont-include-primitive-types
Expand Down Expand Up @@ -96,22 +96,34 @@ message GetClusterInfoResponse {

// Management service manages the Raft cluster of the Metadata Repository.
service Management {
// AddPeer is a remote procedure to add a new node to the Raft cluster. If the
// node is already a member or learner, it fails and returns the gRPC status
// code "AlreadyExists". Users can cancel this RPC, but it doesn't guarantee
// that adding a new peer is not handled.
// AddPeer adds a new node to the Raft cluster.
//
// TODO: Check if the cluster ID is the same as the current node's. If they
// are not the same, return a proper gRPC status code.
// It takes an AddPeerRequest as an argument and checks the validity of the
// given Node ID. If the Node ID is invalid, it returns a gRPC status code
// "InvalidArgument". If the node is already a member or learner, it returns a
// gRPC status code "AlreadyExists". Upon successful execution, this operation
// returns an instance of google.protobuf.Empty.
//
// Note that users can cancel this operation, but cancellation does not
// guarantee that the addition of a new peer will not be handled.
//
// TODO: Implement a check for the cluster ID.
rpc AddPeer(AddPeerRequest) returns (google.protobuf.Empty) {}
// RemovePeer is a remote procedure to remove a node from the Raft cluster. If
// the node is neither a member nor a learner of the cluster, it fails and
// returns the gRPC status code "NotFound". Users can cancel this RPC, but it
// doesn't guarantee that the node will not be removed.

// RemovePeer removes a specific node from a Raft cluster.
//
// TODO: Check if the cluster ID is the same as the current node's. If they
// are not the same, return a proper gRPC status code.
// It takes a RemovePeerRequest as an argument and checks the validity of the
// Node ID. If the Node ID is invalid, it returns a gRPC status code
// "InvalidArgument". If the node is neither a member nor a learner in the
// cluster, it returns a gRPC status code "NotFound". Upon successful
// execution, this operation returns an instance of google.protobuf.Empty.
//
// Note that although users can cancel this operation, cancellation does not
// guarantee that the node will not be removed.
//
// TODO: Implement a check for the cluster ID.
rpc RemovePeer(RemovePeerRequest) returns (google.protobuf.Empty) {}

// GetClusterInfo is a remote procedure used to retrieve information about the
// Raft cluster, specifically the ClusterInfo. If the current node is not a
// member of the cluster, it will fail and return the gRPC status code
Expand Down

0 comments on commit 43003d4

Please sign in to comment.