package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.hint.PinotHintOptions;
import org.apache.calcite.rel.hint.PinotHintStrategyTable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.PinotLogicalExchange;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.PinotSqlAggFunction;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.shaded.com.google.common.collect.ImmutableList;

/* loaded from: input_file:org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.class */
public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
    public static final PinotAggregateExchangeNodeInsertRule INSTANCE = new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);

    public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory relBuilderFactory) {
        super(operand(LogicalAggregate.class, any()), relBuilderFactory, null);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return relOptRuleCall.rels.length >= 1 && (relOptRuleCall.rel(0) instanceof Aggregate) && !PinotHintStrategyTable.containsHintOption(((Aggregate) relOptRuleCall.rel(0)).getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        ImmutableList<RelHint> hints = aggregate.getHints();
        relOptRuleCall.transformTo((aggregate.getGroupSet().isEmpty() || !PinotHintStrategyTable.isHintOptionTrue(hints, PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) ? (aggregate.getGroupSet().isEmpty() || !PinotHintStrategyTable.isHintOptionTrue(hints, PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)) ? (Aggregate) createPlanWithLeafExchangeFinalAggregate(relOptRuleCall) : (Aggregate) createPlanWithExchangeDirectAggregation(relOptRuleCall) : new LogicalAggregate(aggregate.getCluster(), aggregate.getTraitSet(), PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggregateNode.AggType.DIRECT.name()), aggregate.getInput(), aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList()));
    }

    private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        Aggregate convertAggForLeafInput = convertAggForLeafInput(aggregate);
        List<Integer> range = ImmutableIntList.range(0, aggregate.getGroupCount());
        return convertAggFromIntermediateInput(relOptRuleCall, aggregate, range.size() == 0 ? PinotLogicalExchange.create(convertAggForLeafInput, RelDistributions.hash(Collections.emptyList())) : PinotLogicalExchange.create(convertAggForLeafInput, RelDistributions.hash(range)), AggregateNode.AggType.FINAL);
    }

    private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        List<RelHint> replaceHintOptions = PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggregateNode.AggType.DIRECT.name());
        RelNode currentRel = ((HepRelVertex) aggregate.getInput()).getCurrentRel();
        if (!(currentRel instanceof Project)) {
            return convertAggForExchangeDirectAggregate(relOptRuleCall, replaceHintOptions);
        }
        ArrayList arrayList = new ArrayList();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        Objects.requireNonNull(arrayList);
        groupSet.forEach((v1) -> {
            r1.add(v1);
        });
        return new LogicalAggregate(aggregate.getCluster(), aggregate.getTraitSet(), replaceHintOptions, PinotLogicalExchange.create(currentRel, RelDistributions.hash(arrayList)), aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList());
    }

    private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall relOptRuleCall, List<RelHint> list) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        RelNode input = aggregate.getInput();
        ImmutableBitSet.Builder rebuild = aggregate.getGroupSet().rebuild();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            Iterator<Integer> it2 = aggregateCall.getArgList().iterator();
            while (it2.hasNext()) {
                rebuild.set(it2.next().intValue());
            }
            if (aggregateCall.filterArg >= 0) {
                rebuild.set(aggregateCall.filterArg);
            }
        }
        RelBuilder push = relOptRuleCall.builder().push(input);
        ArrayList arrayList = new ArrayList();
        Mapping create = Mappings.create(MappingType.INVERSE_SURJECTION, aggregate.getInput().getRowType().getFieldCount(), rebuild.cardinality());
        int i = 0;
        Iterator<Integer> it3 = rebuild.build().iterator();
        while (it3.hasNext()) {
            int intValue = it3.next().intValue();
            arrayList.add(push.field(intValue));
            int i2 = i;
            i++;
            create.set(intValue, i2);
        }
        push.project(arrayList);
        ImmutableBitSet apply = Mappings.apply(create, aggregate.getGroupSet());
        RelBuilder push2 = relOptRuleCall.builder().push(PinotLogicalExchange.create((Project) push.build(), RelDistributions.hash(apply.asList())));
        push2.aggregate(push2.groupKey(apply, (Iterable<? extends ImmutableBitSet>) ((List) aggregate.getGroupSets().stream().map(immutableBitSet -> {
            return Mappings.apply(create, immutableBitSet);
        }).collect(Util.toImmutableList()))), (Iterable<RelBuilder.AggCall>) aggregate.getAggCallList().stream().map(aggregateCall2 -> {
            return push2.aggregateCall(aggregateCall2, create);
        }).collect(Util.toImmutableList())).hints(list);
        return push2.build();
    }

    private Aggregate convertAggForLeafInput(Aggregate aggregate) {
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        ArrayList arrayList = new ArrayList();
        for (AggregateCall aggregateCall : aggCallList) {
            arrayList.add(buildAggregateCall(aggregate.getInput(), aggregateCall, aggregateCall.getArgList(), aggregate.getGroupCount(), AggregateNode.AggType.LEAF));
        }
        return new LogicalAggregate(aggregate.getCluster(), aggregate.getTraitSet(), PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, AggregateNode.AggType.LEAF.name()), aggregate.getInput(), aggregate.getGroupSet(), aggregate.getGroupSets(), arrayList);
    }

    private RelNode convertAggFromIntermediateInput(RelOptRuleCall relOptRuleCall, Aggregate aggregate, PinotLogicalExchange pinotLogicalExchange, AggregateNode.AggType aggType) {
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(pinotLogicalExchange);
        RexBuilder rexBuilder = pinotLogicalExchange.getCluster().getRexBuilder();
        int groupCount = aggregate.getGroupCount();
        for (int i = 0; i < groupCount; i++) {
            rexBuilder.makeInputRef(aggregate, i);
        }
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        for (int i2 = 0; i2 < aggCallList.size(); i2++) {
            AggregateCall buildAggregateCall = buildAggregateCall(pinotLogicalExchange, aggCallList.get(i2), Collections.singletonList(Integer.valueOf(groupCount + i2)), groupCount, aggType);
            RelNode input = aggregate.getInput();
            Objects.requireNonNull(input);
            rexBuilder.addAggCall(buildAggregateCall, groupCount, arrayList, hashMap, input::fieldIsNullable);
        }
        List<RelHint> replaceHintOptions = PinotHintStrategyTable.replaceHintOptions(aggregate.getHints(), PinotHintOptions.INTERNAL_AGG_OPTIONS, PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name());
        ImmutableBitSet range = ImmutableBitSet.range(groupCount);
        builder.aggregate(builder.groupKey(range, (Iterable<? extends ImmutableBitSet>) ImmutableList.of(range)), (List<AggregateCall>) arrayList);
        builder.hints(replaceHintOptions);
        return builder.build();
    }

    private static AggregateCall buildAggregateCall(RelNode relNode, AggregateCall aggregateCall, List<Integer> list, int i, AggregateNode.AggType aggType) {
        SqlAggFunction sqlAggFunction;
        SqlAggFunction aggregation = aggregateCall.getAggregation();
        aggregation.getKind();
        String functionNameFromAggregateCall = getFunctionNameFromAggregateCall(aggregateCall);
        AggregationFunctionType aggregationFunctionType = AggregationFunctionType.getAggregationFunctionType(functionNameFromAggregateCall);
        if (aggregationFunctionType.getIntermediateReturnTypeInference() != null) {
            switch (aggType) {
                case LEAF:
                    sqlAggFunction = new PinotSqlAggFunction(functionNameFromAggregateCall.toUpperCase(Locale.ROOT), null, aggregationFunctionType.getSqlKind(), aggregationFunctionType.getIntermediateReturnTypeInference(), null, aggregationFunctionType.getOperandTypeChecker(), aggregationFunctionType.getSqlFunctionCategory());
                    break;
                case INTERMEDIATE:
                    sqlAggFunction = new PinotSqlAggFunction(functionNameFromAggregateCall.toUpperCase(Locale.ROOT), null, aggregationFunctionType.getSqlKind(), aggregationFunctionType.getIntermediateReturnTypeInference(), null, OperandTypes.ANY, aggregationFunctionType.getSqlFunctionCategory());
                    break;
                case FINAL:
                    sqlAggFunction = new PinotSqlAggFunction(functionNameFromAggregateCall.toUpperCase(Locale.ROOT), null, aggregationFunctionType.getSqlKind(), ReturnTypes.explicit(aggregateCall.getType()), null, OperandTypes.ANY, aggregationFunctionType.getSqlFunctionCategory());
                    break;
                default:
                    throw new UnsupportedOperationException("Unsuppoted aggType: " + aggType + " for " + functionNameFromAggregateCall);
            }
        } else {
            sqlAggFunction = aggregation;
        }
        return AggregateCall.create(sqlAggFunction, functionNameFromAggregateCall.equals("distinctCount") || aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), list, aggType.isInputIntermediateFormat() ? -1 : aggregateCall.filterArg, aggregateCall.distinctKeys, aggregateCall.collation, i, relNode, null, null);
    }

    private static String getFunctionNameFromAggregateCall(AggregateCall aggregateCall) {
        return (aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT") && aggregateCall.isDistinct()) ? "distinctCount" : aggregateCall.getAggregation().getName();
    }
}
