package org.apache.pinot.core.query.aggregation.function;

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.List;
import java.util.Map;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.class */
public class SumPrecisionAggregationFunction extends BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
    private final Integer _precision;
    private final Integer _scale;
    private final boolean _nullHandlingEnabled;

    public SumPrecisionAggregationFunction(List<ExpressionContext> list, boolean z) {
        super(list.get(0));
        int size = list.size();
        Preconditions.checkArgument(size <= 3, "SumPrecision expects at most 3 arguments, got: %s", size);
        if (size > 1) {
            this._precision = Integer.valueOf(list.get(1).getLiteral().getIntValue());
            if (size > 2) {
                this._scale = Integer.valueOf(list.get(2).getLiteral().getIntValue());
            } else {
                this._scale = null;
            }
        } else {
            this._precision = null;
            this._scale = null;
        }
        this._nullHandlingEnabled = z;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.SUMPRECISION;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        RoaringBitmap nullBitmap;
        BlockValSet blockValSet = map.get(this._expression);
        if (this._nullHandlingEnabled && (nullBitmap = blockValSet.getNullBitmap()) != null && !nullBitmap.isEmpty()) {
            aggregateNullHandlingEnabled(i, aggregationResultHolder, blockValSet, nullBitmap);
            return;
        }
        BigDecimal defaultResult = getDefaultResult(aggregationResultHolder);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i2 = 0; i2 < i; i2++) {
                    defaultResult = defaultResult.add(BigDecimal.valueOf(intValuesSV[i2]));
                }
                break;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i3 = 0; i3 < i; i3++) {
                    defaultResult = defaultResult.add(BigDecimal.valueOf(longValuesSV[i3]));
                }
                break;
            case FLOAT:
            case DOUBLE:
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i4 = 0; i4 < i; i4++) {
                    defaultResult = defaultResult.add(new BigDecimal(stringValuesSV[i4]));
                }
                break;
            case BIG_DECIMAL:
                BigDecimal[] bigDecimalValuesSV = blockValSet.getBigDecimalValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    defaultResult = defaultResult.add(bigDecimalValuesSV[i5]);
                }
                break;
            case BYTES:
                byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    defaultResult = defaultResult.add(BigDecimalUtils.deserialize(bytesValuesSV[i6]));
                }
                break;
            default:
                throw new IllegalStateException();
        }
        aggregationResultHolder.setValue(defaultResult);
    }

    private void aggregateNullHandlingEnabled(int i, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet, RoaringBitmap roaringBitmap) {
        BigDecimal bigDecimal = BigDecimal.ZERO;
        switch (blockValSet.getValueType().getStoredType()) {
            case INT:
                if (roaringBitmap.getCardinality() < i) {
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (!roaringBitmap.contains(i2)) {
                            bigDecimal = bigDecimal.add(BigDecimal.valueOf(intValuesSV[i2]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case LONG:
                if (roaringBitmap.getCardinality() < i) {
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (!roaringBitmap.contains(i3)) {
                            bigDecimal = bigDecimal.add(BigDecimal.valueOf(longValuesSV[i3]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case FLOAT:
                if (roaringBitmap.getCardinality() < i) {
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (!roaringBitmap.contains(i4) && Float.isFinite(floatValuesSV[i4])) {
                            bigDecimal = bigDecimal.add(BigDecimal.valueOf(floatValuesSV[i4]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case DOUBLE:
                if (roaringBitmap.getCardinality() < i) {
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (!roaringBitmap.contains(i5) && Double.isFinite(doubleValuesSV[i5])) {
                            bigDecimal = bigDecimal.add(BigDecimal.valueOf(doubleValuesSV[i5]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case STRING:
                if (roaringBitmap.getCardinality() < i) {
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (!roaringBitmap.contains(i6)) {
                            bigDecimal = bigDecimal.add(new BigDecimal(stringValuesSV[i6]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case BIG_DECIMAL:
                if (roaringBitmap.getCardinality() < i) {
                    BigDecimal[] bigDecimalValuesSV = blockValSet.getBigDecimalValuesSV();
                    for (int i7 = 0; i7 < i; i7++) {
                        if (!roaringBitmap.contains(i7)) {
                            bigDecimal = bigDecimal.add(bigDecimalValuesSV[i7]);
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            case BYTES:
                if (roaringBitmap.getCardinality() < i) {
                    byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                    for (int i8 = 0; i8 < i; i8++) {
                        if (!roaringBitmap.contains(i8)) {
                            bigDecimal = bigDecimal.add(BigDecimalUtils.deserialize(bytesValuesSV[i8]));
                        }
                    }
                    setAggregationResult(aggregationResultHolder, bigDecimal);
                    return;
                }
                return;
            default:
                throw new IllegalStateException();
        }
    }

    protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, BigDecimal bigDecimal) {
        BigDecimal bigDecimal2 = (BigDecimal) aggregationResultHolder.getResult();
        aggregationResultHolder.setValue(bigDecimal2 == null ? bigDecimal : bigDecimal.add(bigDecimal2));
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        RoaringBitmap nullBitmap;
        BlockValSet blockValSet = map.get(this._expression);
        if (this._nullHandlingEnabled && (nullBitmap = blockValSet.getNullBitmap()) != null && !nullBitmap.isEmpty()) {
            aggregateGroupBySVNullHandlingEnabled(i, iArr, groupByResultHolder, blockValSet, nullBitmap);
            return;
        }
        switch (blockValSet.getValueType().getStoredType()) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i2 = 0; i2 < i; i2++) {
                    int i3 = iArr[i2];
                    groupByResultHolder.setValueForKey(i3, getDefaultResult(groupByResultHolder, i3).add(BigDecimal.valueOf(intValuesSV[i2])));
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i4 = 0; i4 < i; i4++) {
                    int i5 = iArr[i4];
                    groupByResultHolder.setValueForKey(i5, getDefaultResult(groupByResultHolder, i5).add(BigDecimal.valueOf(longValuesSV[i4])));
                }
                return;
            case FLOAT:
            case DOUBLE:
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    int i7 = iArr[i6];
                    groupByResultHolder.setValueForKey(i7, getDefaultResult(groupByResultHolder, i7).add(new BigDecimal(stringValuesSV[i6])));
                }
                return;
            case BIG_DECIMAL:
                BigDecimal[] bigDecimalValuesSV = blockValSet.getBigDecimalValuesSV();
                for (int i8 = 0; i8 < i; i8++) {
                    int i9 = iArr[i8];
                    groupByResultHolder.setValueForKey(i9, getDefaultResult(groupByResultHolder, i9).add(bigDecimalValuesSV[i8]));
                }
                return;
            case BYTES:
                byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                for (int i10 = 0; i10 < i; i10++) {
                    int i11 = iArr[i10];
                    groupByResultHolder.setValueForKey(i11, getDefaultResult(groupByResultHolder, i11).add(BigDecimalUtils.deserialize(bytesValuesSV[i10])));
                }
                return;
            default:
                throw new IllegalStateException();
        }
    }

    private void aggregateGroupBySVNullHandlingEnabled(int i, int[] iArr, GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap roaringBitmap) {
        switch (blockValSet.getValueType().getStoredType()) {
            case INT:
                if (roaringBitmap.getCardinality() < i) {
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (!roaringBitmap.contains(i2)) {
                            setGroupByResult(iArr[i2], groupByResultHolder, BigDecimal.valueOf(intValuesSV[i2]));
                        }
                    }
                    return;
                }
                return;
            case LONG:
                if (roaringBitmap.getCardinality() < i) {
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (!roaringBitmap.contains(i3)) {
                            setGroupByResult(iArr[i3], groupByResultHolder, BigDecimal.valueOf(longValuesSV[i3]));
                        }
                    }
                    return;
                }
                return;
            case FLOAT:
            case DOUBLE:
            case STRING:
                if (roaringBitmap.getCardinality() < i) {
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (!roaringBitmap.contains(i4)) {
                            setGroupByResult(iArr[i4], groupByResultHolder, new BigDecimal(stringValuesSV[i4]));
                        }
                    }
                    return;
                }
                return;
            case BIG_DECIMAL:
                if (roaringBitmap.getCardinality() < i) {
                    BigDecimal[] bigDecimalValuesSV = blockValSet.getBigDecimalValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (!roaringBitmap.contains(i5)) {
                            setGroupByResult(iArr[i5], groupByResultHolder, bigDecimalValuesSV[i5]);
                        }
                    }
                    return;
                }
                return;
            case BYTES:
                if (roaringBitmap.getCardinality() < i) {
                    byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (!roaringBitmap.contains(i6)) {
                            setGroupByResult(iArr[i6], groupByResultHolder, BigDecimalUtils.deserialize(bytesValuesSV[i6]));
                        }
                    }
                    return;
                }
                return;
            default:
                throw new IllegalStateException();
        }
    }

    private void setGroupByResult(int i, GroupByResultHolder groupByResultHolder, BigDecimal bigDecimal) {
        BigDecimal bigDecimal2 = (BigDecimal) groupByResultHolder.getResult(i);
        groupByResultHolder.setValueForKey(i, bigDecimal2 == null ? bigDecimal : bigDecimal2.add(bigDecimal));
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i2 = 0; i2 < i; i2++) {
                    int i3 = intValuesSV[i2];
                    for (int i4 : iArr[i2]) {
                        groupByResultHolder.setValueForKey(i4, getDefaultResult(groupByResultHolder, i4).add(BigDecimal.valueOf(i3)));
                    }
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    long j = longValuesSV[i5];
                    for (int i6 : iArr[i5]) {
                        groupByResultHolder.setValueForKey(i6, getDefaultResult(groupByResultHolder, i6).add(BigDecimal.valueOf(j)));
                    }
                }
                return;
            case FLOAT:
            case DOUBLE:
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    String str = stringValuesSV[i7];
                    for (int i8 : iArr[i7]) {
                        groupByResultHolder.setValueForKey(i8, getDefaultResult(groupByResultHolder, i8).add(new BigDecimal(str)));
                    }
                }
                return;
            case BIG_DECIMAL:
                BigDecimal[] bigDecimalValuesSV = blockValSet.getBigDecimalValuesSV();
                for (int i9 = 0; i9 < i; i9++) {
                    BigDecimal bigDecimal = bigDecimalValuesSV[i9];
                    for (int i10 : iArr[i9]) {
                        groupByResultHolder.setValueForKey(i10, getDefaultResult(groupByResultHolder, i10).add(bigDecimal));
                    }
                }
                return;
            case BYTES:
                byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                for (int i11 = 0; i11 < i; i11++) {
                    byte[] bArr = bytesValuesSV[i11];
                    for (int i12 : iArr[i11]) {
                        groupByResultHolder.setValueForKey(i12, getDefaultResult(groupByResultHolder, i12).add(BigDecimalUtils.deserialize(bArr)));
                    }
                }
                return;
            default:
                throw new IllegalStateException();
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public BigDecimal extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal bigDecimal = (BigDecimal) aggregationResultHolder.getResult();
        if (bigDecimal != null) {
            return bigDecimal;
        }
        if (this._nullHandlingEnabled) {
            return null;
        }
        return BigDecimal.ZERO;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public BigDecimal extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        BigDecimal bigDecimal = (BigDecimal) groupByResultHolder.getResult(i);
        if (bigDecimal != null) {
            return bigDecimal;
        }
        if (this._nullHandlingEnabled) {
            return null;
        }
        return BigDecimal.ZERO;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public BigDecimal merge(BigDecimal bigDecimal, BigDecimal bigDecimal2) {
        if (this._nullHandlingEnabled) {
            if (bigDecimal == null) {
                return bigDecimal2;
            }
            if (bigDecimal2 == null) {
                return bigDecimal;
            }
        }
        return bigDecimal.add(bigDecimal2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public BigDecimal extractFinalResult(BigDecimal bigDecimal) {
        if (bigDecimal == null) {
            return null;
        }
        if (this._precision == null) {
            return bigDecimal;
        }
        BigDecimal round = bigDecimal.round(new MathContext(this._precision.intValue(), RoundingMode.HALF_EVEN));
        return this._scale == null ? round : round.setScale(this._scale.intValue(), RoundingMode.HALF_EVEN);
    }

    public BigDecimal getDefaultResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal bigDecimal = (BigDecimal) aggregationResultHolder.getResult();
        return bigDecimal != null ? bigDecimal : BigDecimal.ZERO;
    }

    public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, int i) {
        BigDecimal bigDecimal = (BigDecimal) groupByResultHolder.getResult(i);
        return bigDecimal != null ? bigDecimal : BigDecimal.ZERO;
    }
}
