diff --git a/pkg/daemon/topology.go b/pkg/daemon/topology.go index bc208033e0..44def41eef 100644 --- a/pkg/daemon/topology.go +++ b/pkg/daemon/topology.go @@ -64,16 +64,29 @@ type ReloadingTopology struct { // NewReloadingTopology creates a new ReloadingTopology that reloads the // interface information periodically. The Run method must be called for -// interface information to be populated. +// interface information to be populated. NOTE: The reloading topology does not +// clean up old interface information, so if you have a lot of interface churn, +// you may want to use a different implementation. func NewReloadingTopology(ctx context.Context, conn Connector) (*ReloadingTopology, error) { - topo, err := LoadTopology(ctx, conn) + ia, err := conn.LocalIA(ctx) + if err != nil { + return nil, serrors.Wrap("loading local ISD-AS", err) + } + start, end, err := conn.PortRange(ctx) if err != nil { + return nil, serrors.Wrap("loading port range", err) + } + t := &ReloadingTopology{ + conn: conn, + baseTopology: snet.Topology{ + LocalIA: ia, + PortRange: snet.TopologyPortRange{Start: start, End: end}, + }, + } + if err := t.loadInterfaces(ctx); err != nil { return nil, err } - return &ReloadingTopology{ - conn: conn, - baseTopology: topo, - }, nil + return t, nil } func (t *ReloadingTopology) Topology() snet.Topology { @@ -96,15 +109,12 @@ func (t *ReloadingTopology) Run(ctx context.Context, period time.Duration) { defer ticker.Stop() reload := func() { - intfs, err := t.conn.Interfaces(ctx) - if err != nil { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := t.loadInterfaces(ctx); err != nil { log.FromCtx(ctx).Error("Failed to reload interfaces", "err", err) } - for ifID, addr := range intfs { - t.interfaces.Store(ifID, addr) - } } - reload() for { select { @@ -115,3 +125,14 @@ func (t *ReloadingTopology) Run(ctx context.Context, period time.Duration) { } } } + +func (t *ReloadingTopology) loadInterfaces(ctx context.Context) error { + intfs, err := t.conn.Interfaces(ctx) + if err != nil { + return err + } + for ifID, addr := range intfs { + t.interfaces.Store(ifID, addr) + } + return nil +} diff --git a/pkg/daemon/topology_test.go b/pkg/daemon/topology_test.go new file mode 100644 index 0000000000..fa46d6662b --- /dev/null +++ b/pkg/daemon/topology_test.go @@ -0,0 +1,107 @@ +package daemon_test + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/scionproto/scion/pkg/addr" + "github.com/scionproto/scion/pkg/daemon" + "github.com/scionproto/scion/pkg/daemon/mock_daemon" + "github.com/scionproto/scion/pkg/snet" + "github.com/stretchr/testify/assert" +) + +func TestLoadTopology(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + conn := mock_daemon.NewMockConnector(ctrl) + wantTopo := testTopology{ + ia: addr.MustParseIA("1-ff00:0:110"), + start: uint16(4096), + end: uint16(8192), + interfaces: map[uint16]netip.AddrPort{ + 1: netip.MustParseAddrPort("10.0.0.1:5153"), + 2: netip.MustParseAddrPort("10.0.0.2:6421"), + }, + } + wantTopo.setupMockResponses(conn) + + topo, err := daemon.LoadTopology(context.Background(), conn) + assert.NoError(t, err) + wantTopo.checkTopology(t, topo) +} + +func TestReloadingTopology(t *testing.T) { + ctrl := gomock.NewController(t) + conn := mock_daemon.NewMockConnector(ctrl) + + wantTopo := testTopology{ + ia: addr.MustParseIA("1-ff00:0:110"), + start: uint16(4096), + end: uint16(8192), + interfaces: map[uint16]netip.AddrPort{ + 1: netip.MustParseAddrPort("10.0.0.1:5153"), + 2: netip.MustParseAddrPort("10.0.0.2:6421"), + }, + } + interfacesLater := map[uint16]netip.AddrPort{ + 2: netip.MustParseAddrPort("10.0.0.2:6421"), + 3: netip.MustParseAddrPort("10.0.0.3:7539"), + } + calls := wantTopo.setupMockResponses(conn) + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + gomock.InOrder( + append(calls, + conn.EXPECT().Interfaces(gomock.Any()).DoAndReturn(func(context.Context) (map[uint16]netip.AddrPort, error) { + cancel() + return interfacesLater, nil + }).AnyTimes(), + )..., + ) + + loader, err := daemon.NewReloadingTopology(ctx, conn) + assert.NoError(t, err) + topo := loader.Topology() + wantTopo.checkTopology(t, topo) + + go func() { + loader.Run(ctx, 100*time.Millisecond) + close(done) + }() + <-done + wantTopo.interfaces = interfacesLater + wantTopo.checkTopology(t, loader.Topology()) +} + +type testTopology struct { + ia addr.IA + start uint16 + end uint16 + interfaces map[uint16]netip.AddrPort +} + +func (tt testTopology) setupMockResponses(c *mock_daemon.MockConnector) []*gomock.Call { + return []*gomock.Call{ + c.EXPECT().LocalIA(gomock.Any()).Return(tt.ia, nil), + c.EXPECT().PortRange(gomock.Any()).Return(tt.start, tt.end, nil), + c.EXPECT().Interfaces(gomock.Any()).Return(tt.interfaces, nil), + } +} + +func (tt testTopology) checkTopology(t *testing.T, topo snet.Topology) { + t.Helper() + + assert.Equal(t, tt.ia, topo.LocalIA) + assert.Equal(t, tt.start, topo.PortRange.Start) + assert.Equal(t, tt.end, topo.PortRange.End) + for ifID, want := range tt.interfaces { + got, ok := topo.Interface(ifID) + assert.True(t, ok, "interface %d", ifID) + assert.Equal(t, want, got, "interface %d", ifID) + } +}