package org.apache.pinot.core.query.aggregation.function;

import java.util.List;
import java.util.Map;
import org.apache.datasketches.common.ArrayOfStringsSerDe;
import org.apache.datasketches.frequencies.ItemsSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.CustomObject;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.local.customobject.SerializedFrequentStringsSketch;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/FrequentStringsSketchAggregationFunction.class */
public class FrequentStringsSketchAggregationFunction extends BaseSingleInputAggregationFunction<ItemsSketch<String>, Comparable<?>> {
    protected static final int DEFAULT_MAX_MAP_SIZE = 256;
    protected int _maxMapSize;

    public FrequentStringsSketchAggregationFunction(List<ExpressionContext> list) {
        super(list.get(0));
        int size = list.size();
        Preconditions.checkArgument(size == 1 || size == 2, "Expecting 1 or 2 arguments for FrequentItemsSketch function: FREQUENTSTRINGSSKETCH(column, maxMapSize");
        this._maxMapSize = size == 2 ? list.get(1).getLiteral().getIntValue() : 256;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.FREQUENTSTRINGSSKETCH;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        FieldSpec.DataType valueType = blockValSet.getValueType();
        ItemsSketch<String> orCreateSketch = getOrCreateSketch(aggregationResultHolder);
        if (valueType != FieldSpec.DataType.BYTES) {
            for (String str : blockValSet.getStringValuesSV()) {
                orCreateSketch.update(str);
            }
            return;
        }
        ItemsSketch<String>[] deserializeSketches = deserializeSketches(map.get(this._expression).getBytesValuesSV());
        ItemsSketch<String> orCreateSketch2 = getOrCreateSketch(aggregationResultHolder);
        for (ItemsSketch<String> itemsSketch : deserializeSketches) {
            orCreateSketch2.merge(itemsSketch);
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        if (blockValSet.getValueType() == FieldSpec.DataType.BYTES) {
            ItemsSketch<String>[] deserializeSketches = deserializeSketches(map.get(this._expression).getBytesValuesSV());
            for (int i2 = 0; i2 < i; i2++) {
                getOrCreateSketch(groupByResultHolder, iArr[i2]).merge(deserializeSketches[i2]);
            }
            return;
        }
        String[] stringValuesSV = blockValSet.getStringValuesSV();
        for (int i3 = 0; i3 < i; i3++) {
            getOrCreateSketch(groupByResultHolder, iArr[i3]).update(stringValuesSV[i3]);
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        BlockValSet blockValSet = map.get(this._expression);
        if (blockValSet.getValueType() == FieldSpec.DataType.BYTES) {
            ItemsSketch<String>[] deserializeSketches = deserializeSketches(map.get(this._expression).getBytesValuesSV());
            for (int i2 = 0; i2 < i; i2++) {
                for (int i3 : iArr[i2]) {
                    getOrCreateSketch(groupByResultHolder, i3).merge(deserializeSketches[i2]);
                }
            }
            return;
        }
        String[] stringValuesSV = blockValSet.getStringValuesSV();
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 : iArr[i4]) {
                getOrCreateSketch(groupByResultHolder, i5).update(stringValuesSV[i4]);
            }
        }
    }

    protected ItemsSketch<String> getOrCreateSketch(AggregationResultHolder aggregationResultHolder) {
        ItemsSketch<String> itemsSketch = (ItemsSketch) aggregationResultHolder.getResult();
        if (itemsSketch == null) {
            itemsSketch = new ItemsSketch<>(this._maxMapSize);
            aggregationResultHolder.setValue(itemsSketch);
        }
        return itemsSketch;
    }

    protected ItemsSketch<String> getOrCreateSketch(GroupByResultHolder groupByResultHolder, int i) {
        ItemsSketch<String> itemsSketch = (ItemsSketch) groupByResultHolder.getResult(i);
        if (itemsSketch == null) {
            itemsSketch = new ItemsSketch<>(this._maxMapSize);
            groupByResultHolder.setValueForKey(i, itemsSketch);
        }
        return itemsSketch;
    }

    protected ItemsSketch<String>[] deserializeSketches(byte[][] bArr) {
        ItemsSketch<String>[] itemsSketchArr = new ItemsSketch[bArr.length];
        for (int i = 0; i < bArr.length; i++) {
            itemsSketchArr[i] = ItemsSketch.getInstance(Memory.wrap(bArr[i]), new ArrayOfStringsSerDe());
        }
        return itemsSketchArr;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public ItemsSketch<String> extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        return (ItemsSketch) aggregationResultHolder.getResult();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public ItemsSketch<String> extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        return (ItemsSketch) groupByResultHolder.getResult(i);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public ItemsSketch<String> merge(ItemsSketch<String> itemsSketch, ItemsSketch<String> itemsSketch2) {
        ItemsSketch<String> itemsSketch3 = new ItemsSketch<>(this._maxMapSize);
        if (itemsSketch != null) {
            itemsSketch3.merge(itemsSketch);
        }
        if (itemsSketch2 != null) {
            itemsSketch3.merge(itemsSketch2);
        }
        return itemsSketch3;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunction.SerializedIntermediateResult serializeIntermediateResult(ItemsSketch<String> itemsSketch) {
        return new AggregationFunction.SerializedIntermediateResult(ObjectSerDeUtils.ObjectType.FrequentStringsSketch.getValue(), ObjectSerDeUtils.FREQUENT_STRINGS_SKETCH_SER_DE.serialize(itemsSketch));
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public ItemsSketch<String> deserializeIntermediateResult(CustomObject customObject) {
        return ObjectSerDeUtils.FREQUENT_STRINGS_SKETCH_SER_DE.deserialize2(customObject.getBuffer());
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction, org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public String getResultColumnName() {
        return AggregationFunctionType.FREQUENTSTRINGSSKETCH.getName().toLowerCase() + "(" + String.valueOf(this._expression) + ")";
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Comparable<?> extractFinalResult(ItemsSketch<String> itemsSketch) {
        return new SerializedFrequentStringsSketch(itemsSketch);
    }
}
