package org.apache.pinot.core.query.aggregation.function;

import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.datasketches.common.Util;
import org.apache.datasketches.cpc.CpcSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.local.customobject.CpcSketchAccumulator;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.Constants;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction.class */
public class DistinctCountCPCSketchAggregationFunction extends BaseSingleInputAggregationFunction<CpcSketchAccumulator, Comparable> {
    private static final int DEFAULT_ACCUMULATOR_THRESHOLD = 2;
    protected int _accumulatorThreshold;
    protected int _lgNominalEntries;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction$DictIdsWrapper.class */
    public static final class DictIdsWrapper {
        final Dictionary _dictionary;
        final RoaringBitmap _dictIdBitmap = new RoaringBitmap();

        private DictIdsWrapper(Dictionary dictionary) {
            this._dictionary = dictionary;
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountCPCSketchAggregationFunction$Parameters.class */
    private static class Parameters {
        private static final char PARAMETER_DELIMITER = ';';
        private static final char PARAMETER_KEY_VALUE_SEPARATOR = '=';
        private static final String NOMINAL_ENTRIES_KEY = "nominalEntries";
        private static final String ACCUMULATOR_THRESHOLD_KEY = "accumulatorThreshold";
        private int _nominalEntries;
        private int _accumulatorThreshold;

        Parameters(String str) {
            this._nominalEntries = (int) Math.pow(2.0d, 12.0d);
            this._accumulatorThreshold = 2;
            StringUtils.deleteWhitespace(str);
            for (String str2 : StringUtils.split(str, ';')) {
                String[] split = StringUtils.split(str2, '=');
                Preconditions.checkArgument(split.length == 2, "Invalid parameter: %s", str2);
                String str3 = split[0];
                String str4 = split[1];
                if (str3.equalsIgnoreCase("nominalEntries")) {
                    this._nominalEntries = Integer.parseInt(str4);
                } else {
                    if (!str3.equalsIgnoreCase(ACCUMULATOR_THRESHOLD_KEY)) {
                        throw new IllegalArgumentException("Invalid parameter key: " + str3);
                    }
                    this._accumulatorThreshold = Integer.parseInt(str4);
                }
            }
        }

        int getLgNominalEntries() {
            return Util.exactLog2OfInt(this._nominalEntries);
        }

        int getAccumulatorThreshold() {
            return this._accumulatorThreshold;
        }
    }

    public DistinctCountCPCSketchAggregationFunction(List<ExpressionContext> list) {
        super(list.get(0));
        this._accumulatorThreshold = 2;
        int size = list.size();
        Preconditions.checkArgument(size <= 2, "DistinctCountCPC expects 1 or 2 arguments, got: %s", size);
        if (list.size() != 2) {
            this._lgNominalEntries = 12;
            return;
        }
        ExpressionContext expressionContext = list.get(1);
        Preconditions.checkArgument(expressionContext.getType() == ExpressionContext.Type.LITERAL, "CPC Sketch Aggregation Function expects the second argument to be a literal (parameters), but got: ", expressionContext.getType());
        if (expressionContext.getLiteral().getType() != FieldSpec.DataType.STRING) {
            this._lgNominalEntries = expressionContext.getLiteral().getIntValue();
            return;
        }
        Parameters parameters = new Parameters(expressionContext.getLiteral().getStringValue());
        this._accumulatorThreshold = parameters.getAccumulatorThreshold();
        this._lgNominalEntries = parameters.getLgNominalEntries();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTCPCSKETCH;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
            try {
                CpcSketchAccumulator accumulator = getAccumulator(aggregationResultHolder);
                for (CpcSketch cpcSketch : deserializeSketches(bytesValuesSV, i)) {
                    accumulator.apply(cpcSketch);
                }
                return;
            } catch (Exception e) {
                throw new RuntimeException("Caught exception while merging CPC sketches", e);
            }
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            getDictIdBitmap(aggregationResultHolder, dictionary).addN(blockValSet.getDictionaryIdsSV(), 0, i);
            return;
        }
        CpcSketch cpcSketch2 = getCpcSketch(aggregationResultHolder);
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i2 = 0; i2 < i; i2++) {
                    cpcSketch2.update(intValuesSV[i2]);
                }
                break;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i3 = 0; i3 < i; i3++) {
                    cpcSketch2.update(longValuesSV[i3]);
                }
                break;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i4 = 0; i4 < i; i4++) {
                    cpcSketch2.update(floatValuesSV[i4]);
                }
                break;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    cpcSketch2.update(doubleValuesSV[i5]);
                }
                break;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    cpcSketch2.update(stringValuesSV[i6]);
                }
                break;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + String.valueOf(storedType));
        }
        getAccumulator(aggregationResultHolder).apply(cpcSketch2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            try {
                CpcSketch[] deserializeSketches = deserializeSketches(blockValSet.getBytesValuesSV(), i);
                for (int i2 = 0; i2 < i; i2++) {
                    getAccumulator(groupByResultHolder, iArr[i2]).apply(deserializeSketches[i2]);
                }
                return;
            } catch (Exception e) {
                throw new RuntimeException("Caught exception while aggregating CPC Sketches", e);
            }
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictionaryIdsSV = blockValSet.getDictionaryIdsSV();
            for (int i3 = 0; i3 < i; i3++) {
                getDictIdBitmap(groupByResultHolder, iArr[i3], dictionary).add(dictionaryIdsSV[i3]);
            }
            return;
        }
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i4 = 0; i4 < i; i4++) {
                    getCpcSketch(groupByResultHolder, iArr[i4]).update(intValuesSV[i4]);
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    getCpcSketch(groupByResultHolder, iArr[i5]).update(longValuesSV[i5]);
                }
                return;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    getCpcSketch(groupByResultHolder, iArr[i6]).update(floatValuesSV[i6]);
                }
                return;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    getCpcSketch(groupByResultHolder, iArr[i7]).update(doubleValuesSV[i7]);
                }
                return;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i8 = 0; i8 < i; i8++) {
                    getCpcSketch(groupByResultHolder, iArr[i8]).update(stringValuesSV[i8]);
                }
                return;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + String.valueOf(storedType));
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (blockValSet.isSingleValue() && storedType == FieldSpec.DataType.BYTES) {
            try {
                CpcSketch[] deserializeSketches = deserializeSketches(blockValSet.getBytesValuesSV(), i);
                for (int i2 = 0; i2 < i; i2++) {
                    for (int i3 : iArr[i2]) {
                        getAccumulator(groupByResultHolder, i3).apply(deserializeSketches[i2]);
                    }
                }
                return;
            } catch (Exception e) {
                throw new RuntimeException("Caught exception while aggregating CPC sketches", e);
            }
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictionaryIdsSV = blockValSet.getDictionaryIdsSV();
            for (int i4 = 0; i4 < i; i4++) {
                setDictIdForGroupKeys(groupByResultHolder, iArr[i4], dictionary, dictionaryIdsSV[i4]);
            }
            return;
        }
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    for (int i6 : iArr[i5]) {
                        getCpcSketch(groupByResultHolder, i6).update(intValuesSV[i5]);
                    }
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    for (int i8 : iArr[i7]) {
                        getCpcSketch(groupByResultHolder, i8).update(longValuesSV[i7]);
                    }
                }
                return;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i9 = 0; i9 < i; i9++) {
                    for (int i10 : iArr[i9]) {
                        getCpcSketch(groupByResultHolder, i10).update(floatValuesSV[i9]);
                    }
                }
                return;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i11 = 0; i11 < i; i11++) {
                    for (int i12 : iArr[i11]) {
                        getCpcSketch(groupByResultHolder, i12).update(doubleValuesSV[i11]);
                    }
                }
                return;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i13 = 0; i13 < i; i13++) {
                    for (int i14 : iArr[i13]) {
                        getCpcSketch(groupByResultHolder, i14).update(stringValuesSV[i13]);
                    }
                }
                return;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + String.valueOf(storedType));
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public CpcSketchAccumulator extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Object result = aggregationResultHolder.getResult();
        return result == null ? new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold) : result instanceof CpcSketch ? convertSketchAccumulator(result) : result instanceof DictIdsWrapper ? convertSketchAccumulator(dictionaryToCpcSketch((DictIdsWrapper) result)) : (CpcSketchAccumulator) result;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public CpcSketchAccumulator extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        Object result = groupByResultHolder.getResult(i);
        return result == null ? new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold) : result instanceof CpcSketch ? convertSketchAccumulator(result) : result instanceof DictIdsWrapper ? convertSketchAccumulator(dictionaryToCpcSketch((DictIdsWrapper) result)) : (CpcSketchAccumulator) result;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public CpcSketchAccumulator merge(CpcSketchAccumulator cpcSketchAccumulator, CpcSketchAccumulator cpcSketchAccumulator2) {
        if (cpcSketchAccumulator == null || cpcSketchAccumulator.isEmpty()) {
            return cpcSketchAccumulator2;
        }
        if (cpcSketchAccumulator2 == null || cpcSketchAccumulator2.isEmpty()) {
            return cpcSketchAccumulator;
        }
        cpcSketchAccumulator.setLgNominalEntries(this._lgNominalEntries);
        cpcSketchAccumulator.setThreshold(this._accumulatorThreshold);
        cpcSketchAccumulator.merge(cpcSketchAccumulator2);
        return cpcSketchAccumulator;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Comparable extractFinalResult(CpcSketchAccumulator cpcSketchAccumulator) {
        cpcSketchAccumulator.setLgNominalEntries(this._lgNominalEntries);
        cpcSketchAccumulator.setThreshold(this._accumulatorThreshold);
        return Long.valueOf(Math.round(cpcSketchAccumulator.getResult().getEstimate()));
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Comparable mergeFinalResult(Comparable comparable, Comparable comparable2) {
        return Long.valueOf(((Long) comparable).longValue() + ((Long) comparable2).longValue());
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public boolean canUseStarTree(Map<String, Object> map) {
        Object obj = map.get(Constants.CPCSKETCH_LGK_KEY);
        return this._lgNominalEntries <= (obj != null ? Integer.parseInt(String.valueOf(obj)) : 12);
    }

    protected CpcSketch getCpcSketch(AggregationResultHolder aggregationResultHolder) {
        CpcSketch cpcSketch = (CpcSketch) aggregationResultHolder.getResult();
        if (cpcSketch == null) {
            cpcSketch = new CpcSketch(this._lgNominalEntries);
            aggregationResultHolder.setValue(cpcSketch);
        }
        return cpcSketch;
    }

    protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper) aggregationResultHolder.getResult();
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            aggregationResultHolder.setValue(dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int i, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper) groupByResultHolder.getResult(i);
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            groupByResultHolder.setValueForKey(i, dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected CpcSketch getCpcSketch(GroupByResultHolder groupByResultHolder, int i) {
        CpcSketch cpcSketch = (CpcSketch) groupByResultHolder.getResult(i);
        if (cpcSketch == null) {
            cpcSketch = new CpcSketch(this._lgNominalEntries);
            groupByResultHolder.setValueForKey(i, cpcSketch);
        }
        return cpcSketch;
    }

    private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] iArr, Dictionary dictionary, int i) {
        for (int i2 : iArr) {
            getDictIdBitmap(groupByResultHolder, i2, dictionary).add(i);
        }
    }

    private CpcSketch dictionaryToCpcSketch(DictIdsWrapper dictIdsWrapper) {
        CpcSketch cpcSketch = new CpcSketch(this._lgNominalEntries);
        Dictionary dictionary = dictIdsWrapper._dictionary;
        PeekableIntIterator intIterator = dictIdsWrapper._dictIdBitmap.getIntIterator();
        while (intIterator.hasNext()) {
            addObjectToSketch(dictionary.get(intIterator.next()), cpcSketch);
        }
        return cpcSketch;
    }

    private void addObjectToSketch(Object obj, CpcSketch cpcSketch) {
        if (obj instanceof String) {
            cpcSketch.update((String) obj);
            return;
        }
        if (obj instanceof Integer) {
            cpcSketch.update(((Integer) obj).intValue());
            return;
        }
        if (obj instanceof Long) {
            cpcSketch.update(((Long) obj).longValue());
            return;
        }
        if (obj instanceof Double) {
            cpcSketch.update(((Double) obj).doubleValue());
        } else if (obj instanceof Float) {
            cpcSketch.update(((Float) obj).floatValue());
        } else {
            if (!(obj instanceof Object[])) {
                throw new IllegalStateException("Unsupported data type for CPC Sketch aggregation: " + obj.getClass().getSimpleName());
            }
            addObjectsToSketch((Object[]) obj, cpcSketch);
        }
    }

    private void addObjectsToSketch(Object[] objArr, CpcSketch cpcSketch) {
        if (objArr instanceof String[]) {
            for (String str : (String[]) objArr) {
                cpcSketch.update(str);
            }
            return;
        }
        if (objArr instanceof Integer[]) {
            int length = ((Integer[]) objArr).length;
            for (int i = 0; i < length; i++) {
                cpcSketch.update(r0[i].intValue());
            }
            return;
        }
        if (objArr instanceof Long[]) {
            for (Long l : (Long[]) objArr) {
                cpcSketch.update(l.longValue());
            }
            return;
        }
        if (objArr instanceof Double[]) {
            for (Double d : (Double[]) objArr) {
                cpcSketch.update(d.doubleValue());
            }
            return;
        }
        if (!(objArr instanceof Float[])) {
            throw new IllegalStateException("Unsupported data type for CPC Sketch aggregation: " + objArr.getClass().getSimpleName());
        }
        int length2 = ((Float[]) objArr).length;
        for (int i2 = 0; i2 < length2; i2++) {
            cpcSketch.update(r0[i2].floatValue());
        }
    }

    private CpcSketchAccumulator getAccumulator(AggregationResultHolder aggregationResultHolder) {
        CpcSketchAccumulator cpcSketchAccumulator = (CpcSketchAccumulator) aggregationResultHolder.getResult();
        if (cpcSketchAccumulator == null) {
            cpcSketchAccumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
            aggregationResultHolder.setValue(cpcSketchAccumulator);
        }
        return cpcSketchAccumulator;
    }

    private CpcSketchAccumulator getAccumulator(GroupByResultHolder groupByResultHolder, int i) {
        CpcSketchAccumulator cpcSketchAccumulator = (CpcSketchAccumulator) groupByResultHolder.getResult(i);
        if (cpcSketchAccumulator == null) {
            cpcSketchAccumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
            groupByResultHolder.setValueForKey(i, cpcSketchAccumulator);
        }
        return cpcSketchAccumulator;
    }

    private CpcSketch[] deserializeSketches(byte[][] bArr, int i) {
        CpcSketch[] cpcSketchArr = new CpcSketch[i];
        for (int i2 = 0; i2 < i; i2++) {
            cpcSketchArr[i2] = CpcSketch.heapify(Memory.wrap(bArr[i2]));
        }
        return cpcSketchArr;
    }

    protected CpcSketchAccumulator convertSketchAccumulator(Object obj) {
        if (!(obj instanceof CpcSketch)) {
            return (CpcSketchAccumulator) obj;
        }
        CpcSketchAccumulator cpcSketchAccumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
        cpcSketchAccumulator.apply((CpcSketch) obj);
        return cpcSketchAccumulator;
    }
}
