package org.apache.pinot.calcite.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.UnmodifiableIterator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalWindow;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange;
import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange;

/* loaded from: input_file:org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.class */
public class PinotWindowExchangeNodeInsertRule extends RelOptRule {
    public static final PinotWindowExchangeNodeInsertRule INSTANCE = new PinotWindowExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);
    private static final EnumSet<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND = EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK, SqlKind.DENSE_RANK, SqlKind.NTILE, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE, SqlKind.OTHER_FUNCTION);

    public PinotWindowExchangeNodeInsertRule(RelBuilderFactory relBuilderFactory) {
        super(operand(Window.class, any()), relBuilderFactory, (String) null);
    }

    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return !PinotRuleUtils.isExchange(relOptRuleCall.rel(0).getInput());
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        RelNode create;
        Window window = (Window) relOptRuleCall.rel(0);
        validateWindows(window);
        HepRelVertex input = window.getInput();
        Window.Group updateLiteralArgumentsInWindowGroup = updateLiteralArgumentsInWindowGroup(window);
        if (!updateLiteralArgumentsInWindowGroup.keys.isEmpty()) {
            create = isPartitionByOnlyQuery(updateLiteralArgumentsInWindowGroup) ? PinotLogicalExchange.create(input, RelDistributions.hash(updateLiteralArgumentsInWindowGroup.keys.toList())) : PinotLogicalSortExchange.create(input, RelDistributions.hash(updateLiteralArgumentsInWindowGroup.keys.toList()), updateLiteralArgumentsInWindowGroup.orderKeys, false, true);
        } else if (updateLiteralArgumentsInWindowGroup.orderKeys.getKeys().isEmpty()) {
            if (PinotRuleUtils.isProject(input)) {
                Project project = (Project) input.getCurrentRel();
                if (project.getProjects().isEmpty()) {
                    relOptRuleCall.transformTo(handleEmptyProjectBelowWindow(window, project));
                    return;
                }
            }
            create = PinotLogicalExchange.create(input, RelDistributions.hash(List.of()));
        } else {
            create = PinotLogicalSortExchange.create(input, RelDistributions.hash(List.of()), updateLiteralArgumentsInWindowGroup.orderKeys, false, true);
        }
        relOptRuleCall.transformTo(LogicalWindow.create(window.getTraitSet(), create, window.constants, window.getRowType(), List.of(updateLiteralArgumentsInWindowGroup)));
    }

    private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
        Window.Group group = (Window.Group) window.groups.get(0);
        Project currentRel = window.getInput().getCurrentRel();
        int fieldCount = currentRel.getRowType().getFieldCount();
        List<RexNode> projects = currentRel instanceof Project ? currentRel.getProjects() : null;
        ArrayList arrayList = new ArrayList(group.aggCalls.size());
        boolean z = false;
        UnmodifiableIterator it = group.aggCalls.iterator();
        while (it.hasNext()) {
            Window.RexWinAggCall rexWinAggCall = (Window.RexWinAggCall) it.next();
            boolean z2 = false;
            List<RexNode> operands = rexWinAggCall.getOperands();
            ArrayList arrayList2 = new ArrayList(operands.size());
            for (RexNode rexNode : operands) {
                RexLiteral literal = getLiteral(rexNode, fieldCount, window.constants, projects);
                if (literal != null) {
                    arrayList2.add(literal);
                    z2 = true;
                    z = true;
                } else {
                    arrayList2.add(rexNode);
                }
            }
            if (z2) {
                arrayList.add(new Window.RexWinAggCall(rexWinAggCall.getOperator(), rexWinAggCall.type, arrayList2, rexWinAggCall.ordinal, rexWinAggCall.distinct, rexWinAggCall.ignoreNulls));
            } else {
                arrayList.add(rexWinAggCall);
            }
        }
        RexWindowBound rexWindowBound = group.lowerBound;
        RexNode offset = rexWindowBound.getOffset();
        if (offset != null) {
            RexLiteral literal2 = getLiteral(offset, fieldCount, window.constants, projects);
            if (literal2 == null) {
                throw new IllegalStateException("Could not read window lower bound literal value from window group: " + String.valueOf(group));
            }
            rexWindowBound = rexWindowBound.isPreceding() ? RexWindowBounds.preceding(literal2) : RexWindowBounds.following(literal2);
            z = true;
        }
        RexWindowBound rexWindowBound2 = group.upperBound;
        RexNode offset2 = rexWindowBound2.getOffset();
        if (offset2 != null) {
            RexLiteral literal3 = getLiteral(offset2, fieldCount, window.constants, projects);
            if (literal3 == null) {
                throw new IllegalStateException("Could not read window upper bound literal value from window group: " + String.valueOf(group));
            }
            rexWindowBound2 = rexWindowBound2.isFollowing() ? RexWindowBounds.following(literal3) : RexWindowBounds.preceding(literal3);
            z = true;
        }
        return z ? new Window.Group(group.keys, group.isRows, rexWindowBound, rexWindowBound2, group.orderKeys, arrayList) : group;
    }

    @Nullable
    private RexLiteral getLiteral(RexNode rexNode, int i, ImmutableList<RexLiteral> immutableList, @Nullable List<RexNode> list) {
        if (!(rexNode instanceof RexInputRef)) {
            return null;
        }
        int index = ((RexInputRef) rexNode).getIndex();
        if (index >= i) {
            return (RexLiteral) immutableList.get(index - i);
        }
        if (list == null) {
            return null;
        }
        RexLiteral rexLiteral = (RexNode) list.get(index);
        if (rexLiteral instanceof RexLiteral) {
            return rexLiteral;
        }
        return null;
    }

    private void validateWindows(Window window) {
        int size = window.groups.size();
        Preconditions.checkState(size == 1, String.format("Currently only 1 window group is supported, query has %d groups", Integer.valueOf(size)));
        Window.Group group = (Window.Group) window.groups.get(0);
        validateWindowAggCallsSupported(group);
        validateWindowFrames(group);
    }

    private void validateWindowAggCallsSupported(Window.Group group) {
        UnmodifiableIterator it = group.aggCalls.iterator();
        while (it.hasNext()) {
            SqlKind kind = ((Window.RexWinAggCall) it.next()).getKind();
            Preconditions.checkState(SUPPORTED_WINDOW_FUNCTION_KIND.contains(kind), String.format("Unsupported Window function kind: %s. Only aggregation functions are supported!", kind));
        }
    }

    private void validateWindowFrames(Window.Group group) {
        RexWindowBound rexWindowBound = group.lowerBound;
        RexWindowBound rexWindowBound2 = group.upperBound;
        boolean z = (rexWindowBound.isPreceding() && !rexWindowBound.isUnbounded()) || (rexWindowBound2.isFollowing() && !rexWindowBound2.isUnbounded());
        if (group.isRows) {
            return;
        }
        Preconditions.checkState(!z, "RANGE window frame with offset PRECEDING / FOLLOWING is not supported");
    }

    private boolean isPartitionByOnlyQuery(Window.Group group) {
        boolean z = false;
        if (group.orderKeys.getKeys().isEmpty()) {
            return true;
        }
        if (group.orderKeys.getKeys().size() == group.keys.asList().size()) {
            z = new HashSet(group.keys.toList()).equals(new HashSet((Collection) group.orderKeys.getKeys()));
        }
        return z;
    }

    private RelNode handleEmptyProjectBelowWindow(Window window, Project project) {
        RelOptCluster cluster = window.getCluster();
        LogicalProject create = LogicalProject.create(project.getInput(), project.getHints(), Collections.singletonList(cluster.getRexBuilder().makeLiteral(0, cluster.getTypeFactory().createSqlType(SqlTypeName.INTEGER))), Collections.singletonList("winLiteral"));
        RelDataTypeFactory.FieldInfoBuilder builder = cluster.getTypeFactory().builder();
        builder.addAll(create.getRowType().getFieldList());
        builder.addAll(window.getRowType().getFieldList());
        LogicalWindow logicalWindow = new LogicalWindow(window.getCluster(), window.getTraitSet(), PinotLogicalExchange.create(create, RelDistributions.hash(Collections.emptyList())), window.getConstants(), builder.build(), window.groups);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        List fieldList = logicalWindow.getRowType().getFieldList();
        for (int i = 1; i < fieldList.size(); i++) {
            arrayList.add(new RexInputRef(i, ((RelDataTypeField) fieldList.get(i)).getType()));
            arrayList2.add(String.format("$%d", 0));
        }
        return LogicalProject.create(logicalWindow, project.getHints(), arrayList, arrayList2);
    }
}
