/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.expressions;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Model;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.internal.ModelImpl;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ContextResolvedFunction;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionUtils;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.SqlCallExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.BuiltInFunctionDefinition;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;

@Internal
public final class ApiExpressionUtils {
    public static final long MILLIS_PER_SECOND = 1000L;
    public static final long MILLIS_PER_MINUTE = 60000L;
    public static final long MILLIS_PER_HOUR = 3600000L;
    public static final long MILLIS_PER_DAY = 86400000L;

    private ApiExpressionUtils() {
    }

    public static Expression objectToExpression(Object expression) {
        if (expression == null) {
            return ApiExpressionUtils.valueLiteral(null, DataTypes.NULL());
        }
        if (expression instanceof ApiExpression) {
            return ((ApiExpression)expression).toExpr();
        }
        if (expression instanceof Expression) {
            return (Expression)expression;
        }
        if (expression instanceof Row) {
            RowKind kind = ((Row)expression).getKind();
            if (kind != RowKind.INSERT) {
                throw new ValidationException(String.format("Unsupported kind '%s' of a row [%s]. Only rows with 'INSERT' kind are supported when converting to an expression.", kind, expression));
            }
            return ApiExpressionUtils.convertRow((Row)expression);
        }
        if (expression instanceof Map) {
            return ApiExpressionUtils.convertJavaMap((Map)expression);
        }
        if (expression instanceof byte[]) {
            return ApiExpressionUtils.valueLiteral(expression);
        }
        if (expression.getClass().isArray()) {
            return ApiExpressionUtils.convertArray(expression);
        }
        if (expression instanceof List) {
            return ApiExpressionUtils.convertJavaList((List)expression);
        }
        return ApiExpressionUtils.convertScala(expression).orElseGet(() -> ApiExpressionUtils.valueLiteral(expression));
    }

    private static Expression convertRow(Row expression) {
        List<Expression> fields = IntStream.range(0, expression.getArity()).mapToObj(arg_0 -> ((Row)expression).getField(arg_0)).map(ApiExpressionUtils::objectToExpression).collect(Collectors.toList());
        return ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.ROW, fields);
    }

    private static Expression convertJavaMap(Map<?, ?> expression) {
        List<Expression> entries = expression.entrySet().stream().flatMap(e -> Stream.of(ApiExpressionUtils.objectToExpression(e.getKey()), ApiExpressionUtils.objectToExpression(e.getValue()))).collect(Collectors.toList());
        return ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.MAP, entries);
    }

    private static Expression convertJavaList(List<?> expression) {
        List<Expression> entries = expression.stream().map(ApiExpressionUtils::objectToExpression).collect(Collectors.toList());
        return ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.ARRAY, entries);
    }

    private static Expression convertArray(Object expression) {
        int length = Array.getLength(expression);
        List<Expression> entries = IntStream.range(0, length).mapToObj(idx -> Array.get(expression, idx)).map(ApiExpressionUtils::objectToExpression).collect(Collectors.toList());
        return ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.ARRAY, entries);
    }

    private static Optional<Expression> convertScala(Object obj) {
        try {
            Optional<Expression> array = ApiExpressionUtils.convertScalaSeq(obj);
            if (array.isPresent()) {
                return array;
            }
            Optional<Expression> bigDecimal = ApiExpressionUtils.convertScalaBigDecimal(obj);
            if (bigDecimal.isPresent()) {
                return bigDecimal;
            }
            return ApiExpressionUtils.convertScalaMap(obj);
        }
        catch (Exception e) {
            return Optional.empty();
        }
    }

    private static Optional<Expression> convertScalaMap(Object obj) throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        Class<?> mapClass = Class.forName("scala.collection.Map");
        if (mapClass.isAssignableFrom(obj.getClass())) {
            Class<?> seqClass = Class.forName("scala.collection.Seq");
            Class<?> productClass = Class.forName("scala.Product");
            Method getElement = productClass.getMethod("productElement", Integer.TYPE);
            Method toSeq = mapClass.getMethod("toSeq", new Class[0]);
            Method getMethod = seqClass.getMethod("apply", Object.class);
            Method lengthMethod = seqClass.getMethod("length", new Class[0]);
            Object mapAsSeq = toSeq.invoke(obj, new Object[0]);
            ArrayList<Expression> entries = new ArrayList<Expression>();
            for (int i = 0; i < (Integer)lengthMethod.invoke(mapAsSeq, new Object[0]); ++i) {
                Object mapEntry = getMethod.invoke(mapAsSeq, i);
                Object key = getElement.invoke(mapEntry, 0);
                Object value = getElement.invoke(mapEntry, 1);
                entries.add(ApiExpressionUtils.objectToExpression(key));
                entries.add(ApiExpressionUtils.objectToExpression(value));
            }
            return Optional.of(ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.MAP, entries));
        }
        return Optional.empty();
    }

    private static Optional<Expression> convertScalaSeq(Object obj) throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        Class<?> seqClass = Class.forName("scala.collection.Seq");
        if (seqClass.isAssignableFrom(obj.getClass())) {
            Method getMethod = seqClass.getMethod("apply", Object.class);
            Method lengthMethod = seqClass.getMethod("length", new Class[0]);
            ArrayList<Expression> entries = new ArrayList<Expression>();
            for (int i = 0; i < (Integer)lengthMethod.invoke(obj, new Object[0]); ++i) {
                entries.add(ApiExpressionUtils.objectToExpression(getMethod.invoke(obj, i)));
            }
            return Optional.of(ApiExpressionUtils.unresolvedCall((FunctionDefinition)BuiltInFunctionDefinitions.ARRAY, entries));
        }
        return Optional.empty();
    }

    private static Optional<Expression> convertScalaBigDecimal(Object obj) throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        Class<?> decimalClass = Class.forName("scala.math.BigDecimal");
        if (decimalClass.equals(obj.getClass())) {
            Method toJava = decimalClass.getMethod("underlying", new Class[0]);
            BigDecimal bigDecimal = (BigDecimal)toJava.invoke(obj, new Object[0]);
            return Optional.of(ApiExpressionUtils.valueLiteral(bigDecimal));
        }
        return Optional.empty();
    }

    public static Expression unwrapFromApi(Expression expression) {
        if (expression instanceof ApiExpression) {
            return ((ApiExpression)expression).toExpr();
        }
        return expression;
    }

    public static LocalReferenceExpression localRef(String name, DataType dataType) {
        return new LocalReferenceExpression(name, dataType);
    }

    public static ValueLiteralExpression valueLiteral(Object value) {
        return new ValueLiteralExpression(value);
    }

    public static ValueLiteralExpression valueLiteral(Object value, DataType dataType) {
        return new ValueLiteralExpression(value, dataType);
    }

    public static TypeLiteralExpression typeLiteral(DataType dataType) {
        return new TypeLiteralExpression(dataType);
    }

    public static UnresolvedReferenceExpression unresolvedRef(String name) {
        return new UnresolvedReferenceExpression(name);
    }

    public static UnresolvedCallExpression unresolvedCall(ContextResolvedFunction resolvedFunction, Expression ... args) {
        return ApiExpressionUtils.unresolvedCall(resolvedFunction, Arrays.asList(args));
    }

    public static UnresolvedCallExpression unresolvedCall(ContextResolvedFunction resolvedFunction, List<Expression> args) {
        return new UnresolvedCallExpression(resolvedFunction, args.stream().map(ApiExpressionUtils::unwrapFromApi).collect(Collectors.toList()));
    }

    public static UnresolvedCallExpression unresolvedCall(FunctionDefinition functionDefinition, Expression ... args) {
        return ApiExpressionUtils.unresolvedCall(functionDefinition, Arrays.asList(args));
    }

    public static UnresolvedCallExpression unresolvedCall(FunctionDefinition functionDefinition, List<Expression> args) {
        return new UnresolvedCallExpression(ContextResolvedFunction.anonymous(functionDefinition), args.stream().map(ApiExpressionUtils::unwrapFromApi).collect(Collectors.toList()));
    }

    public static TableReferenceExpression tableRef(String name, Table table) {
        return new TableReferenceExpression(name, table.getQueryOperation(), ((TableImpl)table).getTableEnvironment());
    }

    public static TableReferenceExpression tableRef(String name, QueryOperation queryOperation, TableEnvironment env) {
        return new TableReferenceExpression(name, queryOperation, env);
    }

    public static ModelReferenceExpression modelRef(String name, Model model) {
        return new ModelReferenceExpression(name, ((ModelImpl)model).getModel(), ((ModelImpl)model).getTableEnvironment());
    }

    public static LookupCallExpression lookupCall(String name, Expression ... args) {
        return new LookupCallExpression(name, Arrays.stream(args).map(ApiExpressionUtils::unwrapFromApi).collect(Collectors.toList()));
    }

    public static SqlCallExpression sqlCall(String sqlExpression) {
        return new SqlCallExpression(sqlExpression);
    }

    public static Expression toMonthInterval(Expression e, int multiplier) {
        return ExpressionUtils.extractValue(e, BigDecimal.class).map(v -> ApiExpressionUtils.intervalOfMonths(v.intValue() * multiplier)).orElseThrow(() -> new ValidationException("Invalid constant for year-month interval: " + String.valueOf(e)));
    }

    public static ValueLiteralExpression intervalOfMillis(long millis) {
        return ApiExpressionUtils.valueLiteral(millis, (DataType)((DataType)DataTypes.INTERVAL(DataTypes.SECOND(3)).notNull()).bridgedTo(Long.class));
    }

    public static Expression toMilliInterval(Expression e, long multiplier) {
        return ExpressionUtils.extractValue(e, BigDecimal.class).map(v -> ApiExpressionUtils.intervalOfMillis(v.longValue() * multiplier)).orElseThrow(() -> new ValidationException("Invalid constant for day-time interval: " + String.valueOf(e)));
    }

    public static ValueLiteralExpression intervalOfMonths(int months) {
        return ApiExpressionUtils.valueLiteral(months, (DataType)((DataType)DataTypes.INTERVAL(DataTypes.MONTH()).notNull()).bridgedTo(Integer.class));
    }

    public static Expression toRowInterval(Expression e) {
        return ExpressionUtils.extractValue(e, BigDecimal.class).map(bd -> ApiExpressionUtils.valueLiteral(bd.longValue())).orElseThrow(() -> new ValidationException("Invalid constant for row interval: " + String.valueOf(e)));
    }

    public static boolean isFunctionOfKind(Expression expression, FunctionKind kind) {
        if (expression instanceof UnresolvedCallExpression) {
            return ((UnresolvedCallExpression)expression).getFunctionDefinition().getKind() == kind;
        }
        if (expression instanceof CallExpression) {
            return ((CallExpression)expression).getFunctionDefinition().getKind() == kind;
        }
        return false;
    }

    public static boolean isFunction(Expression expression, BuiltInFunctionDefinition functionDefinition) {
        if (expression instanceof UnresolvedCallExpression) {
            return ((UnresolvedCallExpression)expression).getFunctionDefinition() == functionDefinition;
        }
        if (expression instanceof CallExpression) {
            return ((CallExpression)expression).getFunctionDefinition() == functionDefinition;
        }
        return false;
    }
}

