Skip to content

Commit

Permalink
fix: handle empty keys
Browse files Browse the repository at this point in the history
Handle empty keys, both when sent in RPC requests and in the local API.
  • Loading branch information
Stebalien committed Apr 7, 2020
1 parent 067f8ab commit c45167c
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 18 deletions.
48 changes: 48 additions & 0 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1910,3 +1910,51 @@ func TestProtocolUpgrade(t *testing.T) {
t.Fatalf("Expected 'buzz' got '%s'", string(value))
}
}

func TestInvalidKeys(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

nDHTs := 2
dhts := setupDHTS(t, ctx, nDHTs)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].host.Close()
}
}()

t.Logf("connecting %d dhts in a ring", nDHTs)
for i := 0; i < nDHTs; i++ {
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
}

querier := dhts[0]
_, err := querier.GetClosestPeers(ctx, "")
if err == nil {
t.Fatal("get closest peers should have failed")
}

_, err = querier.FindProviders(ctx, cid.Cid{})
switch err {
case routing.ErrNotFound, routing.ErrNotSupported, kb.ErrLookupFailure:
t.Fatal("failed with the wrong error: ", err)
case nil:
t.Fatal("find providers should have failed")
}

_, err = querier.FindPeer(ctx, peer.ID(""))
if err != peer.ErrEmptyPeerID {
t.Fatal("expected to fail due to the empty peer ID")
}

_, err = querier.GetValue(ctx, "")
if err == nil {
t.Fatal("expected to have failed")
}

err = querier.PutValue(ctx, "", []byte("foobar"))
if err == nil {
t.Fatal("expected to have failed")
}
}
24 changes: 18 additions & 6 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,15 @@ func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
}

func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
// setup response
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())

// first, is there even a key?
k := pmes.GetKey()
if len(k) == 0 {
return nil, errors.New("handleGetValue but no key was provided")
// TODO: send back an error response? could be bad, but the other node's hanging.
}

// setup response
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())

rec, err := dht.checkLocalDatastore(k)
if err != nil {
return nil, err
Expand Down Expand Up @@ -150,6 +149,10 @@ func cleanRecord(rec *recpb.Record) {

// Store a value in this peer local storage
func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
if len(pmes.GetKey()) == 0 {
return nil, errors.New("handleGetValue but no key was provided")
}

rec := pmes.GetRecord()
if rec == nil {
logger.Debugw("got nil record from", "from", p)
Expand Down Expand Up @@ -253,6 +256,10 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M
resp := pb.NewMessage(pmes.GetType(), nil, pmes.GetClusterLevel())
var closest []peer.ID

if len(pmes.GetKey()) == 0 {
return nil, fmt.Errorf("handleFindPeer with empty key")
}

// if looking for self... special case where we send it on CloserPeers.
targetPid := peer.ID(pmes.GetKey())
if targetPid == dht.self {
Expand Down Expand Up @@ -300,12 +307,15 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M
}

func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) {
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())
key := pmes.GetKey()
if len(key) > 80 {
return nil, fmt.Errorf("handleGetProviders key size too large")
} else if len(key) == 0 {
return nil, fmt.Errorf("handleGetProviders key is empty")
}

resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())

// check if we have this value, to add ourselves as provider.
has, err := dht.datastore.Has(convertToDsKey(key))
if err != nil && err != ds.ErrNotFound {
Expand Down Expand Up @@ -341,7 +351,9 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.
func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) {
key := pmes.GetKey()
if len(key) > 80 {
return nil, fmt.Errorf("handleAddProviders key size too large")
return nil, fmt.Errorf("handleAddProvider key size too large")
} else if len(key) == 0 {
return nil, fmt.Errorf("handleAddProvider key is empty")
}

logger.Debugf("adding provider", "from", p, "key", key)
Expand Down
21 changes: 21 additions & 0 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,27 @@ func TestCleanRecord(t *testing.T) {
}
}

func TestBadMessage(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dht := setupDHT(ctx, t, false)

for _, typ := range []pb.Message_MessageType{
pb.Message_PUT_VALUE, pb.Message_GET_VALUE, pb.Message_ADD_PROVIDER,
pb.Message_GET_PROVIDERS, pb.Message_FIND_NODE,
} {
msg := &pb.Message{
Type: typ,
// explicitly avoid the key.
}
_, err := dht.handlerForMsgType(typ)(ctx, dht.Host().ID(), msg)
if err == nil {
t.Fatalf("expected processing message to fail for type %s", pb.Message_FIND_NODE)
}
}
}

func BenchmarkHandleFindPeer(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
3 changes: 3 additions & 0 deletions lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ func (lk loggableKeyBytes) String() string {
// If the context is canceled, this function will return the context error along
// with the closest K peers it has found so far.
func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan peer.ID, error) {
if key == "" {
return nil, fmt.Errorf("can't lookup empty key")
}
//TODO: I can break the interface! return []peer.ID
lookupRes, err := dht.runLookupWithFollowup(ctx, key,
func(ctx context.Context, p peer.ID) ([]*peer.AddrInfo, error) {
Expand Down
11 changes: 0 additions & 11 deletions pb/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"github.com/libp2p/go-libp2p-core/peer"

logging "github.com/ipfs/go-log"
b58 "github.com/mr-tron/base58/base58"
ma "github.com/multiformats/go-multiaddr"
)

Expand Down Expand Up @@ -138,16 +137,6 @@ func (m *Message) SetClusterLevel(level int) {
m.ClusterLevelRaw = lvl + 1
}

// Loggable turns a Message into machine-readable log output
func (m *Message) Loggable() map[string]interface{} {
return map[string]interface{}{
"message": map[string]string{
"type": m.Type.String(),
"key": b58.Encode([]byte(m.GetKey())),
},
}
}

// ConnectionType returns a Message_ConnectionType associated with the
// network.Connectedness.
func ConnectionType(c network.Connectedness) Message_ConnectionType {
Expand Down
11 changes: 10 additions & 1 deletion routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ func (dht *IpfsDHT) refreshRTIfNoShortcut(key kb.ID, lookupRes *lookupWithFollow
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
if !dht.enableProviders {
return routing.ErrNotSupported
} else if !key.Defined() {
return fmt.Errorf("invalid cid: undefined")
}
logger.Debugw("finding provider", "cid", key)

Expand Down Expand Up @@ -486,7 +488,10 @@ func (dht *IpfsDHT) makeProvRecord(key []byte) (*pb.Message, error) {
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
if !dht.enableProviders {
return nil, routing.ErrNotSupported
} else if !c.Defined() {
return nil, fmt.Errorf("invalid cid: undefined")
}

var providers []peer.AddrInfo
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
providers = append(providers, p)
Expand All @@ -500,7 +505,7 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
// completes. Note: not reading from the returned channel may block the query
// from progressing.
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
if !dht.enableProviders {
if !dht.enableProviders || !key.Defined() {
peerOut := make(chan peer.AddrInfo)
close(peerOut)
return peerOut
Expand Down Expand Up @@ -613,6 +618,10 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash

// FindPeer searches for a peer with given ID.
func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, err error) {
if err := id.Validate(); err != nil {
return peer.AddrInfo{}, err
}

logger.Debugw("finding peer", "peer", id)

// Check if were already connected to them
Expand Down

0 comments on commit c45167c

Please sign in to comment.