package org.apache.pinot.core.query.aggregation.function;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.datasketches.tuple.CompactSketch;
import org.apache.datasketches.tuple.Sketch;
import org.apache.datasketches.tuple.Union;
import org.apache.datasketches.tuple.aninteger.IntegerSummary;
import org.apache.datasketches.tuple.aninteger.IntegerSummarySetOperations;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.shaded.com.google.common.base.Preconditions;
import org.apache.pinot.spi.data.FieldSpec;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/IntegerTupleSketchAggregationFunction.class */
public class IntegerTupleSketchAggregationFunction extends BaseSingleInputAggregationFunction<List<CompactSketch<IntegerSummary>>, Comparable> {
    final ExpressionContext _expressionContext;
    final IntegerSummarySetOperations _setOps;
    final int _entries;

    public IntegerTupleSketchAggregationFunction(List<ExpressionContext> list, IntegerSummary.Mode mode) {
        super(list.get(0));
        Preconditions.checkArgument(list.size() <= 2, "Tuple Sketch Aggregation Function expects at most 2 arguments, got: %s", list.size());
        this._expressionContext = list.get(0);
        this._setOps = new IntegerSummarySetOperations(mode, mode);
        if (list.size() != 2) {
            this._entries = (int) Math.pow(2.0d, 16.0d);
            return;
        }
        FieldSpec.DataType type = list.get(1).getLiteral().getType();
        Preconditions.checkArgument(type == FieldSpec.DataType.LONG || type == FieldSpec.DataType.INT, "Tuple Sketch Aggregation Function expected the second argument to be a number of entries to keep, but it was of type %s", type.toString());
        this._entries = ((Long) list.get(1).getLiteral().getValue()).intValue();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType != FieldSpec.DataType.BYTES) {
            throw new IllegalStateException("Illegal data type for " + getType() + " aggregation function: " + storedType);
        }
        byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
        try {
            if (((List) aggregationResultHolder.getResult()) != null) {
                Stream stream = Arrays.stream(bytesValuesSV);
                ObjectSerDeUtils.ObjectSerDe<Sketch<IntegerSummary>> objectSerDe = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE;
                Objects.requireNonNull(objectSerDe);
                aggregationResultHolder.setValue(merge((List<CompactSketch<IntegerSummary>>) aggregationResultHolder.getResult(), (List<CompactSketch<IntegerSummary>>) stream.map(objectSerDe::deserialize2).map((v0) -> {
                    return v0.compact();
                }).collect(Collectors.toList())));
            } else {
                Stream stream2 = Arrays.stream(bytesValuesSV);
                ObjectSerDeUtils.ObjectSerDe<Sketch<IntegerSummary>> objectSerDe2 = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE;
                Objects.requireNonNull(objectSerDe2);
                aggregationResultHolder.setValue((List) stream2.map(objectSerDe2::deserialize2).map((v0) -> {
                    return v0.compact();
                }).collect(Collectors.toList()));
            }
        } catch (Exception e) {
            throw new RuntimeException("Caught exception while merging Tuple Sketches", e);
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType != FieldSpec.DataType.BYTES) {
            throw new IllegalStateException("Illegal data type for INTEGER_TUPLE_SKETCH_UNION aggregation function: " + storedType);
        }
        byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
        for (int i2 = 0; i2 < i; i2++) {
            try {
                byte[] bArr = bytesValuesSV[i2];
                int i3 = iArr[i2];
                CompactSketch<IntegerSummary> compact = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize2(bArr).compact();
                if (groupByResultHolder.getResult(i3) == null) {
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(compact);
                    groupByResultHolder.setValueForKey(i3, arrayList);
                } else {
                    ((List) groupByResultHolder.getResult(i3)).add(compact);
                }
            } catch (Exception e) {
                throw new RuntimeException("Caught exception while merging Tuple Sketches", e);
            }
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        byte[][] bytesValuesSV = map.get(this._expression).getBytesValuesSV();
        for (int i2 = 0; i2 < i; i2++) {
            CompactSketch<IntegerSummary> compact = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize2(bytesValuesSV[i2]).compact();
            for (int i3 : iArr[i2]) {
                if (groupByResultHolder.getResult(i3) == null) {
                    groupByResultHolder.setValueForKey(i3, Collections.singletonList(compact));
                } else {
                    ((List) groupByResultHolder.getResult(i3)).add(compact);
                }
            }
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public List<CompactSketch<IntegerSummary>> extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        return (List) aggregationResultHolder.getResult();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public List<CompactSketch<IntegerSummary>> extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        return (List) groupByResultHolder.getResult(i);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public List<CompactSketch<IntegerSummary>> merge(List<CompactSketch<IntegerSummary>> list, List<CompactSketch<IntegerSummary>> list2) {
        if (list == null && list2 != null) {
            return list2;
        }
        if (list != null && list2 == null) {
            return list;
        }
        if (list == null && list2 == null) {
            return new ArrayList(0);
        }
        ArrayList arrayList = new ArrayList(list.size() + list2.size());
        arrayList.addAll(list);
        arrayList.addAll(list2);
        return arrayList;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Comparable extractFinalResult(List<CompactSketch<IntegerSummary>> list) {
        if (list == null) {
            return null;
        }
        Union union = new Union(this._entries, this._setOps);
        Objects.requireNonNull(union);
        list.forEach((v1) -> {
            r1.union(v1);
        });
        return Base64.getEncoder().encodeToString(union.getResult().toByteArray());
    }
}
