package org.apache.pinot.core.query.aggregation.groupby;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.data.table.Table;
import org.apache.pinot.core.operator.combine.GroupByOrderByCombineOperator;
import org.apache.pinot.core.plan.AggregationGroupByOrderByPlanNode;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
import org.apache.pinot.segment.spi.IndexSegment;
import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.MetricFieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.ReadMode;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/groupby/GroupByTrimTest.class */
public class GroupByTrimTest {
    private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "GroupByTrimTest");
    private static final String SEGMENT_NAME = "testSegment";
    private static final String METRIC_PREFIX = "metric_";
    private static final int NUM_COLUMNS = 2;
    private static final int NUM_ROWS = 10000;
    private final ExecutorService _executorService = Executors.newCachedThreadPool();
    private IndexSegment _indexSegment;
    private String[] _columns;
    private double[][] _inputData;
    private Map<Double, Double> _resultMap;

    @BeforeClass
    public void setUp() throws Exception {
        FileUtils.deleteQuietly(INDEX_DIR);
        this._resultMap = new HashMap();
        this._inputData = new double[NUM_COLUMNS][NUM_ROWS];
        this._columns = new String[NUM_COLUMNS];
        setupSegment();
    }

    @AfterClass
    public void tearDown() {
        this._indexSegment.destroy();
        this._executorService.shutdown();
        FileUtils.deleteQuietly(INDEX_DIR);
    }

    @Test(dataProvider = "groupByTrimTestDataProvider")
    void testGroupByTrim(QueryContext queryContext, int i, int i2, List<Pair<Double, Double>> list) throws Exception {
        queryContext.setEndTimeMs(System.currentTimeMillis() + 15000);
        queryContext.setMinSegmentGroupTrimSize(i);
        queryContext.setMinServerGroupTrimSize(i2);
        Assert.assertEquals(extractTestResult(new GroupByOrderByCombineOperator(Collections.singletonList(new AggregationGroupByOrderByPlanNode(this._indexSegment, queryContext).run()), queryContext, this._executorService).nextBlock().getTable()), list);
    }

    private void setupSegment() throws Exception {
        SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(new TableConfigBuilder(TableType.OFFLINE).setTableName("test").build(), buildSchema());
        segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
        segmentGeneratorConfig.setOutDir(INDEX_DIR.getAbsolutePath());
        ArrayList arrayList = new ArrayList(NUM_ROWS);
        int i = 10;
        for (int i2 = 0; i2 < NUM_ROWS; i2++) {
            GenericRow genericRow = new GenericRow();
            for (int i3 = 0; i3 < NUM_COLUMNS; i3++) {
                double d = i + i2 + i3;
                this._inputData[i3][i2] = d;
                genericRow.putValue(this._columns[i3], Double.valueOf(d));
            }
            computeMaxResult(this._inputData[0][i2], this._inputData[1][i2]);
            arrayList.add(genericRow);
            i += 10;
        }
        SegmentIndexCreationDriverImpl segmentIndexCreationDriverImpl = new SegmentIndexCreationDriverImpl();
        segmentIndexCreationDriverImpl.init(segmentGeneratorConfig, new GenericRowRecordReader(arrayList));
        segmentIndexCreationDriverImpl.build();
        this._indexSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, segmentIndexCreationDriverImpl.getSegmentName()), ReadMode.heap);
    }

    private Schema buildSchema() {
        Schema schema = new Schema();
        for (int i = 0; i < NUM_COLUMNS; i++) {
            String str = METRIC_PREFIX + i;
            schema.addField(new MetricFieldSpec(str, FieldSpec.DataType.DOUBLE));
            this._columns[i] = str;
        }
        return schema;
    }

    private void computeMaxResult(double d, double d2) {
        Double d3 = this._resultMap.get(Double.valueOf(d));
        if (d3 == null || d3.doubleValue() < d2) {
            this._resultMap.put(Double.valueOf(d), Double.valueOf(d2));
        }
    }

    private List<Pair<Double, Double>> extractTestResult(Table table) {
        ArrayList arrayList = new ArrayList(table.size());
        Iterator it = table.iterator();
        while (it.hasNext()) {
            Object[] values = ((Record) it.next()).getValues();
            arrayList.add(Pair.of((Double) values[0], (Double) values[1]));
        }
        arrayList.sort((pair, pair2) -> {
            return Double.compare(((Double) pair2.getRight()).doubleValue(), ((Double) pair.getRight()).doubleValue());
        });
        return arrayList;
    }

    @DataProvider
    public Object[][] groupByTrimTestDataProvider() {
        ArrayList arrayList = new ArrayList();
        List<Pair<Double, Double>> computeExpectedResult = computeExpectedResult();
        QueryContext queryContext = QueryContextConverterUtils.getQueryContext("SELECT metric_0, max(metric_1) FROM testTable GROUP BY metric_0 ORDER BY max(metric_1) DESC LIMIT 1");
        List<Pair<Double, Double>> subList = computeExpectedResult.subList(0, 100);
        arrayList.add(new Object[]{queryContext, 100, 5000, subList});
        arrayList.add(new Object[]{queryContext, 100, -1, subList});
        arrayList.add(new Object[]{queryContext, -1, 100, subList});
        arrayList.add(new Object[]{queryContext, 5000, 100, subList});
        QueryContext queryContext2 = QueryContextConverterUtils.getQueryContext("SELECT metric_0, max(metric_1) FROM testTable GROUP BY metric_0 ORDER BY max(metric_1) DESC LIMIT 50");
        List<Pair<Double, Double>> subList2 = computeExpectedResult.subList(0, 250);
        arrayList.add(new Object[]{queryContext2, 50, 5000, subList2});
        arrayList.add(new Object[]{queryContext2, 200, -1, subList2});
        arrayList.add(new Object[]{queryContext2, -1, 150, subList2});
        arrayList.add(new Object[]{queryContext2, 5000, 10, subList2});
        arrayList.add(new Object[]{queryContext2, 20, 30, subList2});
        arrayList.add(new Object[]{QueryContextConverterUtils.getQueryContext("SELECT metric_0, max(metric_1) FROM testTable GROUP BY metric_0 ORDER BY max(metric_1) DESC LIMIT 10"), -1, -1, computeExpectedResult});
        return (Object[][]) arrayList.toArray(new Object[arrayList.size()]);
    }

    private List<Pair<Double, Double>> computeExpectedResult() {
        ArrayList arrayList = new ArrayList(this._resultMap.size());
        for (Map.Entry<Double, Double> entry : this._resultMap.entrySet()) {
            arrayList.add(Pair.of(entry.getKey(), entry.getValue()));
        }
        arrayList.sort((pair, pair2) -> {
            return Double.compare(((Double) pair2.getRight()).doubleValue(), ((Double) pair.getRight()).doubleValue());
        });
        return arrayList;
    }
}
