/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.flowframework.workflow;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.AbstractRetryableWorkflowStep;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.BaseModelConfig;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.threadpool.ThreadPool;

public abstract class AbstractRegisterLocalModelStep
extends AbstractRetryableWorkflowStep {
    private static final Logger logger = LogManager.getLogger(AbstractRegisterLocalModelStep.class);
    private final MachineLearningNodeClient mlClient;
    private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

    protected AbstractRegisterLocalModelStep(ThreadPool threadPool, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, FlowFrameworkSettings flowFrameworkSettings) {
        super(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
        this.mlClient = mlClient;
        this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
    }

    @Override
    public PlainActionFuture<WorkflowData> execute(String currentNodeId, WorkflowData currentNodeInputs, Map<String, WorkflowData> outputs, Map<String, String> previousNodeInputs, Map<String, String> params, String tenantId) {
        PlainActionFuture registerLocalModelFuture = PlainActionFuture.newFuture();
        try {
            Map modelConfig;
            Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(this.getRequiredKeys(), this.getOptionalKeys(), currentNodeInputs, outputs, previousNodeInputs, params);
            String modelName = (String)inputs.get("name");
            String modelVersion = (String)inputs.get("version");
            String modelFormat = (String)inputs.get("model_format");
            String functionName = (String)inputs.get("function_name");
            String modelContentHashValue = (String)inputs.get("model_content_hash_value");
            String url = (String)inputs.get("url");
            String description = (String)inputs.get("description");
            String modelGroupId = (String)inputs.get("model_group_id");
            String modelInterface = (String)inputs.get("interface");
            Boolean deploy = ParseUtils.parseIfExists(inputs, "deploy", Boolean.class);
            MLRegisterModelInput.MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder().modelName(modelName).version(modelVersion).modelFormat(MLModelFormat.from((String)modelFormat));
            if (functionName != null) {
                mlInputBuilder.functionName(FunctionName.from((String)functionName));
            }
            if (modelContentHashValue != null) {
                mlInputBuilder.hashValue(modelContentHashValue);
            }
            if (url != null) {
                mlInputBuilder.url(url);
            }
            if ((modelConfig = (Map)inputs.get("model_config")) != null) {
                Map additionalConfig = modelConfig.containsKey("additional_config") ? (Map)modelConfig.get("additional_config") : null;
                BaseModelConfig baseModelConfig = BaseModelConfig.baseModelConfigBuilder().modelType((String)modelConfig.get("model_type")).additionalConfig(additionalConfig).allConfig((String)modelConfig.get("all_config")).build();
                mlInputBuilder.modelConfig((MLModelConfig)baseModelConfig);
            }
            if (description != null) {
                mlInputBuilder.description(description);
            }
            if (modelGroupId != null) {
                mlInputBuilder.modelGroupId(modelGroupId);
            }
            if (modelInterface != null) {
                try {
                    BytesArray modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8));
                    Map modelInterfaceAsMap = (Map)XContentHelper.convertToMap((BytesReference)modelInterfaceBytes, (boolean)false, (MediaType)MediaTypeRegistry.JSON).v2();
                    Map<String, String> parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap);
                    mlInputBuilder.modelInterface(parameters);
                }
                catch (Exception ex) {
                    String errorMessage = "Failed to create model interface";
                    logger.error(errorMessage, (Throwable)ex);
                    registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)));
                }
            }
            if (deploy != null) {
                mlInputBuilder.deployModel(deploy.booleanValue());
            }
            MLRegisterModelInput mlInput = mlInputBuilder.build();
            this.mlClient.register(mlInput, ActionListener.wrap(response -> {
                logger.info("Local Model registration task creation successful");
                String taskId = response.getTaskId();
                this.retryableGetMlTask(currentNodeInputs, currentNodeId, (PlainActionFuture<WorkflowData>)registerLocalModelFuture, taskId, "Local model registration", tenantId, (ActionListener<WorkflowData>)ActionListener.wrap(mlTaskWorkflowData -> {
                    String resourceName = WorkflowResources.getResourceByWorkflowStep(this.getName());
                    if (Boolean.TRUE.equals(deploy)) {
                        String id = (String)mlTaskWorkflowData.getContent().get(resourceName);
                        ActionListener deployUpdateListener = ActionListener.wrap(deployUpdateResponse -> registerLocalModelFuture.onResponse(mlTaskWorkflowData), deployUpdateException -> {
                            String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to update simulated deploy step resource {}", (Object)id).getFormattedMessage();
                            logger.error(errorMessage, (Throwable)deployUpdateException);
                            registerLocalModelFuture.onFailure((Exception)((Object)new FlowFrameworkException(errorMessage, ExceptionsHelper.status((Throwable)deployUpdateException))));
                        });
                        this.flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, currentNodeId, "deploy_model", id, tenantId, (ActionListener<WorkflowData>)deployUpdateListener);
                    } else {
                        registerLocalModelFuture.onResponse(mlTaskWorkflowData);
                    }
                }, arg_0 -> ((PlainActionFuture)registerLocalModelFuture).onFailure(arg_0)));
            }, exception -> {
                Exception e = WorkflowStepException.getSafeException(exception);
                String errorMessage = e == null ? ParameterizedMessageFactory.INSTANCE.newMessage("Failed to register local model in step {}", (Object)currentNodeId).getFormattedMessage() : e.getMessage();
                logger.error(errorMessage, (Throwable)e);
                registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException(errorMessage, ExceptionsHelper.status((Throwable)e))));
            }));
        }
        catch (IllegalArgumentException iae) {
            registerLocalModelFuture.onFailure((Exception)((Object)new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)));
        }
        catch (FlowFrameworkException e) {
            registerLocalModelFuture.onFailure((Exception)((Object)e));
        }
        return registerLocalModelFuture;
    }

    protected abstract Set<String> getRequiredKeys();

    protected abstract Set<String> getOptionalKeys();
}

