Skip to content

Commit

Permalink
Add support for batched processing in services.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718789499
  • Loading branch information
panhania authored and copybara-github committed Jan 23, 2025
1 parent 94557ef commit a3821ce
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 0 deletions.
24 changes: 24 additions & 0 deletions fleetspeak/src/server/comms.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,30 @@ func (c commsContext) handleMessagesFromClient(ctx context.Context, info *comms.
if len(msgs) == 0 {
return nil
}
// TODO(b/371158380): Refactor validation and splitting by service to a single
// pass.
msgsByService := make(map[string][]*fspb.Message)
for _, msg := range msgs {
msgsByService[msg.Destination.ServiceName] = append(msgsByService[msg.Destination.ServiceName], msg)
}

var unbatchedMsgs []*fspb.Message

for service, msgs := range msgsByService {
if service == "" {
log.ErrorContextf(ctx, "Dropping %v messages with no service set", len(msgs))
continue
}

// TODO(b/371158380): Verify the batching configuration.
if c.s.serviceConfig.ShouldProcessMessageBatches(service) {
c.s.serviceConfig.ProcessMessageBatch(ctx, service, msgs)
} else {
unbatchedMsgs = append(unbatchedMsgs, msgs...)
}
}

msgs = unbatchedMsgs

sort.Slice(msgs, func(a, b int) bool {
return bytes.Compare(msgs[a].MessageId, msgs[b].MessageId) == -1
Expand Down
31 changes: 31 additions & 0 deletions fleetspeak/src/server/internal/services/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,37 @@ func (c *Manager) Stop() {
c.services = map[string]*liveService{}
}

// ShouldProcessMessageBatches returns true if the specified service is
// configured to process messages in batches.
func (c *Manager) ShouldProcessMessageBatches(serviceName string) bool {
svc := c.services[serviceName]
if svc == nil {
return false
}

_, ok := svc.service.(service.BatchedService)
return ok
}

// ProcessMessageBatch processes a batch of messages using the specified
// service.
func (c *Manager) ProcessMessageBatch(ctx context.Context, serviceName string, msgs []*fspb.Message) {
svc := c.services[serviceName]
if svc == nil {
log.ErrorContextf(ctx, "No such service: %v", serviceName)
return
}

batchedSvc, ok := svc.service.(service.BatchedService)
if !ok {
log.ErrorContextf(ctx, "Service %v does not implement BatchedService", serviceName)
}

if err := batchedSvc.ProcessMessageBatch(ctx, msgs); err != nil {
log.ErrorContextf(ctx, "Process batched messages: %v", err)
}
}

// ProcessMessages implements MessageProcessor and is called by the datastore on
// backlogged messages.
func (c *Manager) ProcessMessages(msgs []*fspb.Message) {
Expand Down
132 changes: 132 additions & 0 deletions fleetspeak/src/server/servertests/comms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ import (
"github.com/google/fleetspeak/fleetspeak/src/server/sertesting"
"github.com/google/fleetspeak/fleetspeak/src/server/service"
"github.com/google/fleetspeak/fleetspeak/src/server/testserver"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
tspb "google.golang.org/protobuf/types/known/timestamppb"

fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
Expand Down Expand Up @@ -576,3 +578,133 @@ func TestServiceError(t *testing.T) {
t.Errorf("Unexpected failure reason: got [%v], want [%v]", messageResult.FailedReason, expectedFailedReason)
}
}

type fakeBatchedService struct {
batches [][]*fspb.Message
}

func (s *fakeBatchedService) Start(sctx service.Context) error {
return nil
}

func (s *fakeBatchedService) ProcessMessage(ctx context.Context, msg *fspb.Message) error {
return s.ProcessMessageBatch(ctx, []*fspb.Message{msg})
}

func (s *fakeBatchedService) ProcessMessageBatch(ctx context.Context, msgs []*fspb.Message) error {
s.batches = append(s.batches, msgs)
return nil
}

func (s *fakeBatchedService) Stop() error {
return nil
}

func TestBatchedService(t *testing.T) {
ctx := context.Background()

service := &fakeBatchedService{}
server := testserver.MakeWithBatchedService(t, "TestServerService", service)
defer server.S.Stop()

clientKey, err := server.AddClient()
if err != nil {
t.Fatalf("add client: %v", err)
}
clientID, err := common.MakeClientID(clientKey)
if err != nil {
t.Fatalf("make client id: %v", err)
}

_, err = server.SimulateContactFromClient(ctx, clientKey, []*fspb.Message{
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("AA"),
MessageType: "TestMessageType",
},
})
if err != nil {
t.Fatalf("simulate contact ('AA'): %v", err)
}

_, err = server.SimulateContactFromClient(ctx, clientKey, []*fspb.Message{
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("BA"),
MessageType: "TestMessageType",
},
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("BB"),
MessageType: "TestMessageType",
},
})
if err != nil {
t.Fatalf("simulate contact ('BA', 'BB'): %v", err)
}

wantBatches := [][]*fspb.Message{
{
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("AA"),
MessageType: "TestMessageType",
},
},
{
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("BA"),
MessageType: "TestMessageType",
},
{
Source: &fspb.Address{
ClientId: clientID.Bytes(),
ServiceName: "TestEndpointService",
},
Destination: &fspb.Address{
ServiceName: "TestServerService",
},
SourceMessageId: []byte("BB"),
MessageType: "TestMessageType",
},
},
}

if diff := cmp.Diff(wantBatches, service.batches,
protocmp.Transform(),
protocmp.IgnoreFields(&fspb.Message{}, "message_id"),
); diff != "" {
t.Errorf("unexpected batches from simulated contact (-want +got):\n%s", diff)
}
}
15 changes: 15 additions & 0 deletions fleetspeak/src/server/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ type Service interface {
Stop() error
}

// A BatchedService is an extension of the Service interface that allows
// processing multiple messages at once.
//
// If a Service implements this interface, the original ProcessMessage method
// (for processing individual messages) will not be used.
type BatchedService interface {
// ProcessMessageBatch processes a batch of messages at once. Unlike the
// ProcessMessage method (of the Service interface), batches that failed to
// be processed will never be retried.
//
// In order for this method to be used, the service needs to enable `BATCHED`
// processing mode in its configuration.
ProcessMessageBatch(context.Context, []*fspb.Message) error
}

// Context allows a Fleetspeak Service to communicate back to the Fleetspeak system.
type Context interface {
// Set sends a message to a client machine or other server component. It can be called
Expand Down
42 changes: 42 additions & 0 deletions fleetspeak/src/server/testserver/testserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,48 @@ func MakeWithService(t *testing.T, testName, caseName string, serviceInstance se
return testServer
}

// MakeWithBatchedService creates with the given batched service backed by a
// SQLite datastore.
func MakeWithBatchedService(t *testing.T, svcName string, svc service.Service) *Server {
t.Helper()

if _, ok := svc.(service.BatchedService); !ok {
t.Fatalf("service %v does not implement BatchedService", svcName)
}

ds, err := sqlite.MakeDatastore(path.Join(t.TempDir(), "test.sqlite"))
if err != nil {
t.Fatalf("create datastore: %v", err)
}

result := &Server{}

server, err := server.MakeServer(
&spb.ServerConfig{
Services: []*spb.ServiceConfig{{
Name: svcName,
Factory: svcName,
}},
},
server.Components{
Datastore: ds,
ServiceFactories: map[string]service.Factory{
svcName: func(conf *spb.ServiceConfig) (service.Service, error) {
return svc, nil
},
},
Communicators: []comms.Communicator{FakeCommunicator{result}},
},
)
if err != nil {
t.Fatalf("create server: %v", err)
}

result.S = server
result.DS = ds
return result
}

// AddClient adds a new client with a random id to a server.
func (s Server) AddClient() (crypto.PublicKey, error) {
k, err := rsa.GenerateKey(rand.Reader, 2048)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/Microsoft/go-winio v0.6.1
github.com/go-sql-driver/mysql v1.6.0
github.com/golang/glog v1.2.4
github.com/google/go-cmp v0.6.0
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95
github.com/mattn/go-sqlite3 v1.14.16
github.com/pires/go-proxyproto v0.6.2
Expand Down

0 comments on commit a3821ce

Please sign in to comment.