/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.BiPredicate;
import java.util.function.Function;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.ConstrainedPrioritySet;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor;
import org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignor;
import org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignorFactory;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
import org.apache.kafka.streams.processor.internals.assignment.TaskMovement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HighAvailabilityTaskAssignor
implements TaskAssignor {
    private static final Logger log = LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class);
    private static final int DEFAULT_STATEFUL_TRAFFIC_COST = 10;
    private static final int DEFAULT_STATEFUL_NON_OVERLAP_COST = 1;

    @Override
    public boolean assign(Map<UUID, ClientState> clients, Set<TaskId> allTaskIds, Set<TaskId> statefulTaskIds, RackAwareTaskAssignor rackAwareTaskAssignor, AssignorConfiguration.AssignmentConfigs configs) {
        TreeSet<TaskId> statefulTasks = new TreeSet<TaskId>(statefulTaskIds);
        TreeMap<UUID, ClientState> clientStates = new TreeMap<UUID, ClientState>(clients);
        HighAvailabilityTaskAssignor.assignActiveStatefulTasks(clientStates, statefulTasks, rackAwareTaskAssignor, configs);
        this.assignStandbyReplicaTasks(clientStates, allTaskIds, statefulTasks, rackAwareTaskAssignor, configs);
        AtomicInteger remainingWarmupReplicas = new AtomicInteger(configs.maxWarmupReplicas);
        Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = HighAvailabilityTaskAssignor.tasksToCaughtUpClients(statefulTasks, clientStates, configs.acceptableRecoveryLag);
        Map<TaskId, SortedSet<UUID>> tasksToClientByLag = HighAvailabilityTaskAssignor.tasksToClientByLag(statefulTasks, clientStates);
        TreeMap<UUID, Set<TaskId>> warmups = new TreeMap<UUID, Set<TaskId>>();
        int neededActiveTaskMovements = TaskMovement.assignActiveTaskMovements(tasksToCaughtUpClients, tasksToClientByLag, clientStates, warmups, remainingWarmupReplicas);
        int neededStandbyTaskMovements = TaskMovement.assignStandbyTaskMovements(tasksToCaughtUpClients, tasksToClientByLag, clientStates, remainingWarmupReplicas, warmups);
        HighAvailabilityTaskAssignor.assignStatelessActiveTasks(clientStates, Utils.diff(TreeSet::new, allTaskIds, statefulTasks), rackAwareTaskAssignor);
        boolean probingRebalanceNeeded = neededActiveTaskMovements + neededStandbyTaskMovements > 0;
        log.info("Decided on assignment: " + clientStates + " with" + (probingRebalanceNeeded ? "" : " no") + " followup probing rebalance.");
        return probingRebalanceNeeded;
    }

    private static void assignActiveStatefulTasks(SortedMap<UUID, ClientState> clientStates, SortedSet<TaskId> statefulTasks, RackAwareTaskAssignor rackAwareTaskAssignor, AssignorConfiguration.AssignmentConfigs configs) {
        Iterator<ClientState> clientStateIterator = null;
        for (TaskId task : statefulTasks) {
            if (clientStateIterator == null || !clientStateIterator.hasNext()) {
                clientStateIterator = clientStates.values().iterator();
            }
            clientStateIterator.next().assignActive(task);
        }
        HighAvailabilityTaskAssignor.balanceTasksOverThreads(clientStates, ClientState::activeTasks, ClientState::unassignActive, ClientState::assignActive, (source, destination) -> true);
        if (rackAwareTaskAssignor.canEnableRackAwareAssignor()) {
            int trafficCost = configs.rackAwareAssignmentTrafficCost == null ? 10 : configs.rackAwareAssignmentTrafficCost;
            int nonOverlapCost = configs.rackAwareAssignmentNonOverlapCost == null ? 1 : configs.rackAwareAssignmentNonOverlapCost;
            rackAwareTaskAssignor.optimizeActiveTasks(statefulTasks, clientStates, trafficCost, nonOverlapCost);
        }
    }

    private void assignStandbyReplicaTasks(TreeMap<UUID, ClientState> clientStates, Set<TaskId> allTaskIds, Set<TaskId> statefulTasks, RackAwareTaskAssignor rackAwareTaskAssignor, AssignorConfiguration.AssignmentConfigs configs) {
        if (configs.numStandbyReplicas == 0) {
            return;
        }
        StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(configs, null);
        standbyTaskAssignor.assign(clientStates, allTaskIds, statefulTasks, configs);
        HighAvailabilityTaskAssignor.balanceTasksOverThreads(clientStates, ClientState::standbyTasks, ClientState::unassignStandby, ClientState::assignStandby, standbyTaskAssignor::isAllowedTaskMovement);
        if (rackAwareTaskAssignor.canEnableRackAwareAssignor()) {
            int trafficCost = configs.rackAwareAssignmentTrafficCost == null ? 10 : configs.rackAwareAssignmentTrafficCost;
            int nonOverlapCost = configs.rackAwareAssignmentNonOverlapCost == null ? 1 : configs.rackAwareAssignmentNonOverlapCost;
            rackAwareTaskAssignor.optimizeStandbyTasks(clientStates, trafficCost, nonOverlapCost, standbyTaskAssignor::isAllowedTaskMovement);
        }
    }

    private static void balanceTasksOverThreads(SortedMap<UUID, ClientState> clientStates, Function<ClientState, Set<TaskId>> currentAssignmentAccessor, BiConsumer<ClientState, TaskId> taskUnassignor, BiConsumer<ClientState, TaskId> taskAssignor, BiPredicate<ClientState, ClientState> taskMovementAttemptPredicate) {
        boolean keepBalancing = true;
        while (keepBalancing) {
            keepBalancing = false;
            for (Map.Entry<UUID, ClientState> sourceEntry : clientStates.entrySet()) {
                UUID sourceClient = sourceEntry.getKey();
                ClientState sourceClientState = sourceEntry.getValue();
                for (Map.Entry<UUID, ClientState> destinationEntry : clientStates.entrySet()) {
                    UUID destinationClient = destinationEntry.getKey();
                    ClientState destinationClientState = destinationEntry.getValue();
                    if (sourceClient.equals(destinationClient)) continue;
                    TreeSet sourceTasks = new TreeSet(currentAssignmentAccessor.apply(sourceClientState));
                    Iterator sourceIterator = sourceTasks.iterator();
                    while (HighAvailabilityTaskAssignor.shouldMoveATask(sourceClientState, destinationClientState) && sourceIterator.hasNext()) {
                        TaskId taskToMove = (TaskId)sourceIterator.next();
                        boolean canMove = !destinationClientState.hasAssignedTask(taskToMove) && taskMovementAttemptPredicate.test(sourceClientState, destinationClientState);
                        if (!canMove) continue;
                        taskUnassignor.accept(sourceClientState, taskToMove);
                        taskAssignor.accept(destinationClientState, taskToMove);
                        keepBalancing = true;
                    }
                }
            }
        }
    }

    private static boolean shouldMoveATask(ClientState sourceClientState, ClientState destinationClientState) {
        double skew = sourceClientState.assignedTaskLoad() - destinationClientState.assignedTaskLoad();
        if (skew <= 0.0) {
            return false;
        }
        double proposedAssignedTasksPerStreamThreadAtDestination = ((double)destinationClientState.assignedTaskCount() + 1.0) / (double)destinationClientState.capacity();
        double proposedAssignedTasksPerStreamThreadAtSource = ((double)sourceClientState.assignedTaskCount() - 1.0) / (double)sourceClientState.capacity();
        double proposedSkew = proposedAssignedTasksPerStreamThreadAtSource - proposedAssignedTasksPerStreamThreadAtDestination;
        if (proposedSkew < 0.0) {
            return false;
        }
        return proposedSkew < skew;
    }

    private static void assignStatelessActiveTasks(TreeMap<UUID, ClientState> clientStates, Iterable<TaskId> statelessTasks, RackAwareTaskAssignor rackAwareTaskAssignor) {
        ConstrainedPrioritySet statelessActiveTaskClientsByTaskLoad = new ConstrainedPrioritySet((client, task) -> true, client -> ((ClientState)clientStates.get(client)).activeTaskLoad());
        statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet());
        TreeSet<TaskId> sortedTasks = new TreeSet<TaskId>();
        for (TaskId task2 : statelessTasks) {
            sortedTasks.add(task2);
            UUID client2 = statelessActiveTaskClientsByTaskLoad.poll(task2);
            ClientState state = clientStates.get(client2);
            state.assignActive(task2);
            statelessActiveTaskClientsByTaskLoad.offer(client2);
        }
        if (rackAwareTaskAssignor.canEnableRackAwareAssignor()) {
            rackAwareTaskAssignor.optimizeActiveTasks(sortedTasks, clientStates, 1, 0);
        }
    }

    private static Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients(Set<TaskId> statefulTasks, Map<UUID, ClientState> clientStates, long acceptableRecoveryLag) {
        HashMap<TaskId, SortedSet<UUID>> taskToCaughtUpClients = new HashMap<TaskId, SortedSet<UUID>>();
        for (TaskId task : statefulTasks) {
            TreeSet<UUID> caughtUpClients = new TreeSet<UUID>();
            for (Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
                UUID client = clientEntry.getKey();
                long taskLag = clientEntry.getValue().lagFor(task);
                if (!HighAvailabilityTaskAssignor.activeRunning(taskLag) && !HighAvailabilityTaskAssignor.unbounded(acceptableRecoveryLag) && !HighAvailabilityTaskAssignor.acceptable(acceptableRecoveryLag, taskLag)) continue;
                caughtUpClients.add(client);
            }
            taskToCaughtUpClients.put(task, caughtUpClients);
        }
        return taskToCaughtUpClients;
    }

    private static Map<TaskId, SortedSet<UUID>> tasksToClientByLag(Set<TaskId> statefulTasks, Map<UUID, ClientState> clientStates) {
        HashMap<TaskId, SortedSet<UUID>> tasksToClientByLag = new HashMap<TaskId, SortedSet<UUID>>();
        for (TaskId task : statefulTasks) {
            TreeSet<UUID> clientLag = new TreeSet<UUID>(Comparator.comparingLong(a -> ((ClientState)clientStates.get(a)).lagFor(task)).thenComparing(a -> a));
            clientLag.addAll(clientStates.keySet());
            tasksToClientByLag.put(task, clientLag);
        }
        return tasksToClientByLag;
    }

    private static boolean unbounded(long acceptableRecoveryLag) {
        return acceptableRecoveryLag == Long.MAX_VALUE;
    }

    private static boolean acceptable(long acceptableRecoveryLag, long taskLag) {
        return taskLag >= 0L && taskLag <= acceptableRecoveryLag;
    }

    private static boolean activeRunning(long taskLag) {
        return taskLag == -2L;
    }
}

