/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.sql.analyzer.FieldId;
import io.trino.sql.analyzer.RelationId;
import io.trino.sql.analyzer.ResolvedField;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArrayConstructor;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.GroupingOperation;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.SubscriptExpression;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public final class GroupingOperationRewriter {
    private GroupingOperationRewriter() {
    }

    public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, ResolvedField> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
        Objects.requireNonNull(groupIdSymbol, "groupIdSymbol is null");
        if (groupingSets.size() == 1) {
            return new LongLiteral("0");
        }
        Preconditions.checkState((boolean)groupIdSymbol.isPresent(), (Object)"groupId symbol is missing");
        RelationId relationId = columnReferenceFields.get(NodeRef.of((Node)((Expression)expression.getGroupingColumns().get(0)))).getFieldId().getRelationId();
        List columns = (List)expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> Preconditions.checkState((boolean)columnReferenceFields.containsKey(groupingColumn), (Object)"the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(ResolvedField::getFieldId).map(fieldId -> GroupingOperationRewriter.translateFieldToInteger(fieldId, relationId)).collect(ImmutableList.toImmutableList());
        List groupingResults = (List)groupingSets.stream().map(groupingSet -> String.valueOf(GroupingOperationRewriter.calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(ImmutableList.toImmutableList());
        return new SubscriptExpression((Expression)new ArrayConstructor(groupingResults), (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, (Expression)groupIdSymbol.get().toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "1")));
    }

    private static int translateFieldToInteger(FieldId fieldId, RelationId requiredOriginRelationId) {
        Preconditions.checkState((boolean)fieldId.getRelationId().equals(requiredOriginRelationId), (Object)"grouping arguments must all come from the same relation");
        return fieldId.getFieldIndex();
    }

    static long calculateGrouping(Set<Integer> groupingSet, List<Integer> columns) {
        long grouping = (1L << columns.size()) - 1L;
        for (int index = 0; index < columns.size(); ++index) {
            int column = columns.get(index);
            if (!groupingSet.contains(column)) continue;
            grouping &= 1L << columns.size() - 1 - index ^ 0xFFFFFFFFFFFFFFFFL;
        }
        return grouping;
    }
}

