From 66cc859b3203010148204769d1a9d105877f5402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Hanuszczak?= Date: Mon, 13 Jan 2025 08:05:52 -0800 Subject: [PATCH] Add support for batched processing in services. PiperOrigin-RevId: 714970543 --- fleetspeak/src/server/comms.go | 29 ++++ .../src/server/internal/services/manager.go | 31 ++++ .../src/server/servertests/comms_test.go | 132 ++++++++++++++++++ fleetspeak/src/server/service/service.go | 12 ++ .../src/server/testserver/testserver.go | 42 ++++++ go.mod | 1 + 6 files changed, 247 insertions(+) diff --git a/fleetspeak/src/server/comms.go b/fleetspeak/src/server/comms.go index 753f489b7..363b75e9f 100644 --- a/fleetspeak/src/server/comms.go +++ b/fleetspeak/src/server/comms.go @@ -402,6 +402,35 @@ func (c commsContext) handleMessagesFromClient(ctx context.Context, info *comms. if len(msgs) == 0 { return nil } + // TODO(b/371158380): Refactor validation and splitting by service to a single + // pass. + msgsByService := make(map[string][]*fspb.Message, len(msgs)) + for _, msg := range msgs { + msgsByService[msg.Destination.ServiceName] = append(msgsByService[msg.Destination.ServiceName], msg) + } + + unbatchedMsgs := make([]*fspb.Message, 0) + + for service, msgs := range msgsByService { + if len(msgs) == 0 { + continue + } + if service == "" { + log.ErrorContextf(ctx, "dropping %v messages with no service set", len(msgs)) + continue + } + + // TODO(b/371158380): Verify the batching configuration. + if c.s.serviceConfig.ShouldProcessMessageBatches(service) { + c.s.serviceConfig.ProcessMessageBatch(ctx, service, msgs) + } else { + unbatchedMsgs = append(unbatchedMsgs, msgs...) + } + } + + // TODO(hanuszczak): Is it better to assign `msgs` to `unbatchedMsgs` here or + // to change the occurrences below (that makes the diff worse?). + msgs = unbatchedMsgs sort.Slice(msgs, func(a, b int) bool { return bytes.Compare(msgs[a].MessageId, msgs[b].MessageId) == -1 diff --git a/fleetspeak/src/server/internal/services/manager.go b/fleetspeak/src/server/internal/services/manager.go index 6785b84cf..5f00610f8 100644 --- a/fleetspeak/src/server/internal/services/manager.go +++ b/fleetspeak/src/server/internal/services/manager.go @@ -145,6 +145,37 @@ func (c *Manager) Stop() { c.services = map[string]*liveService{} } +// ShouldProcessMessageBatches returns true if the specified service is +// configured to process messages in batches. +func (c *Manager) ShouldProcessMessageBatches(serviceName string) bool { + svc := c.services[serviceName] + if svc == nil { + return false + } + + _, ok := svc.service.(service.BatchedService) + return ok +} + +// ProcessMessageBatch processes a batch of messages using the specified +// service. +func (c *Manager) ProcessMessageBatch(ctx context.Context, serviceName string, msgs []*fspb.Message) { + svc := c.services[serviceName] + if svc == nil { + log.ErrorContextf(ctx, "no such service: %v", serviceName) + return + } + + batchedSvc, ok := svc.service.(service.BatchedService) + if !ok { + log.ErrorContextf(ctx, "service %v does not implement BatchedService", serviceName) + } + + if err := batchedSvc.ProcessMessageBatch(ctx, msgs); err != nil { + log.ErrorContextf(ctx, "process batched messages: %v", err) + } +} + // ProcessMessages implements MessageProcessor and is called by the datastore on // backlogged messages. func (c *Manager) ProcessMessages(msgs []*fspb.Message) { diff --git a/fleetspeak/src/server/servertests/comms_test.go b/fleetspeak/src/server/servertests/comms_test.go index 9a0f52064..3169bf89a 100644 --- a/fleetspeak/src/server/servertests/comms_test.go +++ b/fleetspeak/src/server/servertests/comms_test.go @@ -35,7 +35,9 @@ import ( "github.com/google/fleetspeak/fleetspeak/src/server/sertesting" "github.com/google/fleetspeak/fleetspeak/src/server/service" "github.com/google/fleetspeak/fleetspeak/src/server/testserver" + "github.com/google/go-cmp/cmp" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" tspb "google.golang.org/protobuf/types/known/timestamppb" fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" @@ -576,3 +578,133 @@ func TestServiceError(t *testing.T) { t.Errorf("Unexpected failure reason: got [%v], want [%v]", messageResult.FailedReason, expectedFailedReason) } } + +type fakeBatchedService struct { + batches [][]*fspb.Message +} + +func (s *fakeBatchedService) Start(sctx service.Context) error { + return nil +} + +func (s *fakeBatchedService) ProcessMessage(ctx context.Context, msg *fspb.Message) error { + return s.ProcessMessageBatch(ctx, []*fspb.Message{msg}) +} + +func (s *fakeBatchedService) ProcessMessageBatch(ctx context.Context, msgs []*fspb.Message) error { + s.batches = append(s.batches, msgs) + return nil +} + +func (s *fakeBatchedService) Stop() error { + return nil +} + +func TestBatchedService(t *testing.T) { + ctx := context.Background() + + service := &fakeBatchedService{} + server := testserver.MakeWithBatchedService(t, "TestServerService", service) + defer server.S.Stop() + + clientKey, err := server.AddClient() + if err != nil { + t.Fatalf("add client: %v", err) + } + clientID, err := common.MakeClientID(clientKey) + if err != nil { + t.Fatalf("make client id: %v", err) + } + + _, err = server.SimulateContactFromClient(ctx, clientKey, []*fspb.Message{ + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("AA"), + MessageType: "TestMessageType", + }, + }) + if err != nil { + t.Fatalf("simulate contact ('AA'): %v", err) + } + + _, err = server.SimulateContactFromClient(ctx, clientKey, []*fspb.Message{ + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("BA"), + MessageType: "TestMessageType", + }, + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("BB"), + MessageType: "TestMessageType", + }, + }) + if err != nil { + t.Fatalf("simulate contact ('BA', 'BB'): %v", err) + } + + wantBatches := [][]*fspb.Message{ + { + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("AA"), + MessageType: "TestMessageType", + }, + }, + { + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("BA"), + MessageType: "TestMessageType", + }, + { + Source: &fspb.Address{ + ClientId: clientID.Bytes(), + ServiceName: "TestEndpointService", + }, + Destination: &fspb.Address{ + ServiceName: "TestServerService", + }, + SourceMessageId: []byte("BB"), + MessageType: "TestMessageType", + }, + }, + } + + if diff := cmp.Diff(wantBatches, service.batches, + protocmp.Transform(), + protocmp.IgnoreFields(&fspb.Message{}, "message_id"), + ); diff != "" { + t.Errorf("unexpected batches from simulated contact (-want +got):\n%s", diff) + } +} diff --git a/fleetspeak/src/server/service/service.go b/fleetspeak/src/server/service/service.go index f63e487fa..4a0ca845c 100644 --- a/fleetspeak/src/server/service/service.go +++ b/fleetspeak/src/server/service/service.go @@ -50,6 +50,18 @@ type Service interface { Stop() error } +// A BatchedService is an extension of the Service interface that allows +// processing multiple messages at once. +type BatchedService interface { + // ProcessMessageBatch processes a batch of messages at once. Unlike the + // ProcessMessage method (of the Service interface), batches that failed to + // be processed will never be retried. + // + // In order for this method to be used, the service needs to enable `BATCHED` + // processing mode in its configuration. + ProcessMessageBatch(context.Context, []*fspb.Message) error +} + // Context allows a Fleetspeak Service to communicate back to the Fleetspeak system. type Context interface { // Set sends a message to a client machine or other server component. It can be called diff --git a/fleetspeak/src/server/testserver/testserver.go b/fleetspeak/src/server/testserver/testserver.go index 42183f31a..7e1552870 100644 --- a/fleetspeak/src/server/testserver/testserver.go +++ b/fleetspeak/src/server/testserver/testserver.go @@ -140,6 +140,48 @@ func MakeWithService(t *testing.T, testName, caseName string, serviceInstance se return testServer } +// MakeWithBatchedService creates with the given batched service backed by a +// SQLite datastore. +func MakeWithBatchedService(t *testing.T, svcName string, svc service.Service) *Server { + t.Helper() + + if _, ok := svc.(service.BatchedService); !ok { + t.Fatalf("service %v does not implement BatchedService", svcName) + } + + ds, err := sqlite.MakeDatastore(path.Join(t.TempDir(), "test.sqlite")) + if err != nil { + t.Fatalf("create datastore: %v", err) + } + + result := &Server{} + + server, err := server.MakeServer( + &spb.ServerConfig{ + Services: []*spb.ServiceConfig{{ + Name: svcName, + Factory: svcName, + }}, + }, + server.Components{ + Datastore: ds, + ServiceFactories: map[string]service.Factory{ + svcName: func(conf *spb.ServiceConfig) (service.Service, error) { + return svc, nil + }, + }, + Communicators: []comms.Communicator{FakeCommunicator{result}}, + }, + ) + if err != nil { + t.Fatalf("create server: %v", err) + } + + result.S = server + result.DS = ds + return result +} + // AddClient adds a new client with a random id to a server. func (s Server) AddClient() (crypto.PublicKey, error) { k, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/go.mod b/go.mod index be2ac1ff0..53a0ec2af 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/Microsoft/go-winio v0.6.1 github.com/go-sql-driver/mysql v1.6.0 github.com/golang/glog v1.2.4 + github.com/google/go-cmp v0.6.0 github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 github.com/mattn/go-sqlite3 v1.14.16 github.com/pires/go-proxyproto v0.6.2