/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.functions.sql.ml;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlModelCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.util.NlsString;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.ml.TaskType;
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
import org.apache.flink.table.types.logical.LogicalType;

public class SqlMLEvaluateTableFunction
extends SqlMLTableFunction {
    public static final String PARAM_LABEL = "LABEL";
    public static final String PARAM_ARGS = "ARGS";
    public static final String PARAM_TASK = "TASK";

    public SqlMLEvaluateTableFunction() {
        super("ML_EVALUATE", new EvaluateOperandMetadata());
    }

    @Override
    public boolean argumentMustBeScalar(int ordinal) {
        return ordinal != 0;
    }

    @Override
    protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
        RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
        RelDataType inputRowType = opBinding.getOperandType(0);
        RelDataType keyType = typeFactory.createSqlType(SqlTypeName.VARCHAR);
        RelDataType valueType = typeFactory.createSqlType(SqlTypeName.DOUBLE);
        RelDataType mapType = typeFactory.createMapType(keyType, valueType);
        return typeFactory.builder().kind(inputRowType.getStructKind()).add("result", mapType).build();
    }

    private static class EvaluateOperandMetadata
    implements SqlOperandMetadata {
        private static final List<String> PARAM_NAMES = List.of("INPUT", "MODEL", "LABEL", "ARGS", "TASK", "CONFIG");
        private static final List<String> MANDATORY_PARAM_NAMES = List.of("INPUT", "MODEL", "LABEL", "ARGS", "TASK");

        EvaluateOperandMetadata() {
        }

        @Override
        public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
            return Collections.nCopies(PARAM_NAMES.size(), typeFactory.createSqlType(SqlTypeName.ANY));
        }

        @Override
        public List<String> paramNames() {
            return PARAM_NAMES;
        }

        @Override
        public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
            if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2, 3)) {
                return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
            }
            if (!SqlValidatorUtils.throwExceptionOrReturnFalse(SqlMLTableFunction.checkModelSignature(callBinding, 3), throwOnFailure)) {
                return false;
            }
            if (!SqlValidatorUtils.throwExceptionOrReturnFalse(EvaluateOperandMetadata.checkModelOutputType(callBinding, 2), throwOnFailure)) {
                return false;
            }
            if (!SqlValidatorUtils.throwExceptionOrReturnFalse(EvaluateOperandMetadata.checkTask(callBinding.operand(4)), throwOnFailure)) {
                return false;
            }
            if (callBinding.getOperandCount() == PARAM_NAMES.size()) {
                return SqlValidatorUtils.throwExceptionOrReturnFalse(SqlMLTableFunction.checkConfig(callBinding, callBinding.operand(5)), throwOnFailure);
            }
            return true;
        }

        @Override
        public SqlOperandCountRange getOperandCountRange() {
            return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size());
        }

        @Override
        public boolean isOptional(int i) {
            return i >= this.getOperandCountRange().getMin() && i < this.getOperandCountRange().getMax();
        }

        @Override
        public String getAllowedSignatures(SqlOperator op, String opName) {
            return opName + "(TABLE table_name, MODEL model_name, DESCRIPTOR(label_column), DESCRIPTOR(feature_columns), [task], [MAP[]])";
        }

        private static Optional<RuntimeException> checkTask(SqlNode node) {
            if (!(node instanceof SqlCharStringLiteral)) {
                return Optional.of(new ValidationException("Expected a valid task string literal, but got: " + node.getClass().getSimpleName() + "."));
            }
            String task = ((SqlCharStringLiteral)node).getValueAs(NlsString.class).getValue();
            if (!TaskType.isValidTaskType((String)task)) {
                return Optional.of(new ValidationException("Unsupported task: " + task + ". Supported tasks are: " + Arrays.toString(TaskType.values()) + "."));
            }
            return Optional.empty();
        }

        private static Optional<RuntimeException> checkModelOutputType(SqlCallBinding callBinding, int outputDescriptorIndex) {
            SqlCall descriptorCall = (SqlCall)callBinding.operand(outputDescriptorIndex);
            List<SqlNode> descriptCols = descriptorCall.getOperandList();
            if (descriptCols.size() != 1) {
                return Optional.of(new ValidationException("Label descriptor must have exactly one column for evaluation."));
            }
            SqlValidator validator = callBinding.getValidator();
            SqlModelCall modelCall = (SqlModelCall)callBinding.operand(1);
            RelDataType modelOutputType = modelCall.getOutputType(validator);
            if (modelOutputType.getFieldCount() != 1) {
                return Optional.of(new ValidationException("Model output must have exactly one field for evaluation."));
            }
            RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0));
            SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
            Tuple3<Boolean, LogicalType, LogicalType> result = SqlMLTableFunction.checkModelDescriptorType(tableType, modelOutputType.getFieldList().get(0).getType(), descriptCols.get(0), matcher);
            if (!((Boolean)result.f0).booleanValue()) {
                return Optional.of(new ValidationException(String.format("Label descriptor column type %s cannot be assigned to model output type %s for evaluation.", result.f1, result.f2)));
            }
            return Optional.empty();
        }
    }
}

