Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add listPrefix in awsS3WriteCommitPrefix #31776

Merged
merged 6 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ https://github.com/elastic/beats/compare/v8.2.0\...main[Check the HEAD diff]
- sophos.xg: Update module to handle new log fields. {issue}31038[31038] {pull}31388[31388]
- Fix MISP documentation for `var.filters` config option. {pull}31434[31434]
- Fix type mapping of client.as.number in okta module. {pull}31676[31676]
- Fix last write pagination commit checkpoint on `aws-s3` input for s3 direct polling when using the same bucket and different list prefixes. {pull}31776[31776]

*Heartbeat*

Expand Down
95 changes: 65 additions & 30 deletions x-pack/filebeat/input/awss3/input_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,37 @@ package awss3

import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"

"github.com/elastic/beats/v7/libbeat/beat"
"github.com/elastic/beats/v7/libbeat/statestore"
"github.com/elastic/beats/v7/libbeat/statestore/storetest"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/dustin/go-humanize"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"

"github.com/elastic/beats/v7/libbeat/beat"
pubtest "github.com/elastic/beats/v7/libbeat/publisher/testing"
"github.com/elastic/beats/v7/libbeat/statestore"
"github.com/elastic/beats/v7/libbeat/statestore/storetest"
awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws"
conf "github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/monitoring"
)

const (
cloudtrailTestFile = "testdata/aws-cloudtrail.json.gz"
totalListingObjects = 10000
cloudtrailTestFile = "testdata/aws-cloudtrail.json.gz"
totalListingObjects = 10000
totalListingObjectsForInputS3 = totalListingObjects / 5
)

type constantSQS struct {
Expand All @@ -54,11 +57,11 @@ func (c *constantSQS) ReceiveMessage(ctx context.Context, maxMessages int) ([]sq
return c.msgs, nil
}

func (_ *constantSQS) DeleteMessage(ctx context.Context, msg *sqs.Message) error {
func (*constantSQS) DeleteMessage(ctx context.Context, msg *sqs.Message) error {
return nil
}

func (_ *constantSQS) ChangeMessageVisibility(ctx context.Context, msg *sqs.Message, timeout time.Duration) error {
func (*constantSQS) ChangeMessageVisibility(ctx context.Context, msg *sqs.Message, timeout time.Duration) error {
return nil
}

Expand Down Expand Up @@ -93,16 +96,16 @@ func (c *s3PagerConstant) Err() error {
return nil
}

func newS3PagerConstant() *s3PagerConstant {
func newS3PagerConstant(listPrefix string) *s3PagerConstant {
lastModified := time.Now()
ret := &s3PagerConstant{
currentIndex: 0,
}

for i := 0; i < totalListingObjects; i++ {
for i := 0; i < totalListingObjectsForInputS3; i++ {
ret.objects = append(ret.objects, s3.Object{
Key: aws.String(fmt.Sprintf("key-%d.json.gz", i)),
ETag: aws.String(fmt.Sprintf("etag-%d", i)),
Key: aws.String(fmt.Sprintf("%s-%d.json.gz", listPrefix, i)),
ETag: aws.String(fmt.Sprintf("etag-%s-%d", listPrefix, i)),
LastModified: aws.Time(lastModified),
})
}
Expand Down Expand Up @@ -213,7 +216,7 @@ func benchmarkInputSQS(t *testing.T, maxMessagesInflight int) testing.BenchmarkR
}

func TestBenchmarkInputSQS(t *testing.T) {
logp.TestingSetup(logp.WithLevel(logp.InfoLevel))
_ = logp.TestingSetup(logp.WithLevel(logp.InfoLevel))

results := []testing.BenchmarkResult{
benchmarkInputSQS(t, 1),
Expand All @@ -236,7 +239,7 @@ func TestBenchmarkInputSQS(t *testing.T) {
"Time (sec)",
"CPUs",
}
var data [][]string
data := make([][]string, 0)
for _, r := range results {
data = append(data, []string{
fmt.Sprintf("%v", r.Extra["max_messages_inflight"]),
Expand All @@ -258,8 +261,7 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
log := logp.NewLogger(inputName)
metricRegistry := monitoring.NewRegistry()
metrics := newInputMetrics(metricRegistry, "test_id")
s3API := newConstantS3(t)
s3API.pagerConstant = newS3PagerConstant()

client := pubtest.NewChanClientWithCallback(100, func(event beat.Event) {
event.Private.(*awscommon.EventACKTracker).ACK()
})
Expand All @@ -273,14 +275,8 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
t.Fatalf("Failed to access store: %v", err)
}

err = store.Set(awsS3WriteCommitPrefix+"bucket", &commitWriteState{time.Time{}})
if err != nil {
t.Fatalf("Failed to reset store: %v", err)
}

s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, conf.FileSelectors)
s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", "key-", "region", "provider", numberOfWorkers, time.Second)

b.ResetTimer()
start := time.Now()
ctx, cancel := context.WithCancel(context.Background())
b.Cleanup(cancel)

Expand All @@ -291,13 +287,42 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
cancel()
}()

b.ResetTimer()
start := time.Now()
if err := s3Poller.Poll(ctx); err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
errChan := make(chan error)
wg := new(sync.WaitGroup)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int, wg *sync.WaitGroup) {
defer wg.Done()
listPrefix := fmt.Sprintf("list_prefix_%d", i)
s3API := newConstantS3(t)
s3API.pagerConstant = newS3PagerConstant(listPrefix)
err = store.Set(awsS3WriteCommitPrefix+"bucket"+listPrefix, &commitWriteState{time.Time{}})
if err != nil {
errChan <- err
return
}

s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, conf.FileSelectors)
s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second)

if err := s3Poller.Poll(ctx); err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
errChan <- err
}
}
}(i, wg)
}

wg.Wait()
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
default:

}

b.StopTimer()
elapsed := time.Since(start)

Expand All @@ -322,7 +347,7 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
}

func TestBenchmarkInputS3(t *testing.T) {
logp.TestingSetup(logp.WithLevel(logp.InfoLevel))
_ = logp.TestingSetup(logp.WithLevel(logp.InfoLevel))

results := []testing.BenchmarkResult{
benchmarkInputS3(t, 1),
Expand All @@ -340,22 +365,32 @@ func TestBenchmarkInputS3(t *testing.T) {

headers := []string{
"Number of workers",
"Objects listed total",
"Objects listed per sec",
"Objects processed total",
"Objects processed per sec",
"Objects acked total",
"Objects acked per sec",
"Events total",
"Events per sec",
"S3 Bytes total",
"S3 Bytes per sec",
"Time (sec)",
"CPUs",
}
var data [][]string
data := make([][]string, 0)
for _, r := range results {
data = append(data, []string{
fmt.Sprintf("%v", r.Extra["number_of_workers"]),
fmt.Sprintf("%v", r.Extra["objects_listed"]),
fmt.Sprintf("%v", r.Extra["objects_listed_per_sec"]),
fmt.Sprintf("%v", r.Extra["objects_processed"]),
fmt.Sprintf("%v", r.Extra["objects_processed_per_sec"]),
fmt.Sprintf("%v", r.Extra["objects_acked"]),
fmt.Sprintf("%v", r.Extra["objects_acked_per_sec"]),
fmt.Sprintf("%v", r.Extra["events"]),
fmt.Sprintf("%v", r.Extra["events_per_sec"]),
fmt.Sprintf("%v", humanize.Bytes(uint64(r.Extra["s3_bytes"]))),
fmt.Sprintf("%v", humanize.Bytes(uint64(r.Extra["s3_bytes_per_sec"]))),
fmt.Sprintf("%v", r.Extra["sec"]),
fmt.Sprintf("%v", runtime.GOMAXPROCS(0)),
Expand Down
49 changes: 32 additions & 17 deletions x-pack/filebeat/input/awss3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ package awss3

import (
"context"
"errors"
"fmt"
"net/url"
"sync"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.uber.org/multierr"

"github.com/elastic/beats/v7/libbeat/statestore"
Expand Down Expand Up @@ -126,9 +127,11 @@ func (p *s3Poller) ProcessObject(s3ObjectPayloadChan <-chan *s3ObjectPayload) er

if err != nil {
event := s3ObjectPayload.s3ObjectEvent
errs = append(errs, errors.Wrapf(err,
"failed processing S3 event for object key %q in bucket %q",
event.S3.Object.Key, event.S3.Bucket.Name))
errs = append(errs,
fmt.Errorf(
fmt.Sprintf("failed processing S3 event for object key %q in bucket %q: %%w",
event.S3.Object.Key, event.S3.Bucket.Name),
err))

p.handlePurgingLock(info, false)
continue
Expand Down Expand Up @@ -178,7 +181,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
continue
}

state := newState(bucketName, filename, *object.ETag, *object.LastModified)
state := newState(bucketName, filename, *object.ETag, p.listPrefix, *object.LastModified)
if p.states.MustSkip(state, p.store) {
p.log.Debugw("skipping state.", "state", state)
continue
Expand All @@ -197,6 +200,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-

s3Processor := p.s3ObjectHandler.Create(ctx, p.log, acker, event)
if s3Processor == nil {
p.log.Debugw("empty s3 processor.", "state", state)
continue
}

Expand All @@ -216,6 +220,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
}

if totProcessableObjects == 0 {
p.log.Debugw("0 processable objects on bucket pagination.", "bucket", p.bucket, "listPrefix", p.listPrefix, "listingID", listingID)
// nothing to be ACKed, unlock here
p.states.DeleteListing(listingID.String())
lock.Unlock()
Expand All @@ -236,12 +241,11 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
if err := paginator.Err(); err != nil {
p.log.Warnw("Error when paginating listing.", "error", err)
}

return
}

func (p *s3Poller) Purge() {
listingIDs := p.states.GetListingIDs()
p.log.Debugw("purging listing.", "listingIDs", listingIDs)
for _, listingID := range listingIDs {
// we lock here in order to process the purge only after
// full listing page is ACKed by all the workers
Expand All @@ -250,39 +254,45 @@ func (p *s3Poller) Purge() {
// purge calls can overlap, GetListingIDs can return
// an outdated snapshot with listing already purged
p.states.DeleteListing(listingID)
p.log.Debugw("deleting already purged listing from states.", "listingID", listingID)
continue
}

lock.(*sync.Mutex).Lock()

keys := map[string]struct{}{}
latestStoredTimeByBucket := make(map[string]time.Time, 0)
latestStoredTimeByBucketAndListPrefix := make(map[string]time.Time, 0)

for _, state := range p.states.GetStatesByListingID(listingID) {
// it is not stored, keep
if !state.Stored {
p.log.Debugw("state not stored, skip purge", "state", state)
continue
}

var latestStoredTime time.Time
keys[state.ID] = struct{}{}
latestStoredTime, ok := latestStoredTimeByBucket[state.Bucket]
latestStoredTime, ok := latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix]
Copy link
Author

@aspacca aspacca May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is indeed both a bugfix and a breaking change
from the elastic side the only integrations that does that is the one on cisco managed s3 buckets, that's indeed the bugged integration that will fix
but in the case of users using custom setup with a single bucket and multiple list prefixes it will be a breaking change as in the measure that s3 objects could be ingested again (and probably fails ingestion because of same document ids)

if !ok {
var commitWriteState commitWriteState
err := p.store.Get(awsS3WriteCommitPrefix+state.Bucket, &commitWriteState)
err := p.store.Get(awsS3WriteCommitPrefix+state.Bucket+state.ListPrefix, &commitWriteState)
if err == nil {
// we have no entry in the map and we have no entry in the store
// set zero time
latestStoredTime = time.Time{}
p.log.Debugw("last stored time is zero time", "bucket", state.Bucket, "listPrefix", state.ListPrefix)
} else {
latestStoredTime = commitWriteState.Time
p.log.Debugw("last stored time is commitWriteState", "commitWriteState", commitWriteState, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
}
} else {
p.log.Debugw("last stored time from memory", "latestStoredTime", latestStoredTime, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
}

if state.LastModified.After(latestStoredTime) {
latestStoredTimeByBucket[state.Bucket] = state.LastModified
p.log.Debugw("last stored time updated", "state.LastModified", state.LastModified, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix] = state.LastModified
}

}

for key := range keys {
Expand All @@ -293,8 +303,8 @@ func (p *s3Poller) Purge() {
p.log.Errorw("Failed to write states to the registry", "error", err)
}

for bucket, latestStoredTime := range latestStoredTimeByBucket {
if err := p.store.Set(awsS3WriteCommitPrefix+bucket, commitWriteState{latestStoredTime}); err != nil {
for bucketAndListPrefix, latestStoredTime := range latestStoredTimeByBucketAndListPrefix {
if err := p.store.Set(awsS3WriteCommitPrefix+bucketAndListPrefix, commitWriteState{latestStoredTime}); err != nil {
p.log.Errorw("Failed to write commit time to the registry", "error", err)
}
}
Expand All @@ -304,8 +314,6 @@ func (p *s3Poller) Purge() {
p.workersListingMap.Delete(listingID)
p.states.DeleteListing(listingID)
}

return
}

func (p *s3Poller) Poll(ctx context.Context) error {
Expand Down Expand Up @@ -349,8 +357,15 @@ func (p *s3Poller) Poll(ctx context.Context) error {
}()
}

timed.Wait(ctx, p.bucketPollInterval)
err = timed.Wait(ctx, p.bucketPollInterval)
if err != nil {
if errors.Is(err, context.Canceled) {
// A canceled context is a normal shutdown.
return nil
}

return err
}
}

// Wait for all workers to finish.
Expand Down
Loading