diff --git a/scheduler/pkg/coordinator/hub.go b/scheduler/pkg/coordinator/hub.go index d03b3daf78..450959f552 100644 --- a/scheduler/pkg/coordinator/hub.go +++ b/scheduler/pkg/coordinator/hub.go @@ -20,6 +20,7 @@ import ( const ( topicModelEvents = "model.event" + topicServerEvents = "server.event" topicExperimentEvents = "experiment.event" topicPipelineEvents = "pipeline.event" ) @@ -39,6 +40,7 @@ type EventHub struct { bus *busV3.Bus logger log.FieldLogger modelEventHandlerChannels []chan ModelEventMsg + serverEventHandlerChannels []chan ServerEventMsg experimentEventHandlerChannels []chan ExperimentEventMsg pipelineEventHandlerChannels []chan PipelineEventMsg lock sync.RWMutex @@ -59,7 +61,7 @@ func NewEventHub(l log.FieldLogger) (*EventHub, error) { bus: bus, } - hub.bus.RegisterTopics(topicModelEvents, topicExperimentEvents, topicPipelineEvents) + hub.bus.RegisterTopics(topicModelEvents, topicServerEvents, topicExperimentEvents, topicPipelineEvents) return &hub, nil } @@ -74,6 +76,10 @@ func (h *EventHub) Close() { close(c) } + for _, c := range h.serverEventHandlerChannels { + close(c) + } + for _, c := range h.experimentEventHandlerChannels { close(c) } diff --git a/scheduler/pkg/coordinator/hub_test.go b/scheduler/pkg/coordinator/hub_test.go index 6f0543733b..77ee02136b 100644 --- a/scheduler/pkg/coordinator/hub_test.go +++ b/scheduler/pkg/coordinator/hub_test.go @@ -24,8 +24,8 @@ func TestNewEventHub(t *testing.T) { tests := []test{ { - name: "Should register two topics", - expectedTopics: []string{topicModelEvents, topicExperimentEvents, topicPipelineEvents}, + name: "Should register four topics", + expectedTopics: []string{topicModelEvents, topicServerEvents, topicExperimentEvents, topicPipelineEvents}, }, } diff --git a/scheduler/pkg/coordinator/server.go b/scheduler/pkg/coordinator/server.go new file mode 100644 index 0000000000..b3da2c64e4 --- /dev/null +++ b/scheduler/pkg/coordinator/server.go @@ -0,0 +1,100 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed by +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +package coordinator + +import ( + "context" + "reflect" + + busV3 "github.com/mustafaturan/bus/v3" + log "github.com/sirupsen/logrus" +) + +func (h *EventHub) RegisterServerEventHandler( + name string, + queueSize int, + logger log.FieldLogger, + handle func(event ServerEventMsg), +) { + events := make(chan ServerEventMsg, queueSize) + h.addServerEventHandlerChannel(events) + + go func() { + for e := range events { + handle(e) + } + }() + + handler := h.newServerEventHandler(logger, events, handle) + h.bus.RegisterHandler(name, handler) +} + +func (h *EventHub) newServerEventHandler( + logger log.FieldLogger, + events chan ServerEventMsg, + _ func(event ServerEventMsg), +) busV3.Handler { + handleServerEventMessage := func(_ context.Context, e busV3.Event) { + l := logger.WithField("func", "handleServerEventMessage") + l.Debugf("Received event on %s from %s (ID: %s, TxID: %s)", e.Topic, e.Source, e.ID, e.TxID) + + me, ok := e.Data.(ServerEventMsg) + if !ok { + l.Warnf( + "Event (ID %s, TxID %s) on topic %s from %s is not a ServerEventMsg: %s", + e.ID, + e.TxID, + e.Topic, + e.Source, + reflect.TypeOf(e.Data).String(), + ) + return + } + + h.lock.RLock() + if h.closed { + return + } + // Propagate the busV3.Event source to the ServerEventMsg + // This is useful for logging, but also in case we want to distinguish + // the action to take based on where the event came from. + me.Source = e.Source + events <- me + h.lock.RUnlock() + } + + return busV3.Handler{ + Matcher: topicServerEvents, + Handle: handleServerEventMessage, + } +} + +func (h *EventHub) addServerEventHandlerChannel(c chan ServerEventMsg) { + h.lock.Lock() + defer h.lock.Unlock() + + h.serverEventHandlerChannels = append(h.serverEventHandlerChannels, c) +} + +func (h *EventHub) PublishServerEvent(source string, event ServerEventMsg) { + err := h.bus.EmitWithOpts( + context.Background(), + topicServerEvents, + event, + busV3.WithSource(source), + ) + if err != nil { + h.logger.WithError(err).Errorf( + "unable to publish server event message from %s to %s", + source, + topicServerEvents, + ) + } +} diff --git a/scheduler/pkg/coordinator/types.go b/scheduler/pkg/coordinator/types.go index bb7e8c3a09..d83e84d5ac 100644 --- a/scheduler/pkg/coordinator/types.go +++ b/scheduler/pkg/coordinator/types.go @@ -11,6 +11,13 @@ package coordinator import "fmt" +type ServerEventUpdateContext int + +const ( + SERVER_STATUS_UPDATE ServerEventUpdateContext = iota + SERVER_REPLICA_CONNECTED +) + type ModelEventMsg struct { ModelName string ModelVersion uint32 @@ -20,6 +27,16 @@ func (m ModelEventMsg) String() string { return fmt.Sprintf("%s:%d", m.ModelName, m.ModelVersion) } +type ServerEventMsg struct { + ServerName string + Source string + UpdateContext ServerEventUpdateContext +} + +func (m ServerEventMsg) String() string { + return m.ServerName +} + type ExperimentEventMsg struct { ExperimentName string UpdatedExperiment bool