package org.apache.pinot.queries;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.pinot.common.utils.HashUtil;
import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.query.AggregationOperator;
import org.apache.pinot.core.operator.query.GroupByOperator;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.segment.spi.IndexSegment;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/queries/BaseFunnelCountQueriesTest.class */
public abstract class BaseFunnelCountQueriesTest extends BaseQueriesTest {
    protected static final String SEGMENT_NAME = "testSegment";
    protected static final int NUM_RECORDS = 2000;
    protected static final int MAX_VALUE = 1000;
    protected static final int NUM_GROUPS = 100;
    protected static final int FILTER_LIMIT = 50;
    private Set<Integer>[] _values = new Set[2];
    private List<Integer> _all = new ArrayList();
    private IndexSegment _indexSegment;
    private List<IndexSegment> _indexSegments;
    protected static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "FunnelCountQueriesTest");
    protected static final Random RANDOM = new Random();
    protected static final String[] STEPS = {"A", "B"};
    protected static final String ID_COLUMN = "idColumn";
    protected static final String STEP_COLUMN = "stepColumn";
    protected static final Schema SCHEMA = new Schema.SchemaBuilder().addSingleValueDimension(ID_COLUMN, FieldSpec.DataType.INT).addSingleValueDimension(STEP_COLUMN, FieldSpec.DataType.STRING).build();
    protected static final String RAW_TABLE_NAME = "testTable";
    protected static final TableConfigBuilder TABLE_CONFIG_BUILDER = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME);

    protected abstract int getExpectedNumEntriesScannedInFilter();

    protected abstract int getExpectedInterSegmentMultiplier();

    protected abstract TableConfig getTableConfig();

    protected abstract IndexSegment buildSegment(List<GenericRow> list) throws Exception;

    protected abstract void assertIntermediateResult(Object obj, long[] jArr);

    protected abstract String getSettings();

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected String getFilter() {
        return String.format(" WHERE idColumn >= %s", Integer.valueOf(FILTER_LIMIT));
    }

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected IndexSegment getIndexSegment() {
        return this._indexSegment;
    }

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected List<IndexSegment> getIndexSegments() {
        return this._indexSegments;
    }

    @BeforeClass
    public void setUp() throws Exception {
        FileUtils.deleteDirectory(INDEX_DIR);
        this._indexSegment = buildSegment(genereateRows());
        this._indexSegments = Arrays.asList(this._indexSegment, this._indexSegment);
    }

    private List<GenericRow> genereateRows() {
        ArrayList arrayList = new ArrayList(NUM_RECORDS);
        int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
        this._values[0] = new HashSet(hashMapCapacity);
        this._values[1] = new HashSet(hashMapCapacity);
        for (int i = 0; i < NUM_RECORDS; i++) {
            int nextInt = RANDOM.nextInt(MAX_VALUE);
            GenericRow genericRow = new GenericRow();
            genericRow.putValue(ID_COLUMN, Integer.valueOf(nextInt));
            genericRow.putValue(STEP_COLUMN, STEPS[i % 2]);
            arrayList.add(genericRow);
            this._all.add(Integer.valueOf(Integer.hashCode(nextInt)));
            this._values[i % 2].add(Integer.valueOf(Integer.hashCode(nextInt)));
        }
        return arrayList;
    }

    private String getFunnelCountSql() {
        return "FUNNEL_COUNT( STEPS(stepColumn = 'A', stepColumn = 'B'), CORRELATE_BY(idColumn), " + getSettings() + ") ";
    }

    @Test
    public void testAggregationOnly() {
        String format = String.format("SELECT " + getFunnelCountSql() + "FROM testTable", new Object[0]);
        Predicate<? super Integer> predicate = num -> {
            return num.intValue() >= FILTER_LIMIT;
        };
        long count = this._all.stream().filter(predicate).count();
        Set set = (Set) this._values[0].stream().filter(predicate).collect(Collectors.toSet());
        Set set2 = (Set) this._values[1].stream().filter(predicate).collect(Collectors.toSet());
        new HashSet(set).retainAll(set2);
        long[] jArr = new long[2];
        jArr[0] = set.size();
        jArr[1] = r0.size();
        AggregationOperator operatorWithFilter = getOperatorWithFilter(format);
        Assert.assertTrue(operatorWithFilter instanceof AggregationOperator);
        AggregationResultsBlock nextBlock = operatorWithFilter.nextBlock();
        QueriesTestUtils.testInnerSegmentExecutionStatistics(operatorWithFilter.getExecutionStatistics(), count, getExpectedNumEntriesScannedInFilter(), 2 * count, 2000L);
        List results = nextBlock.getResults();
        Assert.assertNotNull(results);
        Assert.assertEquals(results.size(), 1);
        assertIntermediateResult(results.get(0), jArr);
        for (int i = 0; i < 2; i++) {
            jArr[i] = jArr[i] * getExpectedInterSegmentMultiplier();
        }
        QueriesTestUtils.testInterSegmentsResult(getBrokerResponseWithFilter(format), 4 * count, 4 * getExpectedNumEntriesScannedInFilter(), 8 * count, 8000L, new Object[]{jArr});
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testAggregationGroupBy() {
        String format = String.format("SELECT MOD(idColumn, %s), " + getFunnelCountSql() + "FROM testTable WHERE idColumn >= %s GROUP BY 1 ORDER BY 1 LIMIT %s", Integer.valueOf(NUM_GROUPS), Integer.valueOf(FILTER_LIMIT), Integer.valueOf(NUM_GROUPS));
        Set[] setArr = new Set[NUM_GROUPS];
        Set[] setArr2 = new Set[NUM_GROUPS];
        Set[] setArr3 = new Set[NUM_GROUPS];
        long[] jArr = new long[NUM_GROUPS];
        long count = this._all.stream().filter(num -> {
            return num.intValue() >= FILTER_LIMIT;
        }).count();
        int i = 0;
        for (int i2 = 0; i2 < NUM_GROUPS; i2++) {
            int i3 = i2;
            Predicate<? super Integer> predicate = num2 -> {
                return num2.intValue() >= FILTER_LIMIT && num2.intValue() % NUM_GROUPS == i3;
            };
            setArr[i3] = (Set) this._values[0].stream().filter(predicate).collect(Collectors.toSet());
            setArr2[i3] = (Set) this._values[1].stream().filter(predicate).collect(Collectors.toSet());
            setArr3[i3] = new HashSet(setArr[i3]);
            setArr3[i3].retainAll(setArr2[i3]);
            if (!setArr[i2].isEmpty() || !setArr2[i2].isEmpty()) {
                i++;
                long[] jArr2 = new long[2];
                jArr2[0] = setArr[i3].size();
                jArr2[1] = setArr3[i3].size();
                jArr[i3] = jArr2;
            }
        }
        GroupByOperator operator = getOperator(format);
        GroupByResultsBlock nextBlock = operator.nextBlock();
        QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), count, getExpectedNumEntriesScannedInFilter(), 2 * count, 2000L);
        AggregationGroupByResult aggregationGroupByResult = nextBlock.getAggregationGroupByResult();
        Assert.assertNotNull(aggregationGroupByResult);
        int i4 = 0;
        Iterator groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator();
        while (groupKeyIterator.hasNext()) {
            i4++;
            GroupKeyGenerator.GroupKey groupKey = (GroupKeyGenerator.GroupKey) groupKeyIterator.next();
            assertIntermediateResult(aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), jArr[((Double) groupKey._keys[0]).intValue()]);
        }
        Assert.assertEquals(i4, i);
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < NUM_GROUPS; i5++) {
            if (jArr[i5] != 0) {
                for (int i6 = 0; i6 < 2; i6++) {
                    jArr[i5][i6] = jArr[i5][i6] * getExpectedInterSegmentMultiplier();
                }
                arrayList.add(new Object[]{Double.valueOf(i5), jArr[i5]});
            }
        }
        QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(format), 4 * count, 4 * getExpectedNumEntriesScannedInFilter(), 8 * count, 8000L, arrayList);
    }

    @AfterClass
    public void tearDown() throws IOException {
        this._indexSegment.destroy();
        FileUtils.deleteDirectory(INDEX_DIR);
    }
}
