package org.apache.pinot.core.startree.v2;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.core.common.BlockDocIdIterator;
import org.apache.pinot.core.plan.FilterPlanNode;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.plan.StarTreeFilterPlanNode;
import org.apache.pinot.segment.local.aggregator.ValueAggregator;
import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
import org.apache.pinot.segment.local.startree.v2.builder.MultipleTreesBuilder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.IndexSegment;
import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.segment.spi.index.reader.ForwardIndexReader;
import org.apache.pinot.segment.spi.index.reader.ForwardIndexReaderContext;
import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
import org.apache.pinot.spi.config.table.FieldConfig;
import org.apache.pinot.spi.config.table.StarTreeAggregationConfig;
import org.apache.pinot.spi.config.table.StarTreeIndexConfig;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.ReadMode;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.class */
abstract class BaseStarTreeV2Test<R, A> {
    private static final String TABLE_NAME = "testTable";
    private static final String SEGMENT_NAME = "testSegment";
    private static final int NUM_SEGMENT_RECORDS = 100000;
    private static final String DIMENSION_D1 = "d1";
    private static final String DIMENSION_D2 = "d2";
    private static final String METRIC = "m";
    private static final String QUERY_FILTER_AND = " WHERE d1 = 0 AND d2 < 10";
    private static final String QUERY_FILTER_OR = " WHERE d1 > 10 OR d1 < 50";
    private static final String QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS = " WHERE d2 < 95 AND (d1 > 10 OR d1 < 50)";
    private static final String QUERY_FILTER_COMPLEX_AND_MULTIPLE_DIMENSIONS_THREE_PREDICATES = " WHERE d2 < 95 AND d2 > 25 AND (d1 > 10 OR d1 < 50)";
    private static final String QUERY_FILTER_COMPLEX_OR_MULTIPLE_DIMENSIONS_THREE_PREDICATES = " WHERE (d2 > 95 OR d2 < 25) AND (d1 > 10 OR d1 < 50)";
    private static final String QUERY_FILTER_COMPLEX_OR_SINGLE_DIMENSION = " WHERE d1 = 95 AND (d1 > 90 OR d1 < 100)";
    private static final String QUERY_FILTER_OR_MULTIPLE_DIMENSIONS = " WHERE d1 > 10 OR d2 < 50";
    private static final String QUERY_FILTER_OR_ON_AND = " WHERE (d1 > 10 AND d1 < 50) OR d1 < 50";
    private static final String QUERY_FILTER_OR_ON_NOT = " WHERE (NOT d1 > 10) OR d1 < 50";
    private static final String QUERY_FILTER_ALWAYS_FALSE = " WHERE d1 > 100";
    private static final String QUERY_FILTER_OR_ALWAYS_FALSE = " WHERE d1 > 100 OR d1 < 0";
    private static final String QUERY_GROUP_BY = " GROUP BY d2";
    private static final String FILTER_AGG_CLAUSE = " FILTER(WHERE d1 > 10)";
    private ValueAggregator _valueAggregator;
    private FieldSpec.DataType _aggregatedValueType;
    private String _aggregation;
    private IndexSegment _indexSegment;
    private StarTreeV2 _starTreeV2;
    private static final Random RANDOM = new Random();
    private static final File TEMP_DIR = new File(FileUtils.getTempDirectory(), "BaseStarTreeV2Test");
    private static final int DIMENSION_CARDINALITY = 100;
    private static final int MAX_LEAF_RECORDS = RANDOM.nextInt(DIMENSION_CARDINALITY) + 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.pinot.core.startree.v2.BaseStarTreeV2Test$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/pinot/core/startree/v2/BaseStarTreeV2Test$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType = new int[FieldSpec.DataType.values().length];

        static {
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.LONG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.BYTES.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    @BeforeClass
    public void setUp() throws Exception {
        this._valueAggregator = getValueAggregator();
        this._aggregatedValueType = this._valueAggregator.getAggregatedValueType();
        this._aggregation = getAggregation(this._valueAggregator.getAggregationType());
        Schema.SchemaBuilder addSingleValueDimension = new Schema.SchemaBuilder().addSingleValueDimension(DIMENSION_D1, FieldSpec.DataType.INT).addSingleValueDimension(DIMENSION_D2, FieldSpec.DataType.INT);
        FieldSpec.DataType rawValueType = getRawValueType();
        if (rawValueType != null) {
            addSingleValueDimension.addMetric(METRIC, rawValueType);
        }
        Schema build = addSingleValueDimension.build();
        TableConfig build2 = new TableConfigBuilder(TableType.OFFLINE).setTableName(TABLE_NAME).build();
        ArrayList arrayList = new ArrayList(NUM_SEGMENT_RECORDS);
        for (int i = 0; i < NUM_SEGMENT_RECORDS; i++) {
            GenericRow genericRow = new GenericRow();
            genericRow.putValue(DIMENSION_D1, Integer.valueOf(RANDOM.nextInt(DIMENSION_CARDINALITY)));
            genericRow.putValue(DIMENSION_D2, Integer.valueOf(RANDOM.nextInt(DIMENSION_CARDINALITY)));
            if (rawValueType != null) {
                genericRow.putValue(METRIC, getRandomRawValue(RANDOM));
            }
            arrayList.add(genericRow);
        }
        SegmentIndexCreationDriverImpl segmentIndexCreationDriverImpl = new SegmentIndexCreationDriverImpl();
        SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(build2, build);
        segmentGeneratorConfig.setOutDir(TEMP_DIR.getPath());
        segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
        segmentIndexCreationDriverImpl.init(segmentGeneratorConfig, new GenericRowRecordReader(arrayList));
        segmentIndexCreationDriverImpl.build();
        StarTreeIndexConfig starTreeIndexConfig = new StarTreeIndexConfig(Arrays.asList(DIMENSION_D1, DIMENSION_D2), (List) null, (List) null, Collections.singletonList(new StarTreeAggregationConfig(METRIC, this._valueAggregator.getAggregationType().getName(), getCompressionCodec())), MAX_LEAF_RECORDS);
        File file = new File(TEMP_DIR, SEGMENT_NAME);
        MultipleTreesBuilder multipleTreesBuilder = new MultipleTreesBuilder(Collections.singletonList(starTreeIndexConfig), false, file, RANDOM.nextBoolean() ? MultipleTreesBuilder.BuildMode.ON_HEAP : MultipleTreesBuilder.BuildMode.OFF_HEAP);
        try {
            multipleTreesBuilder.build();
            multipleTreesBuilder.close();
            this._indexSegment = ImmutableSegmentLoader.load(file, ReadMode.mmap);
            this._starTreeV2 = (StarTreeV2) this._indexSegment.getStarTrees().get(0);
        } catch (Throwable th) {
            try {
                multipleTreesBuilder.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    String getAggregation(AggregationFunctionType aggregationFunctionType) {
        return aggregationFunctionType == AggregationFunctionType.COUNT ? "COUNT(*)" : (aggregationFunctionType == AggregationFunctionType.PERCENTILEEST || aggregationFunctionType == AggregationFunctionType.PERCENTILETDIGEST) ? String.format("%s(%s, 50)", aggregationFunctionType.getName(), METRIC) : String.format("%s(%s)", aggregationFunctionType.getName(), METRIC);
    }

    @Test
    public void testUnsupportedFilters() {
        String format = String.format("SELECT %s FROM %s", this._aggregation, TABLE_NAME);
        testUnsupportedFilter(format + " WHERE d1 > 10 OR d2 < 50");
        testUnsupportedFilter(format + " WHERE (d1 > 10 AND d1 < 50) OR d1 < 50");
        testUnsupportedFilter(format + " WHERE (NOT d1 > 10) OR d1 < 50");
        testUnsupportedFilter(format + " WHERE d1 > 100");
        testUnsupportedFilter(format + " WHERE d1 > 100 OR d1 < 0");
    }

    @Test
    public void testQueries() throws IOException {
        String format = String.format("SELECT %s FROM %s", this._aggregation, TABLE_NAME);
        for (String str : Arrays.asList(format, String.format("SELECT %s%s FROM %s", this._aggregation, FILTER_AGG_CLAUSE, TABLE_NAME))) {
            testQuery(str);
            testQuery(str + " WHERE d1 = 0 AND d2 < 10");
            testQuery(str + " WHERE d1 > 10 OR d1 < 50");
            testQuery(str + " WHERE d2 < 95 AND (d1 > 10 OR d1 < 50)");
            testQuery(str + " WHERE d2 < 95 AND d2 > 25 AND (d1 > 10 OR d1 < 50)");
            testQuery(str + " WHERE (d2 > 95 OR d2 < 25) AND (d1 > 10 OR d1 < 50)");
            testQuery(str + " WHERE d1 = 95 AND (d1 > 90 OR d1 < 100)");
        }
        testQuery(format + " GROUP BY d2");
    }

    @AfterClass
    public void tearDown() throws IOException {
        this._indexSegment.destroy();
        FileUtils.deleteDirectory(TEMP_DIR);
    }

    private void testUnsupportedFilter(String str) {
        QueryContext queryContext = QueryContextConverterUtils.getQueryContext(str);
        FilterPlanNode filterPlanNode = new FilterPlanNode(this._indexSegment, queryContext);
        filterPlanNode.run();
        Assert.assertNull(StarTreeUtils.extractPredicateEvaluatorsMap(this._indexSegment, queryContext.getFilter(), filterPlanNode.getPredicateEvaluators()));
    }

    private void testQuery(String str) throws IOException {
        QueryContext queryContext = QueryContextConverterUtils.getQueryContext(str);
        AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
        Assert.assertNotNull(aggregationFunctions);
        int length = aggregationFunctions.length;
        AggregationFunctionColumnPair[] extractAggregationFunctionPairs = StarTreeUtils.extractAggregationFunctionPairs(aggregationFunctions);
        Assert.assertNotNull(extractAggregationFunctionPairs);
        HashSet hashSet = new HashSet();
        List groupByExpressions = queryContext.getGroupByExpressions();
        if (groupByExpressions != null) {
            Iterator it = groupByExpressions.iterator();
            while (it.hasNext()) {
                ((ExpressionContext) it.next()).getColumns(hashSet);
            }
        }
        int size = hashSet.size();
        ArrayList arrayList = new ArrayList(hashSet);
        FilterPlanNode filterPlanNode = new FilterPlanNode(this._indexSegment, queryContext);
        filterPlanNode.run();
        Map extractPredicateEvaluatorsMap = StarTreeUtils.extractPredicateEvaluatorsMap(this._indexSegment, queryContext.getFilter(), filterPlanNode.getPredicateEvaluators());
        Assert.assertNotNull(extractPredicateEvaluatorsMap);
        StarTreeFilterPlanNode starTreeFilterPlanNode = new StarTreeFilterPlanNode(queryContext, this._starTreeV2, extractPredicateEvaluatorsMap, hashSet);
        ArrayList arrayList2 = new ArrayList(length);
        for (AggregationFunctionColumnPair aggregationFunctionColumnPair : extractAggregationFunctionPairs) {
            arrayList2.add(this._starTreeV2.getDataSource(aggregationFunctionColumnPair.toColumnName()).getForwardIndex());
        }
        ArrayList arrayList3 = new ArrayList(size);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            arrayList3.add(this._starTreeV2.getDataSource((String) it2.next()).getForwardIndex());
        }
        Map<List<Integer>, List<Object>> computeStarTreeResult = computeStarTreeResult(starTreeFilterPlanNode, arrayList2, arrayList3);
        FilterPlanNode filterPlanNode2 = new FilterPlanNode(this._indexSegment, queryContext);
        ArrayList arrayList4 = new ArrayList(length);
        ArrayList arrayList5 = new ArrayList(length);
        for (AggregationFunctionColumnPair aggregationFunctionColumnPair2 : extractAggregationFunctionPairs) {
            if (aggregationFunctionColumnPair2.getFunctionType() == AggregationFunctionType.COUNT) {
                arrayList4.add(null);
                arrayList5.add(null);
            } else {
                DataSource dataSource = this._indexSegment.getDataSource(aggregationFunctionColumnPair2.getColumn());
                arrayList4.add(dataSource.getForwardIndex());
                arrayList5.add(dataSource.getDictionary());
            }
        }
        ArrayList arrayList6 = new ArrayList(size);
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            arrayList6.add(this._indexSegment.getDataSource((String) it3.next()).getForwardIndex());
        }
        Map<List<Integer>, List<Object>> computeNonStarTreeResult = computeNonStarTreeResult(filterPlanNode2, arrayList4, arrayList5, arrayList6);
        Assert.assertEquals(computeStarTreeResult.size(), computeNonStarTreeResult.size());
        for (Map.Entry<List<Integer>, List<Object>> entry : computeStarTreeResult.entrySet()) {
            List<Integer> key = entry.getKey();
            Assert.assertTrue(computeNonStarTreeResult.containsKey(key));
            List<Object> value = entry.getValue();
            List<Object> list = computeNonStarTreeResult.get(key);
            for (int i = 0; i < length; i++) {
                assertAggregatedValue(value.get(i), list.get(i));
            }
        }
    }

    private Map<List<Integer>, List<Object>> computeStarTreeResult(StarTreeFilterPlanNode starTreeFilterPlanNode, List<ForwardIndexReader> list, List<ForwardIndexReader> list2) throws IOException {
        HashMap hashMap = new HashMap();
        int size = list.size();
        int size2 = list2.size();
        ArrayList<ForwardIndexReaderContext> arrayList = new ArrayList(size);
        ArrayList<ForwardIndexReaderContext> arrayList2 = new ArrayList(size2);
        try {
            Iterator<ForwardIndexReader> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().createContext());
            }
            Iterator<ForwardIndexReader> it2 = list2.iterator();
            while (it2.hasNext()) {
                arrayList2.add(it2.next().createContext());
            }
            BlockDocIdIterator it3 = starTreeFilterPlanNode.run().nextBlock().getBlockDocIdSet().iterator();
            while (true) {
                int next = it3.next();
                if (next == Integer.MIN_VALUE) {
                    break;
                }
                ArrayList arrayList3 = new ArrayList(size2);
                for (int i = 0; i < size2; i++) {
                    arrayList3.add(Integer.valueOf(list2.get(i).getDictId(next, (ForwardIndexReaderContext) arrayList2.get(i))));
                }
                List list3 = (List) hashMap.computeIfAbsent(arrayList3, list4 -> {
                    return new ArrayList(size);
                });
                if (list3.isEmpty()) {
                    for (int i2 = 0; i2 < size; i2++) {
                        list3.add(getAggregatedValue(next, list.get(i2), (ForwardIndexReaderContext) arrayList.get(i2)));
                    }
                } else {
                    for (int i3 = 0; i3 < size; i3++) {
                        list3.set(i3, this._valueAggregator.applyAggregatedValue(list3.get(i3), getAggregatedValue(next, list.get(i3), (ForwardIndexReaderContext) arrayList.get(i3))));
                    }
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext : arrayList) {
                if (forwardIndexReaderContext != null) {
                    forwardIndexReaderContext.close();
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext2 : arrayList2) {
                if (forwardIndexReaderContext2 != null) {
                    forwardIndexReaderContext2.close();
                }
            }
            return hashMap;
        } catch (Throwable th) {
            for (ForwardIndexReaderContext forwardIndexReaderContext3 : arrayList) {
                if (forwardIndexReaderContext3 != null) {
                    forwardIndexReaderContext3.close();
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext4 : arrayList2) {
                if (forwardIndexReaderContext4 != null) {
                    forwardIndexReaderContext4.close();
                }
            }
            throw th;
        }
    }

    private Object getAggregatedValue(int i, ForwardIndexReader forwardIndexReader, ForwardIndexReaderContext forwardIndexReaderContext) {
        switch (AnonymousClass1.$SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[this._aggregatedValueType.ordinal()]) {
            case 1:
                return Long.valueOf(forwardIndexReader.getLong(i, forwardIndexReaderContext));
            case 2:
                return Double.valueOf(forwardIndexReader.getDouble(i, forwardIndexReaderContext));
            case 3:
                return this._valueAggregator.deserializeAggregatedValue(forwardIndexReader.getBytes(i, forwardIndexReaderContext));
            default:
                throw new IllegalStateException();
        }
    }

    private Map<List<Integer>, List<Object>> computeNonStarTreeResult(FilterPlanNode filterPlanNode, List<ForwardIndexReader> list, List<Dictionary> list2, List<ForwardIndexReader> list3) throws IOException {
        HashMap hashMap = new HashMap();
        int size = list.size();
        int size2 = list3.size();
        ArrayList<ForwardIndexReaderContext> arrayList = new ArrayList(size);
        ArrayList<ForwardIndexReaderContext> arrayList2 = new ArrayList(size2);
        try {
            for (ForwardIndexReader forwardIndexReader : list) {
                if (forwardIndexReader != null) {
                    arrayList.add(forwardIndexReader.createContext());
                } else {
                    arrayList.add(null);
                }
            }
            Iterator<ForwardIndexReader> it = list3.iterator();
            while (it.hasNext()) {
                arrayList2.add(it.next().createContext());
            }
            BlockDocIdIterator it2 = filterPlanNode.run().nextBlock().getBlockDocIdSet().iterator();
            while (true) {
                int next = it2.next();
                if (next == Integer.MIN_VALUE) {
                    break;
                }
                ArrayList arrayList3 = new ArrayList(size2);
                for (int i = 0; i < size2; i++) {
                    arrayList3.add(Integer.valueOf(list3.get(i).getDictId(next, (ForwardIndexReaderContext) arrayList2.get(i))));
                }
                List list4 = (List) hashMap.computeIfAbsent(arrayList3, list5 -> {
                    return new ArrayList(size);
                });
                if (list4.isEmpty()) {
                    for (int i2 = 0; i2 < size; i2++) {
                        ForwardIndexReader forwardIndexReader2 = list.get(i2);
                        if (forwardIndexReader2 == null) {
                            list4.add(1L);
                        } else {
                            list4.add(this._valueAggregator.getInitialAggregatedValue(getNextRawValue(next, forwardIndexReader2, (ForwardIndexReaderContext) arrayList.get(i2), list2.get(i2))));
                        }
                    }
                } else {
                    for (int i3 = 0; i3 < size; i3++) {
                        Object obj = list4.get(i3);
                        ForwardIndexReader forwardIndexReader3 = list.get(i3);
                        list4.set(i3, forwardIndexReader3 == null ? Long.valueOf(((Long) obj).longValue() + 1) : this._valueAggregator.applyRawValue(obj, getNextRawValue(next, forwardIndexReader3, (ForwardIndexReaderContext) arrayList.get(i3), list2.get(i3))));
                    }
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext : arrayList) {
                if (forwardIndexReaderContext != null) {
                    forwardIndexReaderContext.close();
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext2 : arrayList2) {
                if (forwardIndexReaderContext2 != null) {
                    forwardIndexReaderContext2.close();
                }
            }
            return hashMap;
        } catch (Throwable th) {
            for (ForwardIndexReaderContext forwardIndexReaderContext3 : arrayList) {
                if (forwardIndexReaderContext3 != null) {
                    forwardIndexReaderContext3.close();
                }
            }
            for (ForwardIndexReaderContext forwardIndexReaderContext4 : arrayList2) {
                if (forwardIndexReaderContext4 != null) {
                    forwardIndexReaderContext4.close();
                }
            }
            throw th;
        }
    }

    private Object getNextRawValue(int i, ForwardIndexReader forwardIndexReader, ForwardIndexReaderContext forwardIndexReaderContext, Dictionary dictionary) {
        return dictionary.get(forwardIndexReader.getDictId(i, forwardIndexReaderContext));
    }

    FieldConfig.CompressionCodec getCompressionCodec() {
        FieldConfig.CompressionCodec compressionCodec;
        FieldConfig.CompressionCodec[] values = FieldConfig.CompressionCodec.values();
        do {
            compressionCodec = values[RANDOM.nextInt(values.length)];
        } while (!compressionCodec.isApplicableToRawIndex());
        return compressionCodec;
    }

    abstract ValueAggregator<R, A> getValueAggregator();

    abstract FieldSpec.DataType getRawValueType();

    abstract R getRandomRawValue(Random random);

    abstract void assertAggregatedValue(A a, A a2);
}
