package org.apache.pinot.integration.tests.custom;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.avro.Schema;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.datasketches.theta.UpdateSketch;
import org.apache.datasketches.theta.UpdateSketchBuilder;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.testng.Assert;
import org.testng.annotations.Test;

@Test(suiteName = "CustomClusterIntegrationTest")
/* loaded from: input_file:org/apache/pinot/integration/tests/custom/ThetaSketchTest.class */
public class ThetaSketchTest extends CustomDataQueryClusterIntegrationTest {
    private static final String DEFAULT_TABLE_NAME = "ThetaSketchTest";
    private static final String DIM_NAME = "dimName";
    private static final String DIM_VALUE = "dimValue";
    private static final String SHARD_ID = "shardId";
    private static final String THETA_SKETCH = "thetaSketchCol";

    @Override // org.apache.pinot.integration.tests.custom.CustomDataQueryClusterIntegrationTest
    public String getTableName() {
        return DEFAULT_TABLE_NAME;
    }

    @Override // org.apache.pinot.integration.tests.custom.CustomDataQueryClusterIntegrationTest
    public Schema createSchema() {
        return new Schema.SchemaBuilder().setSchemaName(getTableName()).addSingleValueDimension(DIM_NAME, FieldSpec.DataType.STRING).addSingleValueDimension(DIM_VALUE, FieldSpec.DataType.STRING).addSingleValueDimension(SHARD_ID, FieldSpec.DataType.INT).addSingleValueDimension(THETA_SKETCH, FieldSpec.DataType.BYTES).build();
    }

    protected long getCountStarResult() {
        return 10L;
    }

    @Override // org.apache.pinot.integration.tests.custom.CustomDataQueryClusterIntegrationTest
    public List<File> createAvroFiles() throws IOException {
        org.apache.avro.Schema createRecord = org.apache.avro.Schema.createRecord("myRecord", (String) null, (String) null, false);
        createRecord.setFields(ImmutableList.of(new Schema.Field(DIM_NAME, org.apache.avro.Schema.create(Schema.Type.STRING), (String) null, (Object) null), new Schema.Field(DIM_VALUE, org.apache.avro.Schema.create(Schema.Type.STRING), (String) null, (Object) null), new Schema.Field(SHARD_ID, org.apache.avro.Schema.create(Schema.Type.INT), (String) null, (Object) null), new Schema.Field(THETA_SKETCH, org.apache.avro.Schema.create(Schema.Type.BYTES), (String) null, (Object) null)));
        File file = new File(this._tempDir, "data.avro");
        DataFileWriter dataFileWriter = new DataFileWriter(new GenericDatumWriter(createRecord));
        try {
            dataFileWriter.create(createRecord, file);
            int i = 0;
            int i2 = 50;
            for (int i3 = 0; i3 < 2; i3++) {
                String[] strArr = {"Female", "Male"};
                String[] strArr2 = {"Math", "History", "Biology"};
                HashMap hashMap = new HashMap();
                for (String str : strArr) {
                    for (String str2 : strArr2) {
                        List list = (List) hashMap.computeIfAbsent(ImmutablePair.of(str, str2), pair -> {
                            return new ArrayList();
                        });
                        for (int i4 = 0; i4 < i2; i4++) {
                            int i5 = i;
                            i++;
                            list.add(Integer.valueOf(i5));
                        }
                        i2 += 10;
                    }
                }
                for (String str3 : strArr) {
                    UpdateSketch build = new UpdateSketchBuilder().build();
                    hashMap.forEach((pair2, list2) -> {
                        if (str3.equals(pair2.getLeft())) {
                            Objects.requireNonNull(build);
                            list2.forEach((v1) -> {
                                r1.update(v1);
                            });
                        }
                    });
                    GenericData.Record record = new GenericData.Record(createRecord);
                    record.put(DIM_NAME, "gender");
                    record.put(DIM_VALUE, str3);
                    record.put(SHARD_ID, Integer.valueOf(i3));
                    record.put(THETA_SKETCH, ByteBuffer.wrap(build.compact().toByteArray()));
                    dataFileWriter.append(record);
                }
                for (String str4 : strArr2) {
                    UpdateSketch build2 = new UpdateSketchBuilder().build();
                    hashMap.forEach((pair3, list3) -> {
                        if (str4.equals(pair3.getRight())) {
                            Objects.requireNonNull(build2);
                            list3.forEach((v1) -> {
                                r1.update(v1);
                            });
                        }
                    });
                    GenericData.Record record2 = new GenericData.Record(createRecord);
                    record2.put(DIM_NAME, "course");
                    record2.put(DIM_VALUE, str4);
                    record2.put(SHARD_ID, Integer.valueOf(i3));
                    record2.put(THETA_SKETCH, ByteBuffer.wrap(build2.compact().toByteArray()));
                    dataFileWriter.append(record2);
                }
            }
            dataFileWriter.close();
            return List.of(file);
        } catch (Throwable th) {
            try {
                dataFileWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test(dataProvider = "useV1QueryEngine")
    public void testThetaSketchQueryV1(boolean z) throws Exception {
        setUseMultiStageQueryEngine(z);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' and dimValue = 'Female'", 540);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender'' and dimValue = ''Female''', '$1') from " + getTableName(), 540);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Female''', 'SET_INTERSECT($1, $2)') from " + getTableName(), 540);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' and dimValue = 'Male'", 720);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender'' and dimValue = ''Male''', '$1') from " + getTableName(), 720);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Male''', 'SET_INTERSECT($1, $2)') from " + getTableName(), 720);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'course' AND dimValue = 'Math'", 380);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''course'' and dimValue = ''Math''', '$1') from " + getTableName(), 380);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''course''', 'dimValue = ''Math''', 'SET_INTERSECT($1, $2)') from " + getTableName(), 380);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender'' and dimValue = ''Female''', 'dimName = ''course'' and dimValue = ''Math''', 'SET_INTERSECT($1, $2)') from " + getTableName(), 160);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Female''', 'dimName = ''course''', 'dimValue = ''Math''', 'SET_INTERSECT($1, $2, $3, $4)') from " + getTableName(), 160);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Female''', 'dimName = ''course''', 'dimValue = ''Math''', 'SET_INTERSECT(SET_INTERSECT($1, $2), SET_INTERSECT($3, $4))') from " + getTableName(), 160);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender'' and dimValue = ''Male''', 'dimName = ''course'' and dimValue = ''Biology''', 'SET_UNION($1, $2)') from " + getTableName(), 920);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Male''', 'dimName = ''course''', 'dimValue = ''Biology''', 'SET_UNION(SET_INTERSECT($1, $2), SET_INTERSECT($3, $4))') from " + getTableName(), 920);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender'' and dimValue = ''Female''', 'dimName = ''course'' and dimValue = ''History''', 'SET_DIFF($1, $2)') from " + getTableName(), 360);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol, '', 'dimName = ''gender''', 'dimValue = ''Female''', 'dimName = ''course''', 'dimValue = ''History''', 'SET_DIFF(SET_INTERSECT($1, $2), SET_INTERSECT($3, $4))') from " + getTableName(), 360);
        runAndAssert("select dimValue, distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' group by dimValue", (Map<String, Integer>) ImmutableMap.of("Female", 540, "Male", 720));
    }

    @Test(dataProvider = "useV2QueryEngine")
    public void testThetaSketchQueryV2(boolean z) throws Exception {
        setUseMultiStageQueryEngine(z);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' and dimValue = 'Female'", 540);
        runAndAssert("select getThetaSketchEstimate(distinctCountRAWThetaSketch(thetaSketchCol) FILTER (WHERE dimName = 'gender' and dimValue = 'Female')) from " + getTableName(), 540);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Female')))   FROM " + getTableName(), 540);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' and dimValue = 'Male'", 720);
        runAndAssert("select getThetaSketchEstimate(distinctCountRAWThetaSketch(thetaSketchCol) FILTER (WHERE dimName = 'gender' and dimValue = 'Male')) from " + getTableName(), 720);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Male')))   FROM " + getTableName(), 720);
        runAndAssert("select distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'course' AND dimValue = 'Math'", 380);
        runAndAssert("select getThetaSketchEstimate(distinctCountRAWThetaSketch(thetaSketchCol) FILTER (WHERE dimName = 'course' and dimValue = 'Math')) from " + getTableName(), 380);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'course'),    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Math')))   FROM " + getTableName(), 380);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'gender' and dimValue = 'Female'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'course' and dimValue = 'Math')))   FROM " + getTableName(), 160);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Female'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'course'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Math')))   FROM " + getTableName(), 160);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Female')),   THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'course'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Math'))))   FROM " + getTableName(), 160);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_UNION(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'gender' and dimValue = 'Male'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'course' and dimValue = 'Biology')))   FROM " + getTableName(), 920);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_UNION(  THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Male')),   THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'course'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Biology'))))   FROM " + getTableName(), 920);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_DIFF(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'gender' and dimValue = 'Female'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (        WHERE dimName = 'course' and dimValue = 'History')))   FROM " + getTableName(), 360);
        runAndAssert("select GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_DIFF(  THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'Female')),   THETA_SKETCH_INTERSECT(    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'course'),     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimValue = 'History'))))   FROM " + getTableName(), 360);
        runAndAssert("select dimValue, GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(     DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName = 'gender'),    DISTINCT_COUNT_RAW_THETA_SKETCH(thetaSketchCol, '') FILTER (WHERE dimName != 'gender')))   FROM " + getTableName() + " GROUP BY dimValue", (Map<String, Integer>) ImmutableMap.of("Female", 0, "Male", 0, "Math", 0, "History", 0, "Biology", 0));
        runAndAssert("select dimValue, distinctCountThetaSketch(thetaSketchCol) from " + getTableName() + " where dimName = 'gender' group by dimValue", (Map<String, Integer>) ImmutableMap.of("Female", 540, "Male", 720));
        runAndAssert("select dimValue, distinctCountThetaSketch(thetaSketchCol) from ( SELECT dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Female' UNION ALL SELECT dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Male' ) GROUP BY dimValue", (Map<String, Integer>) ImmutableMap.of("Female", 540, "Male", 720));
        runAndAssert("select a.dimValue, distinctCountThetaSketch(b.thetaSketchCol) FROM (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Female') a JOIN (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Male') b ON a.dimName = b.dimName GROUP BY a.dimValue", (Map<String, Integer>) ImmutableMap.of("Female", 720));
        runAndAssert("select b.dimValue, distinctCountThetaSketch(a.thetaSketchCol) FROM (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Female') a JOIN (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Male') b ON a.dimName = b.dimName GROUP BY b.dimValue", (Map<String, Integer>) ImmutableMap.of("Male", 540));
        JsonNode postQuery = postQuery("SELECT GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_INTERSECT(  DISTINCT_COUNT_RAW_THETA_SKETCH(a.thetaSketchCol, ''),   DISTINCT_COUNT_RAW_THETA_SKETCH(b.thetaSketchCol, ''))), GET_THETA_SKETCH_ESTIMATE(THETA_SKETCH_UNION(  DISTINCT_COUNT_RAW_THETA_SKETCH(a.thetaSketchCol, ''),   DISTINCT_COUNT_RAW_THETA_SKETCH(b.thetaSketchCol, ''))) FROM (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Female') a JOIN (SELECT dimName, dimValue, thetaSketchCol FROM " + getTableName() + " where dimName = 'gender' and dimValue = 'Male') b ON a.dimName = b.dimName");
        Assert.assertEquals(postQuery.get("resultTable").get("rows").get(0).get(0).longValue(), 0L);
        Assert.assertEquals(postQuery.get("resultTable").get("rows").get(0).get(1).longValue(), 1260L);
    }

    private void runAndAssert(String str, int i) throws Exception {
        Assert.assertEquals(Integer.parseInt(postQuery(str).get("resultTable").get("rows").get(0).get(0).asText()), i);
    }

    private void runAndAssert(String str, Map<String, Integer> map) throws Exception {
        HashMap hashMap = new HashMap();
        postQuery(str).get("resultTable").get("rows").forEach(jsonNode -> {
            hashMap.put(jsonNode.get(0).textValue(), Integer.valueOf(jsonNode.get(1).intValue()));
        });
        Assert.assertEquals(hashMap, map);
    }
}
