/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.packageloader.action;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Strings;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.RemovedTaskListener;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.common.notifications.Level;
import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.DownloadTaskRemovedListener;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;

public class TransportLoadTrainedModelPackage
extends TransportMasterNodeAction<LoadTrainedModelPackageAction.Request, AcknowledgedResponse> {
    private static final Logger logger = LogManager.getLogger(TransportLoadTrainedModelPackage.class);
    private final Client client;
    private final CircuitBreakerService circuitBreakerService;
    final Map<String, List<DownloadTaskRemovedListener>> taskRemovedListenersByModelId;

    @Inject
    public TransportLoadTrainedModelPackage(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, CircuitBreakerService circuitBreakerService) {
        super("cluster:internal/xpack/ml/trained_models/package_loader/load", transportService, clusterService, threadPool, actionFilters, LoadTrainedModelPackageAction.Request::new, indexNameExpressionResolver, NodeAcknowledgedResponse::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.client = new OriginSettingClient(client, "ml");
        this.circuitBreakerService = circuitBreakerService;
        this.taskRemovedListenersByModelId = new HashMap<String, List<DownloadTaskRemovedListener>>();
    }

    protected ClusterBlockException checkBlock(LoadTrainedModelPackageAction.Request request, ClusterState state) {
        return null;
    }

    protected void masterOperation(Task task, LoadTrainedModelPackageAction.Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) throws Exception {
        if (this.handleDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) {
            logger.debug("Existing download of model [{}] in progress", (Object)request.getModelId());
            return;
        }
        ModelDownloadTask downloadTask = this.createDownloadTask(request);
        try {
            ParentTaskAssigningClient parentTaskAssigningClient = this.getParentTaskAssigningClient((Task)downloadTask);
            ModelImporter modelImporter = new ModelImporter((Client)parentTaskAssigningClient, request.getModelId(), request.getModelPackageConfig(), downloadTask, this.threadPool, this.circuitBreakerService);
            ActionListener downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop();
            TransportLoadTrainedModelPackage.importModel(this.client, () -> this.unregisterTask(downloadTask), request, modelImporter, downloadTask, (ActionListener<AcknowledgedResponse>)downloadCompleteListener);
        }
        catch (Exception e) {
            this.taskManager.unregister((Task)downloadTask);
            listener.onFailure(e);
            return;
        }
        if (!request.isWaitForCompletion()) {
            listener.onResponse((Object)AcknowledgedResponse.TRUE);
        }
    }

    private ParentTaskAssigningClient getParentTaskAssigningClient(Task originTask) {
        TaskId parentTaskId = new TaskId(this.clusterService.localNode().getId(), originTask.getId());
        return new ParentTaskAssigningClient(this.client, parentTaskId);
    }

    synchronized boolean handleDownloadInProgress(String modelId, boolean isWaitForCompletion, ActionListener<AcknowledgedResponse> listener) {
        String description = ModelDownloadTask.taskDescription(modelId);
        Collection tasks = this.taskManager.getCancellableTasks().values();
        ModelDownloadTask inProgress = null;
        for (CancellableTask task : tasks) {
            ModelDownloadTask downloadTask;
            if (!description.equals(task.getDescription()) || !(task instanceof ModelDownloadTask)) continue;
            inProgress = downloadTask = (ModelDownloadTask)task;
            break;
        }
        if (inProgress != null) {
            if (!isWaitForCompletion) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
                return true;
            }
            DownloadTaskRemovedListener tracker = new DownloadTaskRemovedListener(inProgress, listener);
            this.taskRemovedListenersByModelId.computeIfAbsent(modelId, s -> new ArrayList()).add(tracker);
            this.taskManager.registerRemovedTaskListener((RemovedTaskListener)tracker);
            return true;
        }
        return false;
    }

    synchronized void unregisterTask(ModelDownloadTask task) {
        this.taskManager.unregister((Task)task);
        List<DownloadTaskRemovedListener> trackers = this.taskRemovedListenersByModelId.remove(task.getModelId());
        if (trackers != null) {
            for (DownloadTaskRemovedListener tracker : trackers) {
                this.taskManager.unregisterRemovedTaskListener((RemovedTaskListener)tracker);
            }
        }
    }

    static void importModel(Client auditClient, Runnable unregisterTaskFn, LoadTrainedModelPackageAction.Request request, ModelImporter modelImporter, ModelDownloadTask task, ActionListener<AcknowledgedResponse> listener) {
        String modelId = request.getModelId();
        long relativeStartNanos = System.nanoTime();
        TransportLoadTrainedModelPackage.logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO);
        ActionListener finishListener = ActionListener.wrap(success -> {
            long totalRuntimeNanos = System.nanoTime() - relativeStartNanos;
            TransportLoadTrainedModelPackage.logAndWriteNotificationAtLevel(auditClient, modelId, Strings.format((String)"finished model import after [%d] seconds", (Object[])new Object[]{TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)}), Level.INFO);
            listener.onResponse((Object)AcknowledgedResponse.TRUE);
        }, exception -> {
            task.setTaskException((Exception)exception);
            listener.onFailure(TransportLoadTrainedModelPackage.processException(auditClient, modelId, exception));
        });
        modelImporter.doImport((ActionListener<AcknowledgedResponse>)ActionListener.runAfter((ActionListener)finishListener, (Runnable)unregisterTaskFn));
    }

    static Exception processException(Client auditClient, String modelId, Exception e) {
        if (e instanceof TaskCancelledException) {
            TaskCancelledException te = (TaskCancelledException)e;
            return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, (ElasticsearchException)te, Level.WARNING);
        }
        if (e instanceof ElasticsearchException) {
            ElasticsearchException es = (ElasticsearchException)e;
            return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, es, Level.ERROR);
        }
        if (e instanceof MalformedURLException) {
            return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, "an invalid URL", e, Level.ERROR, RestStatus.BAD_REQUEST);
        }
        if (e instanceof URISyntaxException) {
            return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, "an invalid URL syntax", e, Level.ERROR, RestStatus.BAD_REQUEST);
        }
        if (e instanceof IOException) {
            return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, "an IOException", e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE);
        }
        return TransportLoadTrainedModelPackage.recordError(auditClient, modelId, "an Exception", e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR);
    }

    private ModelDownloadTask createDownloadTask(final LoadTrainedModelPackageAction.Request request) {
        try (ThreadContext.StoredContext ignored = this.threadPool.getThreadContext().newTraceContext();){
            ModelDownloadTask modelDownloadTask = (ModelDownloadTask)this.taskManager.register("model_import", "xpack/ml/model_import[n]", new TaskAwareRequest(){

                public void setParentTask(TaskId taskId) {
                    request.setParentTask(taskId);
                }

                public void setRequestId(long requestId) {
                    request.setRequestId(requestId);
                }

                public TaskId getParentTask() {
                    return request.getParentTask();
                }

                public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
                    return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers);
                }
            }, false);
            return modelDownloadTask;
        }
    }

    private static Exception recordError(Client client, String modelId, ElasticsearchException e, Level level) {
        String message = Strings.format((String)"Model importing failed due to [%s]", (Object[])new Object[]{e.getDetailedMessage()});
        TransportLoadTrainedModelPackage.logAndWriteNotificationAtLevel(client, modelId, message, level);
        return e;
    }

    private static Exception recordError(Client client, String modelId, String failureType, Exception e, Level level, RestStatus status) {
        String message = Strings.format((String)"Model importing failed due to %s [%s]", (Object[])new Object[]{failureType, e});
        TransportLoadTrainedModelPackage.logAndWriteNotificationAtLevel(client, modelId, message, level);
        return new ElasticsearchStatusException(message, status, (Throwable)e, new Object[0]);
    }

    private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) {
        TransportLoadTrainedModelPackage.writeNotification(client, modelId, message, level);
        logger.log(level.log4jLevel(), Strings.format((String)"[%s] %s", (Object[])new Object[]{modelId, message}));
    }

    private static void writeNotification(Client client, String modelId, String message, Level level) {
        client.execute((ActionType)AuditMlNotificationAction.INSTANCE, (ActionRequest)new AuditMlNotificationAction.Request(AuditMlNotificationAction.AuditType.INFERENCE, modelId, message, level), ActionListener.noop());
    }
}

