package org.apache.pinot.query.runtime.operator;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.docvalsets.DataBlockValSet;
import org.apache.pinot.core.operator.docvalsets.FilteredDataBlockValSet;
import org.apache.pinot.core.operator.docvalsets.FilteredRowBasedBlockValSet;
import org.apache.pinot.core.operator.docvalsets.RowBasedBlockValSet;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
import org.apache.pinot.core.util.DataBlockExtractUtils;
import org.apache.pinot.query.planner.logical.LiteralHintUtils;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator.class */
public class AggregateOperator extends MultiStageOperator {
    private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
    private static final CountAggregationFunction COUNT_STAR_AGG_FUNCTION = new CountAggregationFunction(Collections.singletonList(ExpressionContext.forIdentifier("*")), false);
    private static final ExpressionContext PLACEHOLDER_IDENTIFIER = ExpressionContext.forIdentifier("__PLACEHOLDER__");
    private final MultiStageOperator _inputOperator;
    private final DataSchema _resultSchema;
    private final AggregateNode.AggType _aggType;
    private final MultistageAggregationExecutor _aggregationExecutor;
    private final MultistageGroupByExecutor _groupByExecutor;
    private boolean _hasConstructedAggregateBlock;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.pinot.query.runtime.operator.AggregateOperator$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$calcite$sql$SqlKind = new int[SqlKind.values().length];

        static {
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.INPUT_REF.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.LITERAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public AggregateOperator(OpChainExecutionContext opChainExecutionContext, MultiStageOperator multiStageOperator, DataSchema dataSchema, List<RexExpression> list, List<RexExpression> list2, AggregateNode.AggType aggType, List<Integer> list3, @Nullable AbstractPlanNode.NodeHint nodeHint) {
        super(opChainExecutionContext);
        Map map;
        this._inputOperator = multiStageOperator;
        this._resultSchema = dataSchema;
        this._aggType = aggType;
        Map<Integer, Map<Integer, Literal>> map2 = null;
        if (nodeHint != null && (map = (Map) nodeHint._hintOptions.get("aggOptionsInternal")) != null) {
            map2 = LiteralHintUtils.hintStringToLiteralMap((String) map.get("agg_call_signature"));
        }
        AggregationFunction<?, ?>[] aggFunctions = getAggFunctions(list, map2 == null ? Collections.emptyMap() : map2);
        int length = aggFunctions.length;
        int[] iArr = new int[length];
        int i = -1;
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = list3.get(i2).intValue();
            i = Math.max(i, iArr[i2]);
        }
        if (list2.isEmpty()) {
            this._aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, iArr, i, aggType, this._resultSchema);
            this._groupByExecutor = null;
        } else {
            this._groupByExecutor = new MultistageGroupByExecutor(getGroupKeyIds(list2), aggFunctions, iArr, i, aggType, this._resultSchema, opChainExecutionContext.getOpChainMetadata(), nodeHint);
            this._aggregationExecutor = null;
        }
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    public List<MultiStageOperator> getChildOperators() {
        return ImmutableList.of(this._inputOperator);
    }

    @Nullable
    public String toExplainString() {
        return EXPLAIN_NAME;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    protected TransferableBlock getNextBlock() {
        if (this._hasConstructedAggregateBlock) {
            return TransferableBlockUtils.getEndOfStreamTransferableBlock();
        }
        TransferableBlock consumeAggregation = this._aggregationExecutor != null ? consumeAggregation() : consumeGroupBy();
        return consumeAggregation.isErrorBlock() ? consumeAggregation : produceAggregatedBlock();
    }

    private TransferableBlock produceAggregatedBlock() {
        this._hasConstructedAggregateBlock = true;
        if (this._aggregationExecutor != null) {
            return new TransferableBlock(this._aggregationExecutor.getResult(), this._resultSchema, DataBlock.Type.ROW);
        }
        List<Object[]> result = this._groupByExecutor.getResult();
        if (result.isEmpty()) {
            return TransferableBlockUtils.getEndOfStreamTransferableBlock();
        }
        TransferableBlock transferableBlock = new TransferableBlock(result, this._resultSchema, DataBlock.Type.ROW);
        if (this._groupByExecutor.isNumGroupsLimitReached()) {
            this._opChainStats.getOperatorStats(this._context, this._operatorId).recordSingleStat(DataTable.MetadataKey.NUM_GROUPS_LIMIT_REACHED.getName(), "true");
            this._inputOperator.earlyTerminate();
        }
        return transferableBlock;
    }

    private TransferableBlock consumeGroupBy() {
        TransferableBlock m22nextBlock = this._inputOperator.m22nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m22nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._groupByExecutor.processBlock(transferableBlock);
            m22nextBlock = this._inputOperator.m22nextBlock();
        }
    }

    private TransferableBlock consumeAggregation() {
        TransferableBlock m22nextBlock = this._inputOperator.m22nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m22nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._aggregationExecutor.processBlock(transferableBlock);
            m22nextBlock = this._inputOperator.m22nextBlock();
        }
    }

    private AggregationFunction<?, ?>[] getAggFunctions(List<RexExpression> list, Map<Integer, Map<Integer, Literal>> map) {
        int size = list.size();
        AggregationFunction<?, ?>[] aggregationFunctionArr = new AggregationFunction[size];
        if (this._aggType.isInputIntermediateFormat()) {
            for (int i = 0; i < size; i++) {
                aggregationFunctionArr[i] = getAggFunctionForIntermediateInput(list.get(i), map.getOrDefault(Integer.valueOf(i), Collections.emptyMap()));
            }
        } else {
            for (int i2 = 0; i2 < size; i2++) {
                aggregationFunctionArr[i2] = getAggFunctionForRawInput((RexExpression.FunctionCall) list.get(i2), map.getOrDefault(Integer.valueOf(i2), Collections.emptyMap()));
            }
        }
        return aggregationFunctionArr;
    }

    private AggregationFunction<?, ?> getAggFunctionForRawInput(RexExpression.FunctionCall functionCall, Map<Integer, Literal> map) {
        String functionName = functionCall.getFunctionName();
        List functionOperands = functionCall.getFunctionOperands();
        int size = functionOperands.size();
        if (size == 0) {
            Preconditions.checkState(functionName.equals("COUNT"), "Aggregate function without argument must be COUNT, got: %s", functionName);
            return COUNT_STAR_AGG_FUNCTION;
        }
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            Literal literal = map.get(Integer.valueOf(i));
            if (literal != null) {
                arrayList.add(ExpressionContext.forLiteralContext(literal));
            } else {
                RexExpression.InputRef inputRef = (RexExpression) functionOperands.get(i);
                switch (AnonymousClass1.$SwitchMap$org$apache$calcite$sql$SqlKind[inputRef.getKind().ordinal()]) {
                    case 1:
                        arrayList.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
                        break;
                    case 2:
                        RexExpression.Literal literal2 = (RexExpression.Literal) inputRef;
                        arrayList.add(ExpressionContext.forLiteralContext(literal2.getDataType().toDataType(), literal2.getValue()));
                        break;
                    default:
                        throw new IllegalStateException("Illegal aggregation function operand type: " + inputRef.getKind());
                }
            }
        }
        return AggregationFunctionFactory.getAggregationFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arrayList), true);
    }

    private static AggregationFunction<?, ?> getAggFunctionForIntermediateInput(RexExpression.FunctionCall functionCall, Map<Integer, Literal> map) {
        String functionName = functionCall.getFunctionName();
        List functionOperands = functionCall.getFunctionOperands();
        int size = functionOperands.size();
        Preconditions.checkState(size == 1, "Intermediate aggregate must have 1 argument, got: %s", size);
        RexExpression.InputRef inputRef = (RexExpression) functionOperands.get(0);
        Preconditions.checkState(inputRef.getKind() == SqlKind.INPUT_REF, "Intermediate aggregate argument must be an input reference, got: %s", inputRef.getKind());
        Literal literal = map.get(-1);
        if (literal == null) {
            return AggregationFunctionFactory.getAggregationFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, Collections.singletonList(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())))), true);
        }
        int intValue = literal.getIntValue();
        ArrayList arrayList = new ArrayList(intValue);
        arrayList.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex())));
        for (int i = 1; i < intValue; i++) {
            Literal literal2 = map.get(Integer.valueOf(i));
            if (literal2 != null) {
                arrayList.add(ExpressionContext.forLiteralContext(literal2));
            } else {
                arrayList.add(PLACEHOLDER_IDENTIFIER);
            }
        }
        return AggregationFunctionFactory.getAggregationFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arrayList), true);
    }

    private static String fromColIdToIdentifier(int i) {
        return "$" + i;
    }

    private static int fromIdentifierToColId(String str) {
        Preconditions.checkArgument(str.charAt(0) == '$', "Got identifier not representing column index: %s", str);
        return Integer.parseInt(str.substring(1));
    }

    private int[] getGroupKeyIds(List<RexExpression> list) {
        int size = list.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            RexExpression.InputRef inputRef = (RexExpression) list.get(i);
            Preconditions.checkState(inputRef.getKind() == SqlKind.INPUT_REF, "Group key must be an input reference, got: %s", inputRef.getKind());
            iArr[i] = inputRef.getIndex();
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static RoaringBitmap getMatchedBitmap(TransferableBlock transferableBlock, int i) {
        Preconditions.checkArgument(i >= 0, "Got negative filter argument id: %s", i);
        RoaringBitmap roaringBitmap = new RoaringBitmap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            int size = container.size();
            for (int i2 = 0; i2 < size; i2++) {
                if (((Integer) container.get(i2)[i]).intValue() == 1) {
                    roaringBitmap.add(i2);
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            int numberOfRows = dataBlock.getNumberOfRows();
            for (int i3 = 0; i3 < numberOfRows; i3++) {
                if (dataBlock.getInt(i3, i) == 1) {
                    roaringBitmap.add(i3);
                }
            }
        }
        return roaringBitmap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock) {
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 0) {
            return Collections.emptyMap();
        }
        DataSchema dataSchema = transferableBlock.getDataSchema();
        HashMap hashMap = new HashMap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            for (ExpressionContext expressionContext : inputExpressions) {
                String identifier = expressionContext.getIdentifier();
                if (identifier != null) {
                    int fromIdentifierToColId = fromIdentifierToColId(identifier);
                    hashMap.put(expressionContext, new RowBasedBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId), container, fromIdentifierToColId, true));
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            for (ExpressionContext expressionContext2 : inputExpressions) {
                String identifier2 = expressionContext2.getIdentifier();
                if (identifier2 != null) {
                    int fromIdentifierToColId2 = fromIdentifierToColId(identifier2);
                    hashMap.put(expressionContext2, new DataBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId2), dataBlock, fromIdentifierToColId2));
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<ExpressionContext, BlockValSet> getFilteredBlockValSetMap(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock, int i, RoaringBitmap roaringBitmap) {
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 0) {
            return Collections.emptyMap();
        }
        DataSchema dataSchema = transferableBlock.getDataSchema();
        HashMap hashMap = new HashMap();
        if (transferableBlock.isContainerConstructed()) {
            List<Object[]> container = transferableBlock.getContainer();
            for (ExpressionContext expressionContext : inputExpressions) {
                String identifier = expressionContext.getIdentifier();
                if (identifier != null) {
                    int fromIdentifierToColId = fromIdentifierToColId(identifier);
                    hashMap.put(expressionContext, new FilteredRowBasedBlockValSet(dataSchema.getColumnDataType(fromIdentifierToColId), container, fromIdentifierToColId, i, roaringBitmap, true));
                }
            }
        } else {
            DataBlock dataBlock = transferableBlock.getDataBlock();
            for (ExpressionContext expressionContext2 : inputExpressions) {
                String identifier2 = expressionContext2.getIdentifier();
                if (identifier2 != null) {
                    int fromIdentifierToColId2 = fromIdentifierToColId(identifier2);
                    hashMap.put(expressionContext2, new FilteredDataBlockValSet(transferableBlock.getDataSchema().getColumnDataType(fromIdentifierToColId2), dataBlock, fromIdentifierToColId2, i, roaringBitmap));
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Object[] getIntermediateResults(AggregationFunction<?, ?> aggregationFunction, TransferableBlock transferableBlock) {
        ExpressionContext expressionContext = (ExpressionContext) aggregationFunction.getInputExpressions().get(0);
        Preconditions.checkState(expressionContext.getType() == ExpressionContext.Type.IDENTIFIER, "Expected the first argument to be IDENTIFIER, got: %s", expressionContext.getType());
        int fromIdentifierToColId = fromIdentifierToColId(expressionContext.getIdentifier());
        int numRows = transferableBlock.getNumRows();
        if (!transferableBlock.isContainerConstructed()) {
            return DataBlockExtractUtils.extractColumn(transferableBlock.getDataBlock(), fromIdentifierToColId);
        }
        Object[] objArr = new Object[numRows];
        List<Object[]> container = transferableBlock.getContainer();
        for (int i = 0; i < numRows; i++) {
            objArr[i] = container.get(i)[fromIdentifierToColId];
        }
        return objArr;
    }
}
