package org.apache.pinot.query.runtime.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.query.planner.logical.LiteralHintUtils;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.block.DataBlockValSet;
import org.apache.pinot.query.runtime.operator.block.FilteredDataBlockValSet;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.spi.data.FieldSpec;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator.class */
public class AggregateOperator extends MultiStageOperator {
    private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
    private final MultiStageOperator _inputOperator;
    private final DataSchema _resultSchema;
    private final DataSchema _inputSchema;
    private final AggregateNode.AggType _aggType;
    private final Map<String, Integer> _colNameToIndexMap;
    private final Map<Integer, Map<Integer, Literal>> _aggCallSignatureMap;
    private boolean _hasReturnedAggregateBlock;
    private final boolean _isGroupByAggregation;
    private MultistageAggregationExecutor _aggregationExecutor;
    private MultistageGroupByExecutor _groupByExecutor;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.pinot.query.runtime.operator.AggregateOperator$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$calcite$sql$SqlKind = new int[SqlKind.values().length];

        static {
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.INPUT_REF.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.LITERAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    @VisibleForTesting
    public AggregateOperator(OpChainExecutionContext opChainExecutionContext, MultiStageOperator multiStageOperator, DataSchema dataSchema, DataSchema dataSchema2, List<RexExpression> list, List<RexExpression> list2, AggregateNode.AggType aggType, @Nullable List<Integer> list3, @Nullable AbstractPlanNode.NodeHint nodeHint) {
        super(opChainExecutionContext);
        this._inputOperator = multiStageOperator;
        this._resultSchema = dataSchema;
        this._inputSchema = dataSchema2;
        this._aggType = aggType;
        int[] array = (list3 == null || list3.size() == 0) ? null : list3.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray();
        if (nodeHint == null || nodeHint._hintOptions == null || nodeHint._hintOptions.get("aggOptionsInternal") == null) {
            this._aggCallSignatureMap = Collections.emptyMap();
        } else {
            this._aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap((String) ((Map) nodeHint._hintOptions.get("aggOptionsInternal")).get("agg_call_signature"));
        }
        this._hasReturnedAggregateBlock = false;
        this._colNameToIndexMap = new HashMap();
        List<ExpressionContext> groupSet = getGroupSet(list2);
        List<FunctionContext> functionContexts = getFunctionContexts(list);
        AggregationFunction[] aggregationFunctionArr = new AggregationFunction[functionContexts.size()];
        for (int i = 0; i < functionContexts.size(); i++) {
            aggregationFunctionArr[i] = AggregationFunctionFactory.getAggregationFunction(functionContexts.get(i), true);
        }
        if (list2.isEmpty()) {
            this._isGroupByAggregation = false;
            this._aggregationExecutor = new MultistageAggregationExecutor(aggregationFunctionArr, array, aggType, this._colNameToIndexMap, this._resultSchema);
        } else {
            this._isGroupByAggregation = true;
            this._groupByExecutor = new MultistageGroupByExecutor(groupSet, aggregationFunctionArr, array, aggType, this._colNameToIndexMap, this._resultSchema, opChainExecutionContext.getOpChainMetadata(), nodeHint);
        }
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    public List<MultiStageOperator> getChildOperators() {
        return ImmutableList.of(this._inputOperator);
    }

    @Nullable
    public String toExplainString() {
        return EXPLAIN_NAME;
    }

    @Override // org.apache.pinot.query.runtime.operator.MultiStageOperator
    protected TransferableBlock getNextBlock() {
        try {
            TransferableBlock consumeGroupBy = this._isGroupByAggregation ? consumeGroupBy() : consumeAggregation();
            return consumeGroupBy.isErrorBlock() ? consumeGroupBy : !this._hasReturnedAggregateBlock ? produceAggregatedBlock() : TransferableBlockUtils.getEndOfStreamTransferableBlock();
        } catch (Exception e) {
            return TransferableBlockUtils.getErrorTransferableBlock(e);
        }
    }

    private TransferableBlock produceAggregatedBlock() {
        this._hasReturnedAggregateBlock = true;
        if (!this._isGroupByAggregation) {
            return new TransferableBlock(this._aggregationExecutor.getResult(), this._resultSchema, DataBlock.Type.ROW);
        }
        List<Object[]> result = this._groupByExecutor.getResult();
        if (result.isEmpty()) {
            return TransferableBlockUtils.getEndOfStreamTransferableBlock();
        }
        TransferableBlock transferableBlock = new TransferableBlock(result, this._resultSchema, DataBlock.Type.ROW);
        if (this._groupByExecutor.isNumGroupsLimitReached()) {
            transferableBlock.addException(245, String.format("Reached numGroupsLimit of: %d for group-by, ignoring the extra groups", Integer.valueOf(this._groupByExecutor.getNumGroupsLimit())));
        }
        return transferableBlock;
    }

    private TransferableBlock consumeGroupBy() {
        TransferableBlock m20nextBlock = this._inputOperator.m20nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m20nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._groupByExecutor.processBlock(transferableBlock, this._inputSchema);
            m20nextBlock = this._inputOperator.m20nextBlock();
        }
    }

    private TransferableBlock consumeAggregation() {
        TransferableBlock m20nextBlock = this._inputOperator.m20nextBlock();
        while (true) {
            TransferableBlock transferableBlock = m20nextBlock;
            if (!transferableBlock.isDataBlock()) {
                return transferableBlock;
            }
            this._aggregationExecutor.processBlock(transferableBlock, this._inputSchema);
            m20nextBlock = this._inputOperator.m20nextBlock();
        }
    }

    private List<FunctionContext> getFunctionContexts(List<RexExpression> list) {
        Stream<RexExpression> stream = list.stream();
        Class<RexExpression.FunctionCall> cls = RexExpression.FunctionCall.class;
        Objects.requireNonNull(RexExpression.FunctionCall.class);
        List list2 = (List) stream.map((v1) -> {
            return r1.cast(v1);
        }).collect(Collectors.toList());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            arrayList.add(convertRexExpressionsToFunctionContext(i, (RexExpression.FunctionCall) list2.get(i)));
        }
        return arrayList;
    }

    private FunctionContext convertRexExpressionsToFunctionContext(int i, RexExpression.FunctionCall functionCall) {
        String functionName = functionCall.getFunctionName();
        List functionOperands = functionCall.getFunctionOperands();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < functionOperands.size(); i2++) {
            arrayList.add(convertRexExpressionToExpressionContext(i, i2, (RexExpression) functionOperands.get(i2)));
        }
        if (this._aggType.isInputIntermediateFormat()) {
            rewriteAggArgumentForIntermediateInput(arrayList, i);
        }
        if (arrayList.isEmpty()) {
            arrayList.add(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "__PLACEHOLDER__"));
        }
        return new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, arrayList);
    }

    private void rewriteAggArgumentForIntermediateInput(List<ExpressionContext> list, int i) {
        Map<Integer, Literal> map = this._aggCallSignatureMap.get(Integer.valueOf(i));
        if (map == null || map.isEmpty()) {
            return;
        }
        int intValue = map.get(-1).getIntValue();
        for (int i2 = 1; i2 < intValue; i2++) {
            Literal literal = map.get(Integer.valueOf(i2));
            if (literal != null) {
                list.add(ExpressionContext.forLiteralContext(literal));
            } else {
                list.add(ExpressionContext.forIdentifier("__PLACEHOLDER__"));
            }
        }
    }

    private List<ExpressionContext> getGroupSet(List<RexExpression> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<RexExpression> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(convertRexExpressionToExpressionContext(-1, -1, it.next()));
        }
        return arrayList;
    }

    private ExpressionContext convertRexExpressionToExpressionContext(int i, int i2, RexExpression rexExpression) {
        ExpressionContext forLiteralContext;
        if (this._aggCallSignatureMap.get(Integer.valueOf(i)) != null && this._aggCallSignatureMap.get(Integer.valueOf(i)).get(Integer.valueOf(i2)) != null) {
            return ExpressionContext.forLiteralContext(this._aggCallSignatureMap.get(Integer.valueOf(i)).get(Integer.valueOf(i2)));
        }
        switch (AnonymousClass1.$SwitchMap$org$apache$calcite$sql$SqlKind[rexExpression.getKind().ordinal()]) {
            case 1:
                int index = ((RexExpression.InputRef) rexExpression).getIndex();
                String columnName = this._inputSchema.getColumnName(index);
                this._colNameToIndexMap.put(columnName, Integer.valueOf(index));
                forLiteralContext = ExpressionContext.forIdentifier(columnName);
                break;
            case 2:
                RexExpression.Literal literal = (RexExpression.Literal) rexExpression;
                forLiteralContext = ExpressionContext.forLiteralContext(literal.getDataType().toDataType(), literal.getValue());
                break;
            default:
                throw new IllegalStateException("Aggregation Function operands or GroupBy columns cannot be a function.");
        }
        return forLiteralContext;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggregationFunction, TransferableBlock transferableBlock, DataSchema dataSchema, Map<String, Integer> map, int i) {
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 0) {
            return Collections.emptyMap();
        }
        HashMap hashMap = new HashMap();
        for (ExpressionContext expressionContext : inputExpressions) {
            if (expressionContext.getType().equals(ExpressionContext.Type.IDENTIFIER) && !"__PLACEHOLDER__".equals(expressionContext.getIdentifier())) {
                int intValue = map.get(expressionContext.getIdentifier()).intValue();
                DataSchema.ColumnDataType columnDataType = dataSchema.getColumnDataType(intValue);
                Preconditions.checkState(transferableBlock.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
                if (i == -1) {
                    hashMap.put(expressionContext, new DataBlockValSet(columnDataType, transferableBlock.getDataBlock(), intValue));
                } else {
                    hashMap.put(expressionContext, new FilteredDataBlockValSet(columnDataType, transferableBlock.getDataBlock(), intValue, i));
                }
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int computeBlockNumRows(TransferableBlock transferableBlock, int i) {
        if (i == -1) {
            return transferableBlock.getNumRows();
        }
        int i2 = 0;
        for (int i3 = 0; i3 < transferableBlock.getNumRows(); i3++) {
            i2 += transferableBlock.getDataBlock().getInt(i3, i);
        }
        return i2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] objArr, Map<String, Integer> map) {
        ExpressionContext expressionContext = (ExpressionContext) aggregationFunction.getInputExpressions().get(0);
        ExpressionContext.Type type = expressionContext.getType();
        if (type == ExpressionContext.Type.IDENTIFIER) {
            return objArr[map.get(expressionContext.getIdentifier()).intValue()];
        }
        Preconditions.checkState(type == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", type);
        return expressionContext.getLiteral().getValue();
    }
}
