package org.apache.pinot.core.query.aggregation.function;

import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import java.util.List;
import java.util.Map;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.Constants;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.class */
public class DistinctCountHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
    protected final int _log2m;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction$DictIdsWrapper.class */
    public static final class DictIdsWrapper {
        final Dictionary _dictionary;
        final RoaringBitmap _dictIdBitmap = new RoaringBitmap();

        private DictIdsWrapper(Dictionary dictionary) {
            this._dictionary = dictionary;
        }
    }

    public DistinctCountHLLAggregationFunction(List<ExpressionContext> list) {
        super(list.get(0));
        int size = list.size();
        Preconditions.checkArgument(size <= 2, "DistinctCountHLL expects 1 or 2 arguments, got: %s", size);
        if (list.size() == 2) {
            this._log2m = list.get(1).getLiteral().getIntValue();
        } else {
            this._log2m = 8;
        }
    }

    public int getLog2m() {
        return this._log2m;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTHLL;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
            try {
                HyperLogLog hyperLogLog = (HyperLogLog) aggregationResultHolder.getResult();
                if (hyperLogLog != null) {
                    for (int i2 = 0; i2 < i; i2++) {
                        hyperLogLog.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[i2]));
                    }
                } else {
                    HyperLogLog deserialize2 = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[0]);
                    aggregationResultHolder.setValue(deserialize2);
                    for (int i3 = 1; i3 < i; i3++) {
                        deserialize2.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[i3]));
                    }
                }
                return;
            } catch (Exception e) {
                throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
            }
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            getDictIdBitmap(aggregationResultHolder, dictionary).addN(blockValSet.getDictionaryIdsSV(), 0, i);
            return;
        }
        HyperLogLog hyperLogLog2 = getHyperLogLog(aggregationResultHolder);
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i4 = 0; i4 < i; i4++) {
                    hyperLogLog2.offer(Integer.valueOf(intValuesSV[i4]));
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    hyperLogLog2.offer(Long.valueOf(longValuesSV[i5]));
                }
                return;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    hyperLogLog2.offer(Float.valueOf(floatValuesSV[i6]));
                }
                return;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    hyperLogLog2.offer(Double.valueOf(doubleValuesSV[i7]));
                }
                return;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i8 = 0; i8 < i; i8++) {
                    hyperLogLog2.offer(stringValuesSV[i8]);
                }
                return;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + String.valueOf(storedType));
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
            for (int i2 = 0; i2 < i; i2++) {
                try {
                    HyperLogLog deserialize2 = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[i2]);
                    int i3 = iArr[i2];
                    HyperLogLog hyperLogLog = (HyperLogLog) groupByResultHolder.getResult(i3);
                    if (hyperLogLog != null) {
                        hyperLogLog.addAll(deserialize2);
                    } else {
                        groupByResultHolder.setValueForKey(i3, deserialize2);
                    }
                } catch (Exception e) {
                    throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
                }
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictionaryIdsSV = blockValSet.getDictionaryIdsSV();
            for (int i4 = 0; i4 < i; i4++) {
                getDictIdBitmap(groupByResultHolder, iArr[i4], dictionary).add(dictionaryIdsSV[i4]);
            }
            return;
        }
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    getHyperLogLog(groupByResultHolder, iArr[i5]).offer(Integer.valueOf(intValuesSV[i5]));
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    getHyperLogLog(groupByResultHolder, iArr[i6]).offer(Long.valueOf(longValuesSV[i6]));
                }
                return;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    getHyperLogLog(groupByResultHolder, iArr[i7]).offer(Float.valueOf(floatValuesSV[i7]));
                }
                return;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i8 = 0; i8 < i; i8++) {
                    getHyperLogLog(groupByResultHolder, iArr[i8]).offer(Double.valueOf(doubleValuesSV[i8]));
                }
                return;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i9 = 0; i9 < i; i9++) {
                    getHyperLogLog(groupByResultHolder, iArr[i9]).offer(stringValuesSV[i9]);
                }
                return;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + String.valueOf(storedType));
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
            for (int i2 = 0; i2 < i; i2++) {
                try {
                    HyperLogLog deserialize2 = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[i2]);
                    for (int i3 : iArr[i2]) {
                        HyperLogLog hyperLogLog = (HyperLogLog) groupByResultHolder.getResult(i3);
                        if (hyperLogLog != null) {
                            hyperLogLog.addAll(deserialize2);
                        } else {
                            groupByResultHolder.setValueForKey(i3, ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize2(bytesValuesSV[i2]));
                        }
                    }
                } catch (Exception e) {
                    throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
                }
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictionaryIdsSV = blockValSet.getDictionaryIdsSV();
            for (int i4 = 0; i4 < i; i4++) {
                setDictIdForGroupKeys(groupByResultHolder, iArr[i4], dictionary, dictionaryIdsSV[i4]);
            }
            return;
        }
        switch (storedType) {
            case INT:
                int[] intValuesSV = blockValSet.getIntValuesSV();
                for (int i5 = 0; i5 < i; i5++) {
                    setValueForGroupKeys(groupByResultHolder, iArr[i5], Integer.valueOf(intValuesSV[i5]));
                }
                return;
            case LONG:
                long[] longValuesSV = blockValSet.getLongValuesSV();
                for (int i6 = 0; i6 < i; i6++) {
                    setValueForGroupKeys(groupByResultHolder, iArr[i6], Long.valueOf(longValuesSV[i6]));
                }
                return;
            case FLOAT:
                float[] floatValuesSV = blockValSet.getFloatValuesSV();
                for (int i7 = 0; i7 < i; i7++) {
                    setValueForGroupKeys(groupByResultHolder, iArr[i7], Float.valueOf(floatValuesSV[i7]));
                }
                return;
            case DOUBLE:
                double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                for (int i8 = 0; i8 < i; i8++) {
                    setValueForGroupKeys(groupByResultHolder, iArr[i8], Double.valueOf(doubleValuesSV[i8]));
                }
                return;
            case STRING:
                String[] stringValuesSV = blockValSet.getStringValuesSV();
                for (int i9 = 0; i9 < i; i9++) {
                    setValueForGroupKeys(groupByResultHolder, iArr[i9], stringValuesSV[i9]);
                }
                return;
            default:
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + String.valueOf(storedType));
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public HyperLogLog extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Object result = aggregationResultHolder.getResult();
        return result == null ? new HyperLogLog(this._log2m) : result instanceof DictIdsWrapper ? convertToHyperLogLog((DictIdsWrapper) result) : (HyperLogLog) result;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public HyperLogLog extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        Object result = groupByResultHolder.getResult(i);
        return result == null ? new HyperLogLog(this._log2m) : result instanceof DictIdsWrapper ? convertToHyperLogLog((DictIdsWrapper) result) : (HyperLogLog) result;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public HyperLogLog merge(HyperLogLog hyperLogLog, HyperLogLog hyperLogLog2) {
        if (hyperLogLog.sizeof() != hyperLogLog2.sizeof()) {
            if (hyperLogLog.cardinality() == 0) {
                return hyperLogLog2;
            }
            Preconditions.checkState(hyperLogLog2.cardinality() == 0, "Cannot merge HyperLogLogs of different sizes");
            return hyperLogLog;
        }
        try {
            hyperLogLog.addAll(hyperLogLog2);
            return hyperLogLog;
        } catch (Exception e) {
            throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Long extractFinalResult(HyperLogLog hyperLogLog) {
        return Long.valueOf(hyperLogLog.cardinality());
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Long mergeFinalResult(Long l, Long l2) {
        return Long.valueOf(l.longValue() + l2.longValue());
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public boolean canUseStarTree(Map<String, Object> map) {
        Object obj = map.get(Constants.HLL_LOG2M_KEY);
        return obj != null ? this._log2m == Integer.parseInt(String.valueOf(obj)) : this._log2m == 8;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper) aggregationResultHolder.getResult();
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            aggregationResultHolder.setValue(dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HyperLogLog getHyperLogLog(AggregationResultHolder aggregationResultHolder) {
        HyperLogLog hyperLogLog = (HyperLogLog) aggregationResultHolder.getResult();
        if (hyperLogLog == null) {
            hyperLogLog = new HyperLogLog(this._log2m);
            aggregationResultHolder.setValue(hyperLogLog);
        }
        return hyperLogLog;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int i, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper) groupByResultHolder.getResult(i);
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            groupByResultHolder.setValueForKey(i, dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HyperLogLog getHyperLogLog(GroupByResultHolder groupByResultHolder, int i) {
        HyperLogLog hyperLogLog = (HyperLogLog) groupByResultHolder.getResult(i);
        if (hyperLogLog == null) {
            hyperLogLog = new HyperLogLog(this._log2m);
            groupByResultHolder.setValueForKey(i, hyperLogLog);
        }
        return hyperLogLog;
    }

    private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] iArr, Dictionary dictionary, int i) {
        for (int i2 : iArr) {
            getDictIdBitmap(groupByResultHolder, i2, dictionary).add(i);
        }
    }

    private void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] iArr, Object obj) {
        for (int i : iArr) {
            getHyperLogLog(groupByResultHolder, i).offer(obj);
        }
    }

    private HyperLogLog convertToHyperLogLog(DictIdsWrapper dictIdsWrapper) {
        HyperLogLog hyperLogLog = new HyperLogLog(this._log2m);
        Dictionary dictionary = dictIdsWrapper._dictionary;
        PeekableIntIterator intIterator = dictIdsWrapper._dictIdBitmap.getIntIterator();
        while (intIterator.hasNext()) {
            hyperLogLog.offer(dictionary.get(intIterator.next()));
        }
        return hyperLogLog;
    }
}
