package org.apache.pinot.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable;
import org.apache.pinot.calcite.rel.logical.PinotLogicalAggregate;
import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange;
import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange;
import org.apache.pinot.common.function.sql.PinotSqlAggFunction;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.segment.spi.AggregationFunctionType;

/* loaded from: input_file:org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.class */
public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
    public static final PinotAggregateExchangeNodeInsertRule INSTANCE = new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);

    public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory relBuilderFactory) {
        super(operand(LogicalAggregate.class, any()), relBuilderFactory, (String) null);
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        boolean z = !rel.getGroupSet().isEmpty();
        RelCollation extractWithInGroupCollation = extractWithInGroupCollation(rel);
        Map<String, String> hintOptions = PinotHintStrategyTable.getHintOptions(rel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS);
        if (extractWithInGroupCollation != null || (z && hintOptions != null && Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)))) {
            relOptRuleCall.transformTo(createPlanWithExchangeDirectAggregation(relOptRuleCall, extractWithInGroupCollation));
        } else if (z && hintOptions != null && Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS))) {
            relOptRuleCall.transformTo(new PinotLogicalAggregate(rel, buildAggCalls(rel, AggregateNode.AggType.DIRECT, false), AggregateNode.AggType.DIRECT));
        } else {
            relOptRuleCall.transformTo(createPlanWithLeafExchangeFinalAggregate(relOptRuleCall, hintOptions != null && Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT))));
        }
    }

    @Nullable
    private static RelCollation extractWithInGroupCollation(Aggregate aggregate) {
        Iterator it = aggregate.getAggCallList().iterator();
        while (it.hasNext()) {
            RelCollation collation = ((AggregateCall) it.next()).getCollation();
            if (!collation.getFieldCollations().isEmpty()) {
                return collation;
            }
        }
        return null;
    }

    private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(RelOptRuleCall relOptRuleCall, @Nullable RelCollation relCollation) {
        Aggregate rel = relOptRuleCall.rel(0);
        RelNode input = rel.getInput();
        if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) {
            rel = (Aggregate) generateProjectUnderAggregate(relOptRuleCall);
            input = rel.getInput();
        }
        RelDistribution hash = RelDistributions.hash(rel.getGroupSet().asList());
        return new PinotLogicalAggregate(rel, relCollation != null ? PinotLogicalSortExchange.create(input, hash, relCollation, false, true) : PinotLogicalExchange.create(input, hash), buildAggCalls(rel, AggregateNode.AggType.DIRECT, false), AggregateNode.AggType.DIRECT);
    }

    private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall relOptRuleCall, boolean z) {
        Aggregate rel = relOptRuleCall.rel(0);
        return convertAggFromIntermediateInput(relOptRuleCall, PinotLogicalExchange.create(new PinotLogicalAggregate(rel, buildAggCalls(rel, AggregateNode.AggType.LEAF, z), AggregateNode.AggType.LEAF, z), RelDistributions.hash(ImmutableIntList.range(0, rel.getGroupCount()))), AggregateNode.AggType.FINAL, z);
    }

    private static RelNode generateProjectUnderAggregate(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        RelNode input = rel.getInput();
        ImmutableBitSet.Builder rebuild = rel.getGroupSet().rebuild();
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            Iterator it = aggregateCall.getArgList().iterator();
            while (it.hasNext()) {
                rebuild.set(((Integer) it.next()).intValue());
            }
            if (aggregateCall.filterArg >= 0) {
                rebuild.set(aggregateCall.filterArg);
            }
        }
        RelBuilder push = relOptRuleCall.builder().push(input);
        ArrayList arrayList = new ArrayList();
        Mapping create = Mappings.create(MappingType.INVERSE_SURJECTION, rel.getInput().getRowType().getFieldCount(), rebuild.cardinality());
        int i = 0;
        Iterator it2 = rebuild.build().iterator();
        while (it2.hasNext()) {
            int intValue = ((Integer) it2.next()).intValue();
            arrayList.add(push.field(intValue));
            int i2 = i;
            i++;
            create.set(intValue, i2);
        }
        push.project(arrayList);
        push.aggregate(push.groupKey(Mappings.apply(create, rel.getGroupSet()), (List) rel.getGroupSets().stream().map(immutableBitSet -> {
            return Mappings.apply(create, immutableBitSet);
        }).collect(ImmutableList.toImmutableList())), (List) rel.getAggCallList().stream().map(aggregateCall2 -> {
            return push.aggregateCall(aggregateCall2, create);
        }).collect(ImmutableList.toImmutableList()));
        return push.build();
    }

    private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleCall relOptRuleCall, PinotLogicalExchange pinotLogicalExchange, AggregateNode.AggType aggType, boolean z) {
        List arrayList;
        Aggregate rel = relOptRuleCall.rel(0);
        List<RexNode> findImmediateProjects = findImmediateProjects(rel.getInput());
        int groupCount = rel.getGroupCount();
        List aggCallList = rel.getAggCallList();
        int size = aggCallList.size();
        ArrayList arrayList2 = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            AggregateCall aggregateCall = (AggregateCall) aggCallList.get(i);
            List argList = aggregateCall.getArgList();
            RexInputRef of = RexInputRef.of(groupCount + i, rel.getRowType());
            int size2 = argList.size();
            if (size2 <= 1) {
                arrayList = ImmutableList.of(of);
            } else {
                arrayList = new ArrayList(size2);
                arrayList.add(of);
                for (int i2 = 1; i2 < size2; i2++) {
                    int intValue = ((Integer) argList.get(i2)).intValue();
                    if (findImmediateProjects == null || !(findImmediateProjects.get(intValue) instanceof RexLiteral)) {
                        arrayList.add(of);
                    } else {
                        arrayList.add(findImmediateProjects.get(intValue));
                    }
                }
            }
            arrayList2.add(buildAggCall(pinotLogicalExchange, aggregateCall, arrayList, groupCount, aggType, z));
        }
        return new PinotLogicalAggregate(rel, pinotLogicalExchange, ImmutableBitSet.range(groupCount), arrayList2, aggType, z);
    }

    private static List<AggregateCall> buildAggCalls(Aggregate aggregate, AggregateNode.AggType aggType, boolean z) {
        List arrayList;
        RelNode input = aggregate.getInput();
        List<RexNode> findImmediateProjects = findImmediateProjects(input);
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        ArrayList arrayList2 = new ArrayList(aggCallList.size());
        for (AggregateCall aggregateCall : aggCallList) {
            List argList = aggregateCall.getArgList();
            int size = argList.size();
            if (size == 0) {
                arrayList = ImmutableList.of();
            } else if (size == 1) {
                arrayList = ImmutableList.of(RexInputRef.of(((Integer) argList.get(0)).intValue(), input.getRowType()));
            } else {
                arrayList = new ArrayList(size);
                arrayList.add(RexInputRef.of(((Integer) argList.get(0)).intValue(), input.getRowType()));
                for (int i = 1; i < size; i++) {
                    int intValue = ((Integer) argList.get(i)).intValue();
                    if (findImmediateProjects == null || !(findImmediateProjects.get(intValue) instanceof RexLiteral)) {
                        arrayList.add(RexInputRef.of(intValue, input.getRowType()));
                    } else {
                        arrayList.add(findImmediateProjects.get(intValue));
                    }
                }
            }
            arrayList2.add(buildAggCall(input, aggregateCall, arrayList, aggregate.getGroupCount(), aggType, z));
        }
        return arrayList2;
    }

    private static AggregateCall buildAggCall(RelNode relNode, AggregateCall aggregateCall, List<RexNode> list, int i, AggregateNode.AggType aggType, boolean z) {
        SqlAggFunction aggregation = aggregateCall.getAggregation();
        String name = aggregation.getName();
        SqlKind kind = aggregation.getKind();
        SqlFunctionCategory functionType = aggregation.getFunctionType();
        if (aggregateCall.isDistinct()) {
            if (kind == SqlKind.COUNT) {
                name = "DISTINCTCOUNT";
                kind = SqlKind.OTHER_FUNCTION;
                functionType = SqlFunctionCategory.USER_DEFINED_FUNCTION;
            } else if (kind == SqlKind.LISTAGG) {
                list.add(relNode.getCluster().getRexBuilder().makeLiteral(true));
            }
        }
        SqlReturnTypeInference sqlReturnTypeInference = null;
        RelDataType relDataType = null;
        if (aggType.isOutputIntermediateFormat()) {
            AggregationFunctionType aggregationFunctionType = AggregationFunctionType.getAggregationFunctionType(name);
            sqlReturnTypeInference = z ? aggregationFunctionType.getFinalReturnTypeInference() : aggregationFunctionType.getIntermediateReturnTypeInference();
        }
        if (sqlReturnTypeInference == null) {
            relDataType = aggregateCall.getType();
            sqlReturnTypeInference = ReturnTypes.explicit(relDataType);
        }
        return AggregateCall.create(new PinotSqlAggFunction(name, kind, sqlReturnTypeInference, aggType.isInputIntermediateFormat() ? OperandTypes.ANY : aggregation.getOperandTypeChecker(), functionType), false, aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), list, ImmutableList.of(), aggType.isInputIntermediateFormat() ? -1 : aggregateCall.filterArg, aggregateCall.distinctKeys, aggregateCall.collation, i, relNode, relDataType, (String) null);
    }

    @Nullable
    private static List<RexNode> findImmediateProjects(RelNode relNode) {
        Project unboxRel = PinotRuleUtils.unboxRel(relNode);
        if (unboxRel instanceof Project) {
            return unboxRel.getProjects();
        }
        if (unboxRel instanceof Union) {
            return findImmediateProjects(unboxRel.getInput(0));
        }
        return null;
    }
}
