diff --git a/dht_test.go b/dht_test.go index 74eb52724..76331951f 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1910,3 +1910,51 @@ func TestProtocolUpgrade(t *testing.T) { t.Fatalf("Expected 'buzz' got '%s'", string(value)) } } + +func TestInvalidKeys(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + nDHTs := 2 + dhts := setupDHTS(t, ctx, nDHTs) + defer func() { + for i := 0; i < nDHTs; i++ { + dhts[i].Close() + defer dhts[i].host.Close() + } + }() + + t.Logf("connecting %d dhts in a ring", nDHTs) + for i := 0; i < nDHTs; i++ { + connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)]) + } + + querier := dhts[0] + _, err := querier.GetClosestPeers(ctx, "") + if err == nil { + t.Fatal("get closest peers should have failed") + } + + _, err = querier.FindProviders(ctx, cid.Cid{}) + switch err { + case routing.ErrNotFound, routing.ErrNotSupported, kb.ErrLookupFailure: + t.Fatal("failed with the wrong error: ", err) + case nil: + t.Fatal("find providers should have failed") + } + + _, err = querier.FindPeer(ctx, peer.ID("")) + if err != peer.ErrEmptyPeerID { + t.Fatal("expected to fail due to the empty peer ID") + } + + _, err = querier.GetValue(ctx, "") + if err == nil { + t.Fatal("expected to have failed") + } + + err = querier.PutValue(ctx, "", []byte("foobar")) + if err == nil { + t.Fatal("expected to have failed") + } +} diff --git a/handlers.go b/handlers.go index 42445b70f..3dc709792 100644 --- a/handlers.go +++ b/handlers.go @@ -52,16 +52,15 @@ func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler { } func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) { - // setup response - resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel()) - // first, is there even a key? k := pmes.GetKey() if len(k) == 0 { return nil, errors.New("handleGetValue but no key was provided") - // TODO: send back an error response? could be bad, but the other node's hanging. } + // setup response + resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel()) + rec, err := dht.checkLocalDatastore(k) if err != nil { return nil, err @@ -150,6 +149,10 @@ func cleanRecord(rec *recpb.Record) { // Store a value in this peer local storage func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) { + if len(pmes.GetKey()) == 0 { + return nil, errors.New("handleGetValue but no key was provided") + } + rec := pmes.GetRecord() if rec == nil { logger.Debugw("got nil record from", "from", p) @@ -253,6 +256,10 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M resp := pb.NewMessage(pmes.GetType(), nil, pmes.GetClusterLevel()) var closest []peer.ID + if len(pmes.GetKey()) == 0 { + return nil, fmt.Errorf("handleFindPeer with empty key") + } + // if looking for self... special case where we send it on CloserPeers. targetPid := peer.ID(pmes.GetKey()) if targetPid == dht.self { @@ -300,12 +307,15 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M } func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) { - resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel()) key := pmes.GetKey() if len(key) > 80 { return nil, fmt.Errorf("handleGetProviders key size too large") + } else if len(key) == 0 { + return nil, fmt.Errorf("handleGetProviders key is empty") } + resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel()) + // check if we have this value, to add ourselves as provider. has, err := dht.datastore.Has(convertToDsKey(key)) if err != nil && err != ds.ErrNotFound { @@ -341,7 +351,9 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb. func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) { key := pmes.GetKey() if len(key) > 80 { - return nil, fmt.Errorf("handleAddProviders key size too large") + return nil, fmt.Errorf("handleAddProvider key size too large") + } else if len(key) == 0 { + return nil, fmt.Errorf("handleAddProvider key is empty") } logger.Debugf("adding provider", "from", p, "key", key) diff --git a/handlers_test.go b/handlers_test.go index 6e098f815..d88fdac1f 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -67,6 +67,27 @@ func TestCleanRecord(t *testing.T) { } } +func TestBadMessage(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dht := setupDHT(ctx, t, false) + + for _, typ := range []pb.Message_MessageType{ + pb.Message_PUT_VALUE, pb.Message_GET_VALUE, pb.Message_ADD_PROVIDER, + pb.Message_GET_PROVIDERS, pb.Message_FIND_NODE, + } { + msg := &pb.Message{ + Type: typ, + // explicitly avoid the key. + } + _, err := dht.handlerForMsgType(typ)(ctx, dht.Host().ID(), msg) + if err == nil { + t.Fatalf("expected processing message to fail for type %s", pb.Message_FIND_NODE) + } + } +} + func BenchmarkHandleFindPeer(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/lookup.go b/lookup.go index a602a9a8f..a696df4cf 100644 --- a/lookup.go +++ b/lookup.go @@ -72,6 +72,9 @@ func (lk loggableKeyBytes) String() string { // If the context is canceled, this function will return the context error along // with the closest K peers it has found so far. func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan peer.ID, error) { + if key == "" { + return nil, fmt.Errorf("can't lookup empty key") + } //TODO: I can break the interface! return []peer.ID lookupRes, err := dht.runLookupWithFollowup(ctx, key, func(ctx context.Context, p peer.ID) ([]*peer.AddrInfo, error) { diff --git a/pb/message.go b/pb/message.go index a7e9d14f2..3023d6438 100644 --- a/pb/message.go +++ b/pb/message.go @@ -5,7 +5,6 @@ import ( "github.com/libp2p/go-libp2p-core/peer" logging "github.com/ipfs/go-log" - b58 "github.com/mr-tron/base58/base58" ma "github.com/multiformats/go-multiaddr" ) @@ -138,16 +137,6 @@ func (m *Message) SetClusterLevel(level int) { m.ClusterLevelRaw = lvl + 1 } -// Loggable turns a Message into machine-readable log output -func (m *Message) Loggable() map[string]interface{} { - return map[string]interface{}{ - "message": map[string]string{ - "type": m.Type.String(), - "key": b58.Encode([]byte(m.GetKey())), - }, - } -} - // ConnectionType returns a Message_ConnectionType associated with the // network.Connectedness. func ConnectionType(c network.Connectedness) Message_ConnectionType { diff --git a/routing.go b/routing.go index 8ae8180b5..6808284a1 100644 --- a/routing.go +++ b/routing.go @@ -394,6 +394,8 @@ func (dht *IpfsDHT) refreshRTIfNoShortcut(key kb.ID, lookupRes *lookupWithFollow func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) { if !dht.enableProviders { return routing.ErrNotSupported + } else if !key.Defined() { + return fmt.Errorf("invalid cid: undefined") } logger.Debugw("finding provider", "cid", key) @@ -486,7 +488,10 @@ func (dht *IpfsDHT) makeProvRecord(key []byte) (*pb.Message, error) { func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) { if !dht.enableProviders { return nil, routing.ErrNotSupported + } else if !c.Defined() { + return nil, fmt.Errorf("invalid cid: undefined") } + var providers []peer.AddrInfo for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) { providers = append(providers, p) @@ -500,7 +505,7 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn // completes. Note: not reading from the returned channel may block the query // from progressing. func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo { - if !dht.enableProviders { + if !dht.enableProviders || !key.Defined() { peerOut := make(chan peer.AddrInfo) close(peerOut) return peerOut @@ -613,6 +618,10 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash // FindPeer searches for a peer with given ID. func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, err error) { + if err := id.Validate(); err != nil { + return peer.AddrInfo{}, err + } + logger.Debugw("finding peer", "peer", id) // Check if were already connected to them