package org.apache.pinot.query.runtime.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.datatable.StatMap;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.config.QueryOptionsUtils;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.docvalsets.DataBlockValSet;
import org.apache.pinot.core.operator.docvalsets.FilteredDataBlockValSet;
import org.apache.pinot.core.operator.docvalsets.FilteredRowBasedBlockValSet;
import org.apache.pinot.core.operator.docvalsets.RowBasedBlockValSet;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
import org.apache.pinot.core.util.DataBlockExtractUtils;
import org.apache.pinot.core.util.GroupByUtils;
import org.apache.pinot.query.parser.CalciteRexExpressionParser;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.operator.MultiStageOperator;
import org.apache.pinot.query.runtime.operator.utils.SortUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator.class */
public class AggregateOperator extends MultiStageOperator {
    private static final Logger LOGGER;
    private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
    private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION;
    private final MultiStageOperator _input;
    private final DataSchema _resultSchema;
    private final AggregationFunction<?, ?>[] _aggFunctions;
    private final MultistageAggregationExecutor _aggregationExecutor;
    private final MultistageGroupByExecutor _groupByExecutor;

    @Nullable
    private TransferableBlock _eosBlock;
    private final StatMap<StatKey> _statMap;
    private boolean _hasConstructedAggregateBlock;
    private final boolean _errorOnNumGroupsLimit;
    private final int _groupTrimSize;

    @Nullable
    private final Comparator<Object[]> _comparator;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator$StatKey.class */
    public enum StatKey implements StatMap.Key {
        EXECUTION_TIME_MS(StatMap.Type.LONG) { // from class: org.apache.pinot.query.runtime.operator.AggregateOperator.StatKey.1
            public boolean includeDefaultInJson() {
                return true;
            }
        },
        EMITTED_ROWS(StatMap.Type.LONG) { // from class: org.apache.pinot.query.runtime.operator.AggregateOperator.StatKey.2
            public boolean includeDefaultInJson() {
                return true;
            }
        },
        NUM_GROUPS_LIMIT_REACHED(StatMap.Type.BOOLEAN),
        NUM_GROUPS_WARNING_LIMIT_REACHED(StatMap.Type.BOOLEAN);

        private final StatMap.Type _type;

        StatKey(StatMap.Type type) {
            this._type = type;
        }

        public StatMap.Type getType() {
            return this._type;
        }
    }

    public AggregateOperator(OpChainExecutionContext opChainExecutionContext, MultiStageOperator multiStageOperator, AggregateNode aggregateNode) {
        super(opChainExecutionContext);
        this._statMap = new StatMap<>(StatKey.class);
        this._input = multiStageOperator;
        this._resultSchema = aggregateNode.getDataSchema();
        this._aggFunctions = getAggFunctions(aggregateNode.getAggCalls());
        int length = this._aggFunctions.length;
        List filterArgs = aggregateNode.getFilterArgs();
        int[] iArr = new int[length];
        int i = -1;
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = ((Integer) filterArgs.get(i2)).intValue();
            i = Math.max(i, iArr[i2]);
        }
        List<Integer> groupKeys = aggregateNode.getGroupKeys();
        int i3 = Integer.MAX_VALUE;
        SortUtils.SortComparator sortComparator = null;
        int limit = aggregateNode.getLimit();
        int minGroupTrimSize = getMinGroupTrimSize(aggregateNode.getNodeHint(), opChainExecutionContext.getOpChainMetadata());
        if (limit > 0 && minGroupTrimSize > 0) {
            List collations = aggregateNode.getCollations();
            if (collations.isEmpty()) {
                i3 = limit;
            } else {
                i3 = GroupByUtils.getTableCapacity(limit, minGroupTrimSize);
                if (i3 < Integer.MAX_VALUE) {
                    sortComparator = new SortUtils.SortComparator(this._resultSchema, collations, true);
                }
            }
        }
        this._groupTrimSize = i3;
        this._comparator = sortComparator;
        this._errorOnNumGroupsLimit = getErrorOnNumGroupsLimit(aggregateNode.getNodeHint(), opChainExecutionContext.getOpChainMetadata());
        AggregateNode.AggType aggType = aggregateNode.getAggType();
        boolean isLeafReturnFinalResult = aggregateNode.isLeafReturnFinalResult();
        if (groupKeys.isEmpty()) {
            this._aggregationExecutor = new MultistageAggregationExecutor(this._aggFunctions, iArr, i, aggType, this._resultSchema);
            this._groupByExecutor = null;
        } else {
            this._groupByExecutor = new MultistageGroupByExecutor(getGroupKeyIds(groupKeys), this._aggFunctions, iArr, i, aggType, isLeafReturnFinalResult, this._resultSchema, opChainExecutionContext.getOpChainMetadata(), aggregateNode.getNodeHint());
            this._aggregationExecutor = null;
        }
    }

    private static int getMinGroupTrimSize(PlanNode.NodeHint nodeHint, Map<String, String> map) {
        String option = getOption(nodeHint, "mse_min_group_trim_size");
        if (option != null) {
            return Integer.parseInt(option);
        }
        Integer mSEMinGroupTrimSize = QueryOptionsUtils.getMSEMinGroupTrimSize(map);
        if (mSEMinGroupTrimSize != null) {
            return mSEMinGroupTrimSize.intValue();
        }
        return 5000;
    }

    private static boolean getErrorOnNumGroupsLimit(PlanNode.NodeHint nodeHint, Map<String, String> map) {
        String option = getOption(nodeHint, "error_on_num_groups_limit");
        return option != null ? Boolean.parseBoolean(option) : QueryOptionsUtils.getErrorOnNumGroupsLimit(map);
    }

    @Nullable
    private static String getOption(PlanNode.NodeHint nodeHint, String str) {
        Map map = (Map) nodeHint.getHintOptions().get("aggOptions");
        if (map != null) {
            return (String) map.get(str);
        }
        return null;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    public void registerExecution(long j, int i) {
        this._statMap.merge(StatKey.EXECUTION_TIME_MS, j);
        this._statMap.merge(StatKey.EMITTED_ROWS, i);
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    public MultiStageOperator.Type getOperatorType() {
        return MultiStageOperator.Type.AGGREGATE;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    protected Logger logger() {
        return LOGGER;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    public List<MultiStageOperator> getChildOperators() {
        return List.of(this._input);
    }

    public String toExplainString() {
        return EXPLAIN_NAME;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    protected TransferableBlock getNextBlock() {
        if (this._hasConstructedAggregateBlock) {
            if ($assertionsDisabled || this._eosBlock != null) {
                return this._eosBlock;
            }
            throw new AssertionError();
        }
        TransferableBlock consumeAggregation = this._aggregationExecutor != null ? consumeAggregation() : consumeGroupBy();
        if (consumeAggregation.isErrorBlock()) {
            return consumeAggregation;
        }
        if (!$assertionsDisabled && !consumeAggregation.isSuccessfulEndOfStreamBlock()) {
            throw new AssertionError("Final block must be EOS block");
        }
        this._eosBlock = updateEosBlock(consumeAggregation, this._statMap);
        return produceAggregatedBlock();
    }

    private TransferableBlock produceAggregatedBlock() {
        this._hasConstructedAggregateBlock = true;
        if (this._aggregationExecutor != null) {
            return new TransferableBlock(this._aggregationExecutor.getResult(), this._resultSchema, DataBlock.Type.ROW, this._aggFunctions);
        }
        List<Object[]> result = this._comparator != null ? this._groupByExecutor.getResult(this._comparator, this._groupTrimSize) : this._groupByExecutor.getResult(this._groupTrimSize);
        if (result.isEmpty()) {
            return this._eosBlock;
        }
        TransferableBlock transferableBlock = new TransferableBlock(result, this._resultSchema, DataBlock.Type.ROW, this._aggFunctions);
        if (this._groupByExecutor.isNumGroupsLimitReached()) {
            if (this._errorOnNumGroupsLimit) {
                this._input.earlyTerminate();
                throw new RuntimeException("NUM_GROUPS_LIMIT has been reached at " + this._operatorId);
            }
            this._statMap.merge(StatKey.NUM_GROUPS_LIMIT_REACHED, true);
            this._input.earlyTerminate();
        }
        if (this._groupByExecutor.getNumGroups() >= this._groupByExecutor.getNumGroupsWarningLimit()) {
            LOGGER.warn("numGroups reached warning limit: {} (actual: {})", Integer.valueOf(this._groupByExecutor.getNumGroupsWarningLimit()), Integer.valueOf(this._groupByExecutor.getNumGroups()));
            this._statMap.merge(StatKey.NUM_GROUPS_WARNING_LIMIT_REACHED, true);
        }
        return transferableBlock;
    }

    private TransferableBlock consumeGroupBy() {
        TransferableBlock m43nextBlock = this._input.m43nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m43nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._groupByExecutor.processBlock(transferableBlock);
            sampleAndCheckInterruption();
            m43nextBlock = this._input.m43nextBlock();
        }
    }

    private TransferableBlock consumeAggregation() {
        TransferableBlock m43nextBlock = this._input.m43nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m43nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._aggregationExecutor.processBlock(transferableBlock);
            sampleAndCheckInterruption();
            m43nextBlock = this._input.m43nextBlock();
        }
    }

    private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression.FunctionCall> list) {
        int size = list.size();
        AggregationFunction<?, ?>[] aggregationFunctionArr = new AggregationFunction[size];
        for (int i = 0; i < size; i++) {
            aggregationFunctionArr[i] = getAggFunction(list.get(i));
        }
        return aggregationFunctionArr;
    }

    private AggregationFunction<?, ?> getAggFunction(RexExpression.FunctionCall functionCall) {
        String functionName = functionCall.getFunctionName();
        List<RexExpression.Literal> functionOperands = functionCall.getFunctionOperands();
        int size = functionOperands.size();
        if (size == 0) {
            Preconditions.checkState(functionName.equals("COUNT"), "Aggregate function without argument must be COUNT, got: %s", functionName);
            return COUNT_STAR_AGG_FUNCTION;
        }
        ArrayList arrayList = new ArrayList(size);
        for (RexExpression.Literal literal : functionOperands) {
            if (literal instanceof RexExpression.InputRef) {
                arrayList.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(((RexExpression.InputRef) literal).getIndex())));
            } else {
                if (!$assertionsDisabled && !(literal instanceof RexExpression.Literal)) {
                    throw new AssertionError();
                }
                arrayList.add(ExpressionContext.forLiteral(CalciteRexExpressionParser.toLiteral(literal)));
            }
        }
        return AggregationFunctionFactory.getAggregationFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arrayList), true);
    }

    private static String fromColIdToIdentifier(int i) {
        return "$" + i;
    }

    private static int fromIdentifierToColId(String str) {
        Preconditions.checkArgument(str.charAt(0) == '$', "Got identifier not representing column index: %s", str);
        return Integer.parseInt(str.substring(1));
    }

    private int[] getGroupKeyIds(List<Integer> list) {
        int size = list.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = list.get(i).intValue();
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static RoaringBitmap getMatchedBitmap(TransferableBlock transferableBlock, int i) {
        Preconditions.checkArgument(i >= 0, "Got negative filter argument id: %s", i);
        RoaringBitmap roaringBitmap = new RoaringBitmap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            int size = container.size();
            for (int i2 = 0; i2 < size; i2++) {
                if (((Integer) container.get(i2)[i]).intValue() == 1) {
                    roaringBitmap.add(i2);
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            int numberOfRows = dataBlock.getNumberOfRows();
            for (int i3 = 0; i3 < numberOfRows; i3++) {
                if (dataBlock.getInt(i3, i) == 1) {
                    roaringBitmap.add(i3);
                }
            }
        }
        return roaringBitmap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock) {
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 0) {
            return Collections.emptyMap();
        }
        DataSchema dataSchema = transferableBlock.getDataSchema();
        if (!$assertionsDisabled && dataSchema == null) {
            throw new AssertionError();
        }
        HashMap hashMap = new HashMap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            for (ExpressionContext expressionContext : inputExpressions) {
                String identifier = expressionContext.getIdentifier();
                if (identifier != null) {
                    int fromIdentifierToColId = fromIdentifierToColId(identifier);
                    hashMap.put(expressionContext, new RowBasedBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId), container, fromIdentifierToColId, true));
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            for (ExpressionContext expressionContext2 : inputExpressions) {
                String identifier2 = expressionContext2.getIdentifier();
                if (identifier2 != null) {
                    int fromIdentifierToColId2 = fromIdentifierToColId(identifier2);
                    hashMap.put(expressionContext2, new DataBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId2), dataBlock, fromIdentifierToColId2));
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<ExpressionContext, BlockValSet> getFilteredBlockValSetMap(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock, int i, RoaringBitmap roaringBitmap) {
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 0) {
            return Collections.emptyMap();
        }
        DataSchema dataSchema = transferableBlock.getDataSchema();
        if (!$assertionsDisabled && dataSchema == null) {
            throw new AssertionError();
        }
        HashMap hashMap = new HashMap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            for (ExpressionContext expressionContext : inputExpressions) {
                String identifier = expressionContext.getIdentifier();
                if (identifier != null) {
                    int fromIdentifierToColId = fromIdentifierToColId(identifier);
                    hashMap.put(expressionContext, new FilteredRowBasedBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId), container, fromIdentifierToColId, i, roaringBitmap, true));
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            for (ExpressionContext expressionContext2 : inputExpressions) {
                String identifier2 = expressionContext2.getIdentifier();
                if (identifier2 != null) {
                    int fromIdentifierToColId2 = fromIdentifierToColId(identifier2);
                    hashMap.put(expressionContext2, new FilteredDataBlockValSet(transferableBlock.getDataSchema().getColumnDataType(fromIdentifierToColId2), dataBlock, fromIdentifierToColId2, i, roaringBitmap));
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Object[] getIntermediateResults(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock) {
        ExpressionContext expressionContext = (ExpressionContext) aggregationFunction.getInputExpressions().get(0);
        Preconditions.checkState(expressionContext.getType() == ExpressionContext.Type.IDENTIFIER, "Expected the first argument to be IDENTIFIER, got: %s", expressionContext.getType());
        int fromIdentifierToColId = fromIdentifierToColId(expressionContext.getIdentifier());
        int numRows = transferableBlock.getNumRows();
        if (!transferableBlock.isContainerConstructed()) {
            return DataBlockExtractUtils.extractAggResult(transferableBlock.getDataBlock(), fromIdentifierToColId, aggregationFunction);
        }
        Object[] objArr = new Object[numRows];
        List<Object[]> container = transferableBlock.getContainer();
        for (int i = 0; i < numRows; i++) {
            objArr[i] = container.get(i)[fromIdentifierToColId];
        }
        return objArr;
    }

    @VisibleForTesting
    int getGroupTrimSize() {
        return this._groupTrimSize;
    }

    static {
        $assertionsDisabled = !AggregateOperator.class.desiredAssertionStatus();
        LOGGER = LoggerFactory.getLogger(AggregateOperator.class);
        COUNT_STAR_AGG_FUNCTION = new CountAggregationFunction(Collections.singletonList(ExpressionContext.forIdentifier("*")), false);
    }
}
