diff --git a/dht_test.go b/dht_test.go index 30a84bc82..cb19ff0cc 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1833,3 +1833,110 @@ func TestDynamicModeSwitching(t *testing.T) { assertDHTClient() } + +func TestProtocolUpgrade(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + os := []opts.Option{ + opts.Mode(opts.ModeServer), + opts.NamespacedValidator("v", blankValidator{}), + opts.DisableAutoRefresh(), + opts.DisjointPaths(1), + } + + // This test verifies that we can have a node serving both old and new DHTs that will respond as a server to the old + // DHT, but only act as a client of the new DHT. In it's capacity as a server it should also only tell queriers + // about other DHT servers in the new DHT. + + protoNew := protocol.ID("/dht/B") + protoOld := protocol.ID("/dht/C") + + dhtA, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]opts.Option{opts.Protocols(protoNew, protoOld), opts.ClientProtocols(protoNew)}, os...)...) + if err != nil { + t.Fatal(err) + } + + dhtB, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]opts.Option{opts.Protocols(protoNew, protoOld), opts.ClientProtocols(protoNew)}, os...)...) + if err != nil { + t.Fatal(err) + } + + dhtC, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]opts.Option{opts.Protocols(protoOld)}, os...)...) + if err != nil { + t.Fatal(err) + } + + connect(t, ctx, dhtA, dhtB) + connectNoSync(t, ctx, dhtA, dhtC) + wait(t, ctx, dhtC, dhtA) + + if sz := dhtA.RoutingTable().Size(); sz != 1 { + t.Fatalf("Expected routing table to be of size %d got %d", 1, sz) + } + + ctxT, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := dhtB.PutValue(ctxT, "/v/bat", []byte("screech")); err != nil { + t.Fatal(err) + } + + value, err := dhtC.GetValue(ctxT, "/v/bat") + if err != nil { + t.Fatal(err) + } + + if string(value) != "screech" { + t.Fatalf("Expected 'screech' got '%s'", string(value)) + } + + if err := dhtC.PutValue(ctxT, "/v/cat", []byte("meow")); err != nil { + t.Fatal(err) + } + + value, err = dhtB.GetValue(ctxT, "/v/cat") + if err != nil { + t.Fatal(err) + } + + if string(value) != "meow" { + t.Fatalf("Expected 'meow' got '%s'", string(value)) + } + + // Add record into local DHT only + rec := record.MakePutRecord("/v/crow", []byte("caw")) + rec.TimeReceived = u.FormatRFC3339(time.Now()) + err = dhtC.putLocal(string(rec.Key), rec) + if err != nil { + t.Fatal(err) + } + + value, err = dhtB.GetValue(ctxT, "/v/crow") + switch err { + case nil: + t.Fatalf("should not have been able to find value for %s", "/v/crow") + case routing.ErrNotFound: + default: + t.Fatal(err) + } + + // Add record into local DHT only + rec = record.MakePutRecord("/v/bee", []byte("buzz")) + rec.TimeReceived = u.FormatRFC3339(time.Now()) + err = dhtB.putLocal(string(rec.Key), rec) + if err != nil { + t.Fatal(err) + } + + value, err = dhtC.GetValue(ctxT, "/v/bee") + if err != nil { + t.Fatal(err) + } + + if string(value) != "buzz" { + t.Fatalf("Expected 'buzz' got '%s'", string(value)) + } +}