package org.apache.pinot.query.runtime.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.spi.data.FieldSpec;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator.class */
public class AggregateOperator extends BaseOperator<TransferableBlock> {
    private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
    private final Operator<TransferableBlock> _inputOperator;
    private final List<RexExpression.FunctionCall> _aggCalls;
    private final List<RexExpression> _groupSet;
    private final DataSchema _resultSchema;
    private final Accumulator[] _accumulators;
    private final Map<Key, Object[]> _groupByKeyHolder;
    private TransferableBlock _upstreamErrorBlock;
    private boolean _readyToConstruct;
    private boolean _hasReturnedAggregateBlock;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator$Accumulator.class */
    public static class Accumulator {
        private static final Map<String, Merger> MERGERS = ImmutableMap.builder().put("SUM", (obj, obj2) -> {
            return AggregateOperator.mergeSum(obj, obj2);
        }).put("$SUM", (obj3, obj4) -> {
            return AggregateOperator.mergeSum(obj3, obj4);
        }).put("$SUM0", (obj5, obj6) -> {
            return AggregateOperator.mergeSum(obj5, obj6);
        }).put("MIN", (obj7, obj8) -> {
            return AggregateOperator.mergeMin(obj7, obj8);
        }).put("$MIN", (obj9, obj10) -> {
            return AggregateOperator.mergeMin(obj9, obj10);
        }).put("$MIN0", (obj11, obj12) -> {
            return AggregateOperator.mergeMin(obj11, obj12);
        }).put("MAX", (obj13, obj14) -> {
            return AggregateOperator.mergeMax(obj13, obj14);
        }).put("$MAX", (obj15, obj16) -> {
            return AggregateOperator.mergeMax(obj15, obj16);
        }).put("$MAX0", (obj17, obj18) -> {
            return AggregateOperator.mergeMax(obj17, obj18);
        }).put("COUNT", (obj19, obj20) -> {
            return AggregateOperator.mergeCount(obj19, obj20);
        }).put("BOOL_AND", (obj21, obj22) -> {
            return AggregateOperator.mergeBoolAnd(obj21, obj22);
        }).put("$BOOL_AND", (obj23, obj24) -> {
            return AggregateOperator.mergeBoolAnd(obj23, obj24);
        }).put("$BOOL_AND0", (obj25, obj26) -> {
            return AggregateOperator.mergeBoolAnd(obj25, obj26);
        }).put("BOOL_OR", (obj27, obj28) -> {
            return AggregateOperator.mergeBoolOr(obj27, obj28);
        }).put("$BOOL_OR", (obj29, obj30) -> {
            return AggregateOperator.mergeBoolOr(obj29, obj30);
        }).put("$BOOL_OR0", (obj31, obj32) -> {
            return AggregateOperator.mergeBoolOr(obj31, obj32);
        }).build();
        final int _inputRef;
        final Object _literal;
        final Map<Key, Object> _results = new HashMap();
        final Merger _merger;

        Accumulator(RexExpression.FunctionCall functionCall, Merger merger) {
            this._merger = merger;
            RexExpression aggregationFunctionOperand = toAggregationFunctionOperand(functionCall);
            if (aggregationFunctionOperand instanceof RexExpression.InputRef) {
                this._inputRef = ((RexExpression.InputRef) aggregationFunctionOperand).getIndex();
                this._literal = null;
            } else {
                this._inputRef = -1;
                this._literal = ((RexExpression.Literal) aggregationFunctionOperand).getValue();
            }
        }

        void accumulate(Key key, Object[] objArr) {
            Map<Key, Object> map = this._results;
            Object obj = map.get(key);
            Object obj2 = this._inputRef == -1 ? this._literal : objArr[this._inputRef];
            if (obj == null) {
                map.put(key, obj2);
            } else {
                this._results.put(key, this._merger.apply(obj, obj2));
            }
        }

        private RexExpression toAggregationFunctionOperand(RexExpression.FunctionCall functionCall) {
            List<RexExpression> functionOperands = functionCall.getFunctionOperands();
            Preconditions.checkState(functionOperands.size() < 2, "aggregate functions cannot have more than one operand");
            return functionOperands.size() > 0 ? functionOperands.get(0) : new RexExpression.Literal(FieldSpec.DataType.INT, 1);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperator$Merger.class */
    public interface Merger extends BiFunction<Object, Object, Object> {
    }

    public AggregateOperator(Operator<TransferableBlock> operator, DataSchema dataSchema, List<RexExpression> list, List<RexExpression> list2) {
        this(operator, dataSchema, list, list2, Accumulator.MERGERS);
    }

    @VisibleForTesting
    AggregateOperator(Operator<TransferableBlock> operator, DataSchema dataSchema, List<RexExpression> list, List<RexExpression> list2, Map<String, Merger> map) {
        this._inputOperator = operator;
        this._groupSet = list2;
        this._upstreamErrorBlock = null;
        Stream<RexExpression> stream = list.stream();
        Class<RexExpression.FunctionCall> cls = RexExpression.FunctionCall.class;
        Objects.requireNonNull(RexExpression.FunctionCall.class);
        this._aggCalls = (List) stream.map((v1) -> {
            return r2.cast(v1);
        }).collect(Collectors.toList());
        this._accumulators = new Accumulator[this._aggCalls.size()];
        for (int i = 0; i < this._aggCalls.size(); i++) {
            RexExpression.FunctionCall functionCall = this._aggCalls.get(i);
            String functionName = functionCall.getFunctionName();
            if (!map.containsKey(functionName)) {
                throw new IllegalStateException("Unexpected value: " + functionName);
            }
            this._accumulators[i] = new Accumulator(functionCall, map.get(functionName));
        }
        this._groupByKeyHolder = new HashMap();
        this._resultSchema = dataSchema;
        this._readyToConstruct = false;
        this._hasReturnedAggregateBlock = false;
    }

    @Override // org.apache.pinot.core.common.Operator
    public List<Operator> getChildOperators() {
        return null;
    }

    @Override // org.apache.pinot.core.common.Operator
    @Nullable
    public String toExplainString() {
        return EXPLAIN_NAME;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.pinot.core.operator.BaseOperator
    /* renamed from: getNextBlock */
    public TransferableBlock getNextBlock2() {
        try {
            return (this._readyToConstruct || consumeInputBlocks()) ? this._upstreamErrorBlock != null ? this._upstreamErrorBlock : !this._hasReturnedAggregateBlock ? produceAggregatedBlock() : TransferableBlockUtils.getEndOfStreamTransferableBlock() : TransferableBlockUtils.getNoOpTransferableBlock();
        } catch (Exception e) {
            return TransferableBlockUtils.getErrorTransferableBlock(e);
        }
    }

    private TransferableBlock produceAggregatedBlock() {
        ArrayList arrayList = new ArrayList(this._groupByKeyHolder.size());
        for (Map.Entry<Key, Object[]> entry : this._groupByKeyHolder.entrySet()) {
            Object[] objArr = new Object[this._aggCalls.size() + this._groupSet.size()];
            Object[] value = entry.getValue();
            System.arraycopy(value, 0, objArr, 0, value.length);
            for (int i = 0; i < this._accumulators.length; i++) {
                objArr[i + this._groupSet.size()] = this._accumulators[i]._results.get(entry.getKey());
            }
            arrayList.add(objArr);
        }
        this._hasReturnedAggregateBlock = true;
        return arrayList.size() == 0 ? TransferableBlockUtils.getEndOfStreamTransferableBlock() : new TransferableBlock(arrayList, this._resultSchema, DataBlock.Type.ROW);
    }

    private boolean consumeInputBlocks() {
        TransferableBlock nextBlock = this._inputOperator.nextBlock();
        while (true) {
            TransferableBlock transferableBlock = nextBlock;
            if (transferableBlock.isNoOpBlock()) {
                return false;
            }
            if (transferableBlock.isErrorBlock()) {
                this._upstreamErrorBlock = transferableBlock;
                return true;
            }
            if (transferableBlock.isEndOfStreamBlock()) {
                this._readyToConstruct = true;
                return true;
            }
            for (Object[] objArr : transferableBlock.getContainer()) {
                Key extraRowKey = extraRowKey(objArr, this._groupSet);
                this._groupByKeyHolder.put(extraRowKey, extraRowKey.getValues());
                for (int i = 0; i < this._aggCalls.size(); i++) {
                    this._accumulators[i].accumulate(extraRowKey, objArr);
                }
            }
            nextBlock = this._inputOperator.nextBlock();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object mergeSum(Object obj, Object obj2) {
        return Double.valueOf(((Number) obj).doubleValue() + ((Number) obj2).doubleValue());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object mergeMin(Object obj, Object obj2) {
        return Double.valueOf(Math.min(((Number) obj).doubleValue(), ((Number) obj2).doubleValue()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object mergeMax(Object obj, Object obj2) {
        return Double.valueOf(Math.max(((Number) obj).doubleValue(), ((Number) obj2).doubleValue()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Object mergeCount(Object obj, Object obj2) {
        return Double.valueOf(((Number) obj).doubleValue() + 1.0d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Boolean mergeBoolAnd(Object obj, Object obj2) {
        return Boolean.valueOf(((Boolean) obj).booleanValue() && ((Boolean) obj2).booleanValue());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Boolean mergeBoolOr(Object obj, Object obj2) {
        return Boolean.valueOf(((Boolean) obj).booleanValue() || ((Boolean) obj2).booleanValue());
    }

    private static Key extraRowKey(Object[] objArr, List<RexExpression> list) {
        Object[] objArr2 = new Object[list.size()];
        for (int i = 0; i < list.size(); i++) {
            objArr2[i] = objArr[((RexExpression.InputRef) list.get(i)).getIndex()];
        }
        return new Key(objArr2);
    }
}
