Skip to content

Commit

Permalink
Implement generic batching framework
Browse files Browse the repository at this point in the history
Signed-off-by: Eddie Torres <torredil@amazon.com>
  • Loading branch information
torredil committed Nov 7, 2023
1 parent 27d1a99 commit 02f4ed8
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 0 deletions.
3 changes: 3 additions & 0 deletions charts/aws-ebs-csi-driver/templates/controller.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ spec:
{{- if .Values.controller.sdkDebugLog }}
- --aws-sdk-debug-log=true
{{- end}}
{{- if .Values.controller.batching }}
- --batching=true
{{- end}}
{{- with .Values.controller.loggingFormat }}
- --logging-format={{ . }}
{{- end }}
Expand Down
1 change: 1 addition & 0 deletions charts/aws-ebs-csi-driver/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ awsAccessSecret:
accessKey: access_key

controller:
batching: true
volumeModificationFeature:
enabled: false
# Additional parameters provided by aws-ebs-csi-driver controller.
Expand Down
1 change: 1 addition & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func main() {
driver.WithWarnOnInvalidTag(options.ControllerOptions.WarnOnInvalidTag),
driver.WithUserAgentExtra(options.ControllerOptions.UserAgentExtra),
driver.WithOtelTracing(options.ServerOptions.EnableOtelTracing),
driver.WithBatching(options.ControllerOptions.Batching),
)
if err != nil {
klog.ErrorS(err, "failed to create driver")
Expand Down
3 changes: 3 additions & 0 deletions cmd/options/controller_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type ControllerOptions struct {
WarnOnInvalidTag bool
// flag to set user agent
UserAgentExtra string
// flag to enable batching of API calls
Batching bool
}

func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) {
Expand All @@ -48,4 +50,5 @@ func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) {
fs.BoolVar(&s.AwsSdkDebugLog, "aws-sdk-debug-log", false, "To enable the aws sdk debug log level (default to false).")
fs.BoolVar(&s.WarnOnInvalidTag, "warn-on-invalid-tag", false, "To warn on invalid tags, instead of returning an error")
fs.StringVar(&s.UserAgentExtra, "user-agent-extra", "", "Extra string appended to user agent.")
fs.BoolVar(&s.Batching, "batching", false, "To enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits.")
}
5 changes: 5 additions & 0 deletions cmd/options/controller_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func TestControllerOptions(t *testing.T) {
flag: "aws-sdk-debug-log",
found: true,
},
{
name: "lookup batching",
flag: "batching",
found: true,
},
{
name: "lookup user-agent-extra",
flag: "user-agent-extra",
Expand Down
10 changes: 10 additions & 0 deletions cmd/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func TestGetOptions(t *testing.T) {
userAgentExtraFlagValue := "test"
otelTracingFlagName := "enable-otel-tracing"
otelTracingFlagValue := true
batchingFlagName := "batching"
batchingFlagValue := true

args := append([]string{
"aws-ebs-csi-driver",
Expand All @@ -68,6 +70,7 @@ func TestGetOptions(t *testing.T) {
args = append(args, "--"+extraTagsFlagName+"="+extraTagKey+"="+extraTagValue)
args = append(args, "--"+awsSdkDebugFlagName+"="+strconv.FormatBool(awsSdkDebugFlagValue))
args = append(args, "--"+userAgentExtraFlag+"="+userAgentExtraFlagValue)
args = append(args, "--"+batchingFlagName+"="+strconv.FormatBool(batchingFlagValue))
}
if withNodeOptions {
args = append(args, "--"+VolumeAttachLimitFlagName+"="+strconv.FormatInt(VolumeAttachLimit, 10))
Expand Down Expand Up @@ -110,6 +113,13 @@ func TestGetOptions(t *testing.T) {
if options.ControllerOptions.UserAgentExtra != userAgentExtraFlagValue {
t.Fatalf("expected user agent string to be %q but it is %q", userAgentExtraFlagValue, options.ControllerOptions.UserAgentExtra)
}
batchingFlag := flagSet.Lookup(batchingFlagName)
if batchingFlag == nil {
t.Fatalf("expected %q flag to be added but it is not", batchingFlagName)
}
if options.ControllerOptions.Batching != batchingFlagValue {
t.Fatalf("expected sdk debug flag to be %v but it is %v", batchingFlagValue, options.ControllerOptions.Batching)
}
}

if withNodeOptions {
Expand Down
1 change: 1 addition & 0 deletions deploy/kubernetes/base/controller.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ spec:
args:
# - {all,controller,node} # specify the driver mode
- --endpoint=$(CSI_ENDPOINT)
- --batching=true
- --logging-format=text
- --user-agent-extra=kustomize
- --v=2
Expand Down
1 change: 1 addition & 0 deletions docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ There are a couple of driver options that can be passed as arguments when starti
| logging-format | json | text | Sets the log format. Permitted formats: text, json|
| user-agent-extra | csi-ebs | helm | Extra string appended to user agent|
| enable-otel-tracing | true | false | If set to true, the driver will enable opentelemetry tracing. Might need [additional env variables](https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/#general-sdk-configuration) to export the traces to the right collector|
| batching | true | true | If set to true, the driver will enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits|
162 changes: 162 additions & 0 deletions pkg/batcher/batcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Package batcher facilitates task aggregation and execution.
//
// Basic Usage:
// Instantiate a Batcher, set up its constraints, and then start adding tasks. As tasks accumulate,
// they are batched together for execution, either when a maximum task count is reached or a specified
// duration elapses. Results of the executed tasks are communicated asynchronously via channels.
//
// Example:
// Create a Batcher with a maximum of 10 tasks or a 5-second wait:
//
// `b := batcher.New(10, 5*time.Second, execFunc)`
//
// Add a task and receive its result:
//
// resultChan := make(chan batcher.BatchResult)
// b.AddTask(myTask, resultChan)
// result := <-resultChan
//
// Key Components:
// - `Batcher`: The main component that manages task queueing, aggregation, and execution.
// - `BatchResult`: A structure encapsulating the response for a task.
// - `taskEntry`: Internal representation of a task and its associated result channel.
//
// Task Duplication:
// Batcher identifies tasks by content. For multiple identical tasks, each has a unique result channel.
// This distinction ensures that identical tasks return their results to the appropriate callers.
package batcher

import (
"time"

"k8s.io/klog/v2"
)

// Batcher manages the batching and execution of tasks. It collects tasks up to a specified limit (maxEntries) or
// waits for a defined duration (maxDelay) before triggering a batch execution. The actual task execution
// logic is provided by the execFunc, which processes tasks and returns their corresponding results. Tasks are
// queued via the taskChan and stored in pendingTasks until batch execution.
type Batcher[InputType comparable, ResultType interface{}] struct {
// execFunc is the function responsible for executing a batch of tasks.
// It returns a map associating each task with its result.
execFunc func(inputs []InputType) (map[InputType]ResultType, error)

// pendingTasks holds the tasks that are waiting to be executed in a batch.
// Each task is associated with one or more result channels to account for duplicates.
pendingTasks map[InputType][]chan BatchResult[ResultType]

// taskChan is the channel through which new tasks are added to the Batcher.
taskChan chan taskEntry[InputType, ResultType]

// maxEntries is the maximum number of tasks that can be batched together for execution.
maxEntries int

// maxDelay is the maximum duration the Batcher waits before executing a batch operation,
// regardless of how many tasks are in the batch.
maxDelay time.Duration
}

// BatchResult encapsulates the response of a batched task.
// A task will either have a result or an error, but not both.
type BatchResult[ResultType interface{}] struct {
Result ResultType
Err error
}

// taskEntry represents a single task waiting to be batched and its associated result channel.
// The result channel is used to communicate the task's result back to the caller.
type taskEntry[InputType comparable, ResultType interface{}] struct {
task InputType
resultChan chan BatchResult[ResultType]
}

// New creates and returns a Batcher configured with the specified maxEntries and maxDelay parameters.
// Upon instantiation, it immediately launches the internal task manager as a goroutine to oversee batch operations.
// The provided execFunc is used to execute batch requests.
func New[InputType comparable, ResultType interface{}](entries int, delay time.Duration, fn func(inputs []InputType) (map[InputType]ResultType, error)) *Batcher[InputType, ResultType] {
klog.V(7).InfoS("New: initializing Batcher", "maxEntries", entries, "maxDelay", delay)

b := &Batcher[InputType, ResultType]{
execFunc: fn,
pendingTasks: make(map[InputType][]chan BatchResult[ResultType]),
taskChan: make(chan taskEntry[InputType, ResultType], entries),
maxEntries: entries,
maxDelay: delay,
}

go b.taskManager()
return b
}

// AddTask adds a new task to the Batcher's queue.
func (b *Batcher[InputType, ResultType]) AddTask(t InputType, resultChan chan BatchResult[ResultType]) {
klog.V(7).InfoS("AddTask: queueing task", "task", t)
b.taskChan <- taskEntry[InputType, ResultType]{task: t, resultChan: resultChan}
}

// taskManager runs as a goroutine, continuously managing the Batcher's internal state.
// It batches tasks and triggers their execution based on set constraints (maxEntries and maxDelay).
func (b *Batcher[InputType, ResultType]) taskManager() {
klog.V(7).InfoS("taskManager: started taskManager")
var timerCh <-chan time.Time

exec := func() {
timerCh = nil
go b.execute(b.pendingTasks)
b.pendingTasks = make(map[InputType][]chan BatchResult[ResultType])
}

for {
select {
case <-timerCh:
klog.V(7).InfoS("taskManager: maxDelay execution")
exec()

case t := <-b.taskChan:
if _, exists := b.pendingTasks[t.task]; exists {
klog.InfoS("taskManager: duplicate task detected", "task", t.task)
} else {
b.pendingTasks[t.task] = make([]chan BatchResult[ResultType], 0)
}
b.pendingTasks[t.task] = append(b.pendingTasks[t.task], t.resultChan)

if len(b.pendingTasks) == 1 {
klog.V(7).InfoS("taskManager: starting maxDelay timer")
timerCh = time.After(b.maxDelay)
}

if len(b.pendingTasks) == b.maxEntries {
klog.V(7).InfoS("taskManager: maxEntries reached")
exec()
}
}
}
}

// execute is called by taskManager to execute a batch of tasks.
// It calls the Batcher's internal execFunc and then sends the results of each task to its corresponding result channels.
func (b *Batcher[InputType, ResultType]) execute(pendingTasks map[InputType][]chan BatchResult[ResultType]) {
batch := make([]InputType, 0, len(pendingTasks))
for task := range pendingTasks {
batch = append(batch, task)
}

klog.V(7).InfoS("execute: calling execFunc", "batchSize", len(batch))
resultsMap, err := b.execFunc(batch)
if err != nil {
klog.ErrorS(err, "execute: error executing batch")
}

klog.V(7).InfoS("execute: sending batch results", "batch", batch)
for _, task := range batch {
r := resultsMap[task]
for _, ch := range pendingTasks[task] {
select {
case ch <- BatchResult[ResultType]{Result: r, Err: err}:
default:
klog.V(7).InfoS("execute: ignoring channel with no receiver")
}
}
}
klog.V(7).InfoS("execute: finished execution", "batchSize", len(batch))
}
Loading

0 comments on commit 02f4ed8

Please sign in to comment.