package org.apache.pinot.queries;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.pinot.core.operator.query.AggregationGroupByOrderByOperator;
import org.apache.pinot.core.operator.query.AggregationOperator;
import org.apache.pinot.core.operator.query.SelectionOnlyOperator;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
import org.apache.pinot.segment.spi.ImmutableSegment;
import org.apache.pinot.segment.spi.IndexSegment;
import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.ReadMode;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/queries/CastQueriesTest.class */
public class CastQueriesTest extends BaseQueriesTest {
    private static final String SEGMENT_NAME = "testSegment";
    private static final int NUM_RECORDS = 1000;
    private static final int BUCKET_SIZE = 8;
    private static final String CLASSIFICATION_COLUMN = "class";
    private IndexSegment _indexSegment;
    private List<IndexSegment> _indexSegments;
    private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "CastQueriesTest");
    private static final String X_COL = "x";
    private static final String Y_COL = "y";
    private static final Schema SCHEMA = new Schema.SchemaBuilder().addSingleValueDimension(X_COL, FieldSpec.DataType.DOUBLE).addSingleValueDimension(Y_COL, FieldSpec.DataType.DOUBLE).addSingleValueDimension("class", FieldSpec.DataType.STRING).build();
    private static final String RAW_TABLE_NAME = "testTable";
    private static final TableConfig TABLE_CONFIG = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected String getFilter() {
        return "";
    }

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected IndexSegment getIndexSegment() {
        return this._indexSegment;
    }

    @Override // org.apache.pinot.queries.BaseQueriesTest
    protected List<IndexSegment> getIndexSegments() {
        return this._indexSegments;
    }

    @BeforeClass
    public void setUp() throws Exception {
        FileUtils.deleteQuietly(INDEX_DIR);
        ArrayList arrayList = new ArrayList(1000);
        for (int i = 0; i < 1000; i++) {
            GenericRow genericRow = new GenericRow();
            genericRow.putValue(X_COL, Double.valueOf(0.5d));
            genericRow.putValue(Y_COL, Double.valueOf(0.25d));
            genericRow.putValue("class", (i % 8));
            arrayList.add(genericRow);
        }
        SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
        segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
        segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
        segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
        SegmentIndexCreationDriverImpl segmentIndexCreationDriverImpl = new SegmentIndexCreationDriverImpl();
        segmentIndexCreationDriverImpl.init(segmentGeneratorConfig, new GenericRowRecordReader(arrayList));
        segmentIndexCreationDriverImpl.build();
        ImmutableSegment load = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
        this._indexSegment = load;
        this._indexSegments = Arrays.asList(load, load);
    }

    @Test
    public void testCastSum() {
        List<Object> results = ((AggregationOperator) getOperator("select cast(sum(x) as int), cast(sum(y) as int) from testTable")).nextBlock().getResults();
        Assert.assertNotNull(results);
        Assert.assertEquals(results.size(), 2);
        Assert.assertEquals(((Number) results.get(0)).intValue(), 500);
        Assert.assertEquals(((Number) results.get(1)).intValue(), 250);
    }

    @Test
    public void testCastSumGroupBy() {
        AggregationGroupByResult aggregationGroupByResult = ((AggregationGroupByOrderByOperator) getOperator("select cast(sum(x) as int), cast(sum(y) as int) from testTable group by class")).nextBlock().getAggregationGroupByResult();
        Assert.assertNotNull(aggregationGroupByResult);
        Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator();
        while (groupKeyIterator.hasNext()) {
            GroupKeyGenerator.GroupKey next = groupKeyIterator.next();
            Assert.assertEquals(((Number) aggregationGroupByResult.getResultForGroupId(0, next._groupId)).intValue(), 62);
            Assert.assertEquals(((Number) aggregationGroupByResult.getResultForGroupId(1, next._groupId)).intValue(), 31);
        }
    }

    @Test
    public void testCastFilterAndProject() {
        Collection<Object[]> rows = ((SelectionOnlyOperator) getOperator("select cast(class as int) from testTable where class = cast(0 as string) limit 1000")).nextBlock().getRows();
        Assert.assertNotNull(rows);
        Assert.assertEquals(rows.size(), 125);
        for (Object[] objArr : rows) {
            Assert.assertEquals(objArr.length, 1);
            Assert.assertEquals(objArr[0], (Object) 0);
        }
    }
}
