From 56ee1392e8fdb8dd49719370e6c01cc8b641deb7 Mon Sep 17 00:00:00 2001 From: Antoine Pourchet Date: Tue, 28 May 2024 18:05:51 -0600 Subject: [PATCH] KAFKA-15045: (KIP-924 pt. 11) Implemented StickyTaskAssignor (#16052) This PR implements the StickyTaskAssignor with the new KIP 924 API. Reviewers: Anna Sophie Blee-Goldman --- .../processor/assignment/TaskAssignor.java | 4 + .../assignors/StickyTaskAssignor.java | 468 ++++++++++++++++++ 2 files changed, 472 insertions(+) create mode 100644 streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignor.java index acfbe4beb84e9..f1b8e48005f54 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignor.java @@ -17,6 +17,7 @@ package org.apache.kafka.streams.processor.assignment; import java.util.Collection; +import java.util.Map; import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupAssignment; import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription; import org.apache.kafka.common.Configurable; @@ -73,6 +74,9 @@ enum AssignmentError { */ default void onAssignmentComputed(GroupAssignment assignment, GroupSubscription subscription, AssignmentError error) {} + @Override + default void configure(final Map configs) {} + /** * Wrapper class for the final assignment of active and standbys tasks to individual * KafkaStreams clients. diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java new file mode 100644 index 0000000000000..aa8b174198ab4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.assignment.assignors; + +import static java.util.Collections.unmodifiableMap; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; +import java.util.stream.Collectors; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.assignment.ApplicationState; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment.AssignedTask; +import org.apache.kafka.streams.processor.assignment.KafkaStreamsState; +import org.apache.kafka.streams.processor.assignment.ProcessId; +import org.apache.kafka.streams.processor.assignment.TaskAssignmentUtils; +import org.apache.kafka.streams.processor.assignment.TaskAssignor; +import org.apache.kafka.streams.processor.assignment.TaskInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class StickyTaskAssignor implements TaskAssignor { + private static final Logger LOG = LoggerFactory.getLogger(StickyTaskAssignor.class); + + private final boolean mustPreserveActiveTaskAssignment; + + public StickyTaskAssignor() { + this(false); + } + + public StickyTaskAssignor(final boolean mustPreserveActiveTaskAssignment) { + this.mustPreserveActiveTaskAssignment = mustPreserveActiveTaskAssignment; + } + + @Override + public TaskAssignment assign(final ApplicationState applicationState) { + final Map clients = applicationState.kafkaStreamsStates(false); + final Map previousActiveAssignment = mapPreviousActiveTasks(clients); + final Map> previousStandbyAssignment = mapPreviousStandbyTasks(clients); + final AssignmentState assignmentState = new AssignmentState(applicationState, clients, + previousActiveAssignment, previousStandbyAssignment); + + assignActive(applicationState, clients.values(), assignmentState, this.mustPreserveActiveTaskAssignment); + optimizeActive(applicationState, assignmentState); + assignStandby(applicationState, assignmentState); + optimizeStandby(applicationState, assignmentState); + + final Map finalAssignments = assignmentState.buildKafkaStreamsAssignments(); + if (mustPreserveActiveTaskAssignment && !finalAssignments.isEmpty()) { + // We set the followup deadline for only one of the clients. + final ProcessId clientId = finalAssignments.entrySet().iterator().next().getKey(); + final KafkaStreamsAssignment previousAssignment = finalAssignments.get(clientId); + finalAssignments.put(clientId, previousAssignment.withFollowupRebalance(Instant.ofEpochMilli(0))); + } + + final Collection taskAssignments = finalAssignments.entrySet().stream() + .map(Map.Entry::getValue).collect(Collectors.toList()); + return new TaskAssignment(taskAssignments); + } + + private void optimizeActive(final ApplicationState applicationState, + final AssignmentState assignmentState) { + if (mustPreserveActiveTaskAssignment) { + return; + } + + final Map currentAssignments = assignmentState.buildKafkaStreamsAssignments(); + + final Set statefulTasks = applicationState.allTasks().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map optimizedAssignmentsForStatefulTasks = TaskAssignmentUtils.optimizeRackAwareActiveTasks( + applicationState, currentAssignments, new TreeSet<>(statefulTasks)); + + final Set statelessTasks = applicationState.allTasks().stream() + .filter(task -> !task.isStateful()) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map optimizedAssignmentsForAllTasks = TaskAssignmentUtils.optimizeRackAwareActiveTasks( + applicationState, optimizedAssignmentsForStatefulTasks, new TreeSet<>(statelessTasks)); + + assignmentState.processOptimizedAssignments(optimizedAssignmentsForAllTasks); + } + + private void optimizeStandby(final ApplicationState applicationState, final AssignmentState assignmentState) { + if (applicationState.assignmentConfigs().numStandbyReplicas() <= 0) { + return; + } + + if (mustPreserveActiveTaskAssignment) { + return; + } + + final Map currentAssignments = assignmentState.buildKafkaStreamsAssignments(); + final Map optimizedAssignments = TaskAssignmentUtils.optimizeRackAwareStandbyTasks( + applicationState, currentAssignments); + assignmentState.processOptimizedAssignments(optimizedAssignments); + } + + private static void assignActive(final ApplicationState applicationState, + final Collection clients, + final AssignmentState assignmentState, + final boolean mustPreserveActiveTaskAssignment) { + final int totalCapacity = computeTotalProcessingThreads(clients); + final Set allTaskIds = applicationState.allTasks().stream() + .map(TaskInfo::id).collect(Collectors.toSet()); + final int taskCount = allTaskIds.size(); + final int activeTasksPerThread = taskCount / totalCapacity; + final Set unassigned = new HashSet<>(allTaskIds); + + // first try and re-assign existing active tasks to clients that previously had + // the same active task + for (final TaskId taskId : assignmentState.previousActiveAssignment.keySet()) { + final ProcessId previousClientForTask = assignmentState.previousActiveAssignment.get(taskId); + if (allTaskIds.contains(taskId)) { + if (mustPreserveActiveTaskAssignment || assignmentState.hasRoomForActiveTask(previousClientForTask, activeTasksPerThread)) { + assignmentState.finalizeAssignment(taskId, previousClientForTask, AssignedTask.Type.ACTIVE); + unassigned.remove(taskId); + } + } + } + + // try and assign any remaining unassigned tasks to clients that previously + // have seen the task. + for (final Iterator iterator = unassigned.iterator(); iterator.hasNext(); ) { + final TaskId taskId = iterator.next(); + final Set previousClientsForStandbyTask = assignmentState.previousStandbyAssignment.getOrDefault(taskId, new HashSet<>()); + for (final ProcessId client: previousClientsForStandbyTask) { + if (assignmentState.hasRoomForActiveTask(client, activeTasksPerThread)) { + assignmentState.finalizeAssignment(taskId, client, AssignedTask.Type.ACTIVE); + iterator.remove(); + break; + } + } + } + + // assign any remaining unassigned tasks + final List sortedTasks = new ArrayList<>(unassigned); + Collections.sort(sortedTasks); + for (final TaskId taskId : sortedTasks) { + final Set candidateClients = clients.stream() + .map(KafkaStreamsState::processId) + .collect(Collectors.toSet()); + final ProcessId bestClient = assignmentState.findBestClientForTask(taskId, candidateClients); + assignmentState.finalizeAssignment(taskId, bestClient, AssignedTask.Type.ACTIVE); + } + } + + private static void assignStandby(final ApplicationState applicationState, + final AssignmentState assignmentState) { + final Set statefulTasks = applicationState.allTasks().stream() + .filter(TaskInfo::isStateful) + .collect(Collectors.toSet()); + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + for (final TaskInfo task : statefulTasks) { + for (int i = 0; i < numStandbyReplicas; i++) { + final Set candidateClients = assignmentState.findClientsWithoutAssignedTask(task.id()); + if (candidateClients.isEmpty()) { + LOG.warn("Unable to assign {} of {} standby tasks for task [{}]. " + + "There is not enough available capacity. You should " + + "increase the number of threads and/or application instances " + + "to maintain the requested number of standby replicas.", + numStandbyReplicas - i, + numStandbyReplicas, task.id()); + break; + } + + final ProcessId bestClient = assignmentState.findBestClientForTask(task.id(), candidateClients); + assignmentState.finalizeAssignment(task.id(), bestClient, AssignedTask.Type.STANDBY); + } + } + } + + private static Map mapPreviousActiveTasks(final Map clients) { + final Map previousActiveTasks = new HashMap<>(); + for (final KafkaStreamsState client : clients.values()) { + for (final TaskId taskId : client.previousActiveTasks()) { + previousActiveTasks.put(taskId, client.processId()); + } + } + return previousActiveTasks; + } + + private static Map> mapPreviousStandbyTasks(final Map clients) { + final Map> previousStandbyTasks = new HashMap<>(); + for (final KafkaStreamsState client : clients.values()) { + for (final TaskId taskId : client.previousActiveTasks()) { + previousStandbyTasks.computeIfAbsent(taskId, k -> new HashSet<>()); + previousStandbyTasks.get(taskId).add(client.processId()); + } + } + return previousStandbyTasks; + } + + private static int computeTotalProcessingThreads(final Collection clients) { + int count = 0; + for (final KafkaStreamsState client : clients) { + count += client.numProcessingThreads(); + } + return count; + } + + private static class AssignmentState { + private final Map clients; + private final Map previousActiveAssignment; + private final Map> previousStandbyAssignment; + + private final TaskPairs taskPairs; + + private Map> newTaskLocations; + private Map> newAssignments; + + private AssignmentState(final ApplicationState applicationState, + final Map clients, + final Map previousActiveAssignment, + final Map> previousStandbyAssignment) { + this.clients = clients; + this.previousActiveAssignment = unmodifiableMap(previousActiveAssignment); + this.previousStandbyAssignment = unmodifiableMap(previousStandbyAssignment); + + final int taskCount = applicationState.allTasks().size(); + final int maxPairs = taskCount * (taskCount - 1) / 2; + this.taskPairs = new TaskPairs(maxPairs); + + this.newTaskLocations = new HashMap<>(); + this.newAssignments = new HashMap<>(); + } + + public void finalizeAssignment(final TaskId taskId, final ProcessId client, final AssignedTask.Type type) { + newAssignments.computeIfAbsent(client, k -> new HashSet<>()); + newTaskLocations.computeIfAbsent(taskId, k -> new HashSet<>()); + + final Set newAssignmentsForClient = newAssignments.get(client) + .stream().map(AssignedTask::id).collect(Collectors.toSet()); + + taskPairs.addPairs(taskId, newAssignmentsForClient); + newAssignments.get(client).add(new AssignedTask(taskId, type)); + newTaskLocations.get(taskId).add(client); + } + + public Map buildKafkaStreamsAssignments() { + final Map kafkaStreamsAssignments = new HashMap<>(); + for (final Map.Entry> entry : newAssignments.entrySet()) { + final ProcessId processId = entry.getKey(); + final Set assignedTasks = newAssignments.get(processId); + final KafkaStreamsAssignment assignment = KafkaStreamsAssignment.of(processId, assignedTasks); + kafkaStreamsAssignments.put(processId, assignment); + } + return kafkaStreamsAssignments; + } + + public void processOptimizedAssignments(final Map optimizedAssignments) { + final Map> newTaskLocations = new HashMap<>(); + final Map> newAssignments = new HashMap<>(); + + for (final Map.Entry entry : optimizedAssignments.entrySet()) { + final ProcessId processId = entry.getKey(); + final Set assignedTasks = optimizedAssignments.get(processId).assignment(); + newAssignments.put(processId, assignedTasks); + + for (final AssignedTask task : assignedTasks) { + newTaskLocations.computeIfAbsent(task.id(), k -> new HashSet<>()); + newTaskLocations.get(task.id()).add(processId); + } + } + + this.newTaskLocations = newTaskLocations; + this.newAssignments = newAssignments; + } + + public boolean hasRoomForActiveTask(final ProcessId processId, final int activeTasksPerThread) { + final int capacity = clients.get(processId).numProcessingThreads(); + final int newActiveTaskCount = newAssignments.computeIfAbsent(processId, k -> new HashSet<>()) + .stream().filter(assignedTask -> assignedTask.type() == AssignedTask.Type.ACTIVE) + .collect(Collectors.toSet()) + .size(); + return newActiveTaskCount < capacity * activeTasksPerThread; + } + + public ProcessId findBestClientForTask(final TaskId taskId, final Set clientsWithin) { + if (clientsWithin.size() == 1) { + return clientsWithin.iterator().next(); + } + + final ProcessId previousClient = findLeastLoadedClientWithPreviousActiveOrStandbyTask( + taskId, clientsWithin); + if (previousClient == null) { + return findLeastLoadedClient(taskId, clientsWithin); + } + + if (shouldBalanceLoad(previousClient)) { + final ProcessId standby = findLeastLoadedClientWithPreviousStandbyTask(taskId, clientsWithin); + if (standby == null || shouldBalanceLoad(standby)) { + return findLeastLoadedClient(taskId, clientsWithin); + } + return standby; + } + return previousClient; + } + + public Set findClientsWithoutAssignedTask(final TaskId taskId) { + final Set unavailableClients = newTaskLocations.computeIfAbsent(taskId, k -> new HashSet<>()); + return clients.values().stream() + .map(KafkaStreamsState::processId) + .filter(o -> !unavailableClients.contains(o)) + .collect(Collectors.toSet()); + } + + public double clientLoad(final ProcessId processId) { + final int capacity = clients.get(processId).numProcessingThreads(); + final double totalTaskCount = newAssignments.getOrDefault(processId, new HashSet<>()).size(); + return totalTaskCount / capacity; + } + + public ProcessId findLeastLoadedClient(final TaskId taskId, final Set clientIds) { + ProcessId leastLoaded = null; + for (final ProcessId processId : clientIds) { + final double thisClientLoad = clientLoad(processId); + if (thisClientLoad == 0) { + return processId; + } + + if (leastLoaded == null || thisClientLoad < clientLoad(leastLoaded)) { + final Set assignedTasks = newAssignments.getOrDefault(processId, new HashSet<>()) + .stream().map(AssignedTask::id).collect(Collectors.toSet()); + if (taskPairs.hasNewPair(taskId, assignedTasks)) { + leastLoaded = processId; + } + } + } + + if (leastLoaded != null) { + return leastLoaded; + } + + for (final ProcessId processId : clientIds) { + final double thisClientLoad = clientLoad(processId); + + if (leastLoaded == null || thisClientLoad < clientLoad(leastLoaded)) { + leastLoaded = processId; + } + } + + return leastLoaded; + } + + public ProcessId findLeastLoadedClientWithPreviousActiveOrStandbyTask(final TaskId taskId, + final Set clientsWithin) { + final ProcessId previous = previousActiveAssignment.get(taskId); + if (previous != null && clientsWithin.contains(previous)) { + return previous; + } + return findLeastLoadedClientWithPreviousStandbyTask(taskId, clientsWithin); + } + + public ProcessId findLeastLoadedClientWithPreviousStandbyTask(final TaskId taskId, + final Set clientsWithin) { + final Set ids = previousStandbyAssignment.getOrDefault(taskId, new HashSet<>()); + final HashSet constrainTo = new HashSet<>(ids); + constrainTo.retainAll(clientsWithin); + return findLeastLoadedClient(taskId, constrainTo); + } + + public boolean shouldBalanceLoad(final ProcessId client) { + final double thisClientLoad = clientLoad(client); + if (thisClientLoad < 1) { + return false; + } + + for (final ProcessId otherClient : clients.keySet()) { + if (clientLoad(otherClient) < thisClientLoad) { + return true; + } + } + return false; + } + } + + private static class TaskPairs { + private final Set pairs; + private final int maxPairs; + + public TaskPairs(final int maxPairs) { + this.maxPairs = maxPairs; + this.pairs = new HashSet<>(maxPairs); + } + + public boolean hasNewPair(final TaskId task1, + final Set taskIds) { + if (pairs.size() == maxPairs) { + return false; + } + for (final TaskId taskId : taskIds) { + if (!pairs.contains(pair(task1, taskId))) { + return true; + } + } + return false; + } + + public void addPairs(final TaskId taskId, final Set assigned) { + for (final TaskId id : assigned) { + pairs.add(pair(id, taskId)); + } + } + + public TaskPair pair(final TaskId task1, final TaskId task2) { + if (task1.compareTo(task2) < 0) { + return new TaskPair(task1, task2); + } + return new TaskPair(task2, task1); + } + } + + private static class TaskPair { + private final TaskId task1; + private final TaskId task2; + + TaskPair(final TaskId task1, final TaskId task2) { + this.task1 = task1; + this.task2 = task2; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TaskPair pair = (TaskPair) o; + return Objects.equals(task1, pair.task1) && + Objects.equals(task2, pair.task2); + } + + @Override + public int hashCode() { + return Objects.hash(task1, task2); + } + } +}