package org.apache.pinot.core.query.optimizer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.FilterKind;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/query/optimizer/QueryOptimizerTest.class */
public class QueryOptimizerTest {
    private static final QueryOptimizer OPTIMIZER = new QueryOptimizer();
    private static final Schema SCHEMA = new Schema.SchemaBuilder().setSchemaName("testTable").addSingleValueDimension("int", FieldSpec.DataType.INT).addSingleValueDimension("long", FieldSpec.DataType.LONG).addSingleValueDimension("float", FieldSpec.DataType.FLOAT).addSingleValueDimension("double", FieldSpec.DataType.DOUBLE).addSingleValueDimension("string", FieldSpec.DataType.STRING).addSingleValueDimension("bytes", FieldSpec.DataType.BYTES).addMultiValueDimension("mvInt", FieldSpec.DataType.INT).build();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.pinot.core.query.optimizer.QueryOptimizerTest$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/pinot/core/query/optimizer/QueryOptimizerTest$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$pinot$sql$FilterKind = new int[FilterKind.values().length];

        static {
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.GREATER_THAN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.GREATER_THAN_OR_EQUAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.LESS_THAN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.LESS_THAN_OR_EQUAL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.BETWEEN.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$apache$pinot$sql$FilterKind[FilterKind.RANGE.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    @Test
    public void testNoFilter() {
        PinotQuery compileToPinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT * FROM testTable");
        OPTIMIZER.optimize(compileToPinotQuery, SCHEMA);
        Assert.assertNull(compileToPinotQuery.getFilterExpression());
    }

    @Test
    public void testFlattenAndOrFilter() {
        PinotQuery compileToPinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT * FROM testTable WHERE ((int = 4 OR (long = 5 AND (float = 9 AND double = 7.5))) OR string = 'foo') OR bytes = 'abc'");
        OPTIMIZER.optimize(compileToPinotQuery, SCHEMA);
        Function functionCall = compileToPinotQuery.getFilterExpression().getFunctionCall();
        Assert.assertEquals(functionCall.getOperator(), FilterKind.OR.name());
        List operands = functionCall.getOperands();
        Assert.assertEquals(operands.size(), 4);
        Assert.assertEquals(operands.get(0), getEqFilterExpression("int", 4));
        Assert.assertEquals(operands.get(2), getEqFilterExpression("string", "foo"));
        Assert.assertEquals(operands.get(3), getEqFilterExpression("bytes", "abc"));
        Function functionCall2 = ((Expression) operands.get(1)).getFunctionCall();
        Assert.assertEquals(functionCall2.getOperator(), FilterKind.AND.name());
        List operands2 = functionCall2.getOperands();
        Assert.assertEquals(operands2.size(), 3);
        Assert.assertEquals(operands2.get(0), getEqFilterExpression("long", 5L));
        Assert.assertEquals(operands2.get(1), getEqFilterExpression("float", Float.valueOf(9.0f)));
        Assert.assertEquals(operands2.get(2), getEqFilterExpression("double", Double.valueOf(7.5d)));
    }

    private static Expression getEqFilterExpression(String str, Object obj) {
        Expression functionExpression = RequestUtils.getFunctionExpression(FilterKind.EQUALS.name());
        functionExpression.getFunctionCall().setOperands(Arrays.asList(RequestUtils.getIdentifierExpression(str), RequestUtils.getLiteralExpression(obj)));
        return functionExpression;
    }

    @Test
    public void testMergeEqInFilter() {
        PinotQuery compileToPinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT * FROM testTable WHERE int IN (1, 1) AND (long IN (2, 3) OR long IN (3, 4) OR long = 2) AND (float = 3.5 OR double IN (1.1, 1.2) OR float = 4.5 OR float > 5.5 OR double = 1.3)");
        OPTIMIZER.optimize(compileToPinotQuery, SCHEMA);
        Function functionCall = compileToPinotQuery.getFilterExpression().getFunctionCall();
        Assert.assertEquals(functionCall.getOperator(), FilterKind.AND.name());
        List operands = functionCall.getOperands();
        Assert.assertEquals(operands.size(), 3);
        Assert.assertEquals(operands.get(0), getEqFilterExpression("int", 1));
        checkInFilterFunction(((Expression) operands.get(1)).getFunctionCall(), "long", Arrays.asList(2L, 3L, 4L));
        Function functionCall2 = ((Expression) operands.get(2)).getFunctionCall();
        Assert.assertEquals(functionCall2.getOperator(), FilterKind.OR.name());
        List operands2 = functionCall2.getOperands();
        Assert.assertEquals(operands2.size(), 3);
        Assert.assertEquals(((Expression) operands2.get(0)).getFunctionCall().getOperator(), FilterKind.GREATER_THAN.name());
        Function functionCall3 = ((Expression) operands2.get(1)).getFunctionCall();
        Assert.assertEquals(functionCall3.getOperator(), FilterKind.IN.name());
        Function functionCall4 = ((Expression) operands2.get(2)).getFunctionCall();
        Assert.assertEquals(functionCall4.getOperator(), FilterKind.IN.name());
        if (((Expression) functionCall3.getOperands().get(0)).getIdentifier().getName().equals("float")) {
            checkInFilterFunction(functionCall3, "float", Arrays.asList(Double.valueOf(3.5d), Double.valueOf(4.5d)));
            checkInFilterFunction(functionCall4, "double", Arrays.asList(Double.valueOf(1.1d), Double.valueOf(1.2d), Double.valueOf(1.3d)));
        } else {
            checkInFilterFunction(functionCall3, "double", Arrays.asList(Double.valueOf(1.1d), Double.valueOf(1.2d), Double.valueOf(1.3d)));
            checkInFilterFunction(functionCall4, "float", Arrays.asList(Double.valueOf(3.5d), Double.valueOf(4.5d)));
        }
    }

    private static void checkInFilterFunction(Function function, String str, List<Object> list) {
        Assert.assertEquals(function.getOperator(), FilterKind.IN.name());
        List operands = function.getOperands();
        int size = operands.size();
        Assert.assertEquals(size, list.size() + 1);
        Assert.assertEquals(((Expression) operands.get(0)).getIdentifier().getName(), str);
        HashSet hashSet = new HashSet();
        Iterator<Object> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(RequestUtils.getLiteralExpression(it.next()));
        }
        for (int i = 1; i < size; i++) {
            Assert.assertTrue(hashSet.contains(operands.get(i)));
        }
    }

    @Test
    public void testMergeRangeFilter() {
        PinotQuery compileToPinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT * FROM testTable WHERE (int > 10 AND int <= 100 AND int BETWEEN 10 AND 20) OR (float BETWEEN 5.5 AND 7.5 AND float = 6 AND float < 6.5 AND float BETWEEN 6 AND 8) OR (string > '123' AND string > '23') OR (mvInt > 5 AND mvInt < 0)");
        OPTIMIZER.optimize(compileToPinotQuery, SCHEMA);
        Function functionCall = compileToPinotQuery.getFilterExpression().getFunctionCall();
        Assert.assertEquals(functionCall.getOperator(), FilterKind.OR.name());
        List operands = functionCall.getOperands();
        Assert.assertEquals(operands.size(), 4);
        Assert.assertEquals(operands.get(0), getRangeFilterExpression("int", "(10��20]"));
        Assert.assertEquals(operands.get(2), getRangeFilterExpression("string", "(23��*)"));
        Function functionCall2 = ((Expression) operands.get(1)).getFunctionCall();
        Assert.assertEquals(functionCall2.getOperator(), FilterKind.AND.name());
        List operands2 = functionCall2.getOperands();
        Assert.assertEquals(operands2.size(), 2);
        Assert.assertEquals(operands2.get(0), getEqFilterExpression("float", Float.valueOf(6.0f)));
        Assert.assertEquals(operands2.get(1), getRangeFilterExpression("float", "[6.0��6.5)"));
        Function functionCall3 = ((Expression) operands.get(3)).getFunctionCall();
        Assert.assertEquals(functionCall3.getOperator(), FilterKind.AND.name());
        List operands3 = functionCall3.getOperands();
        Assert.assertEquals(operands3.size(), 2);
        Assert.assertEquals(((Expression) operands3.get(0)).getFunctionCall().getOperator(), FilterKind.GREATER_THAN.name());
        Assert.assertEquals(((Expression) operands3.get(1)).getFunctionCall().getOperator(), FilterKind.LESS_THAN.name());
    }

    private static Expression getRangeFilterExpression(String str, String str2) {
        Expression functionExpression = RequestUtils.getFunctionExpression(FilterKind.RANGE.name());
        functionExpression.getFunctionCall().setOperands(Arrays.asList(RequestUtils.getIdentifierExpression(str), RequestUtils.getLiteralExpression(str2)));
        return functionExpression;
    }

    @Test
    public void testQueries() {
        testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3", "SELECT * FROM testTable WHERE int IN (1, 2, 3)");
        testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3 AND long = 4", "SELECT * FROM testTable WHERE int IN (1, 2) OR (int = 3 AND long = 4)");
        testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3 OR long = 4 OR long = 5 OR long = 6", "SELECT * FROM testTable WHERE int IN (1, 2, 3) OR long IN (4, 5, 6)");
        testQuery("SELECT * FROM testTable WHERE int = 1 OR long = 4 OR int = 2 OR long = 5 OR int = 3 OR long = 6", "SELECT * FROM testTable WHERE int IN (1, 2, 3) OR long IN (4, 5, 6)");
        testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 1", "SELECT * FROM testTable WHERE int = 1");
        testQuery("SELECT * FROM testTable WHERE (int = 1 OR int = 1) AND long = 2", "SELECT * FROM testTable WHERE int = 1 AND long = 2");
        testQuery("SELECT * FROM testTable WHERE int = 1 OR int IN (2, 3, 4, 5)", "SELECT * FROM testTable WHERE int IN (1, 2, 3, 4, 5)");
        testQuery("SELECT * FROM testTable WHERE int IN (1, 1) OR int = 1", "SELECT * FROM testTable WHERE int = 1");
        testQuery("SELECT * FROM testTable WHERE string = 'foo' OR string = 'bar' OR string = 'foobar'", "SELECT * FROM testTable WHERE string IN ('foo', 'bar', 'foobar')");
        testQuery("SELECT * FROM testTable WHERE bytes = 'dead' OR bytes = 'beef' OR bytes = 'deadbeef'", "SELECT * FROM testTable WHERE bytes IN ('dead', 'beef', 'deadbeef')");
        testQuery("SELECT * FROM testTable WHERE int >= 10 AND int <= 20", "SELECT * FROM testTable WHERE int BETWEEN 10 AND 20");
        testQuery("SELECT * FROM testTable WHERE int BETWEEN 10 AND 20 AND int > 7 AND int <= 17 OR int > 20", "SELECT * FROM testTable WHERE int BETWEEN 10 AND 17 OR int > 20");
        testQuery("SELECT * FROM testTable WHERE long BETWEEN 10 AND 20 AND long > 7 AND long <= 17 OR long > 20", "SELECT * FROM testTable WHERE long BETWEEN 10 AND 17 OR long > 20");
        testQuery("SELECT * FROM testTable WHERE float BETWEEN 10.5 AND 20 AND float > 7 AND float <= 17.5 OR float > 20", "SELECT * FROM testTable WHERE float BETWEEN 10.5 AND 17.5 OR float > 20");
        testQuery("SELECT * FROM testTable WHERE double BETWEEN 10.5 AND 20 AND double > 7 AND double <= 17.5 OR double > 20", "SELECT * FROM testTable WHERE double BETWEEN 10.5 AND 17.5 OR double > 20");
        testQuery("SELECT * FROM testTable WHERE string BETWEEN '10' AND '20' AND string > '7' AND string <= '17' OR string > '20'", "SELECT * FROM testTable WHERE string > '7' AND string <= '17' OR string > '20'");
        testQuery("SELECT * FROM testTable WHERE bytes BETWEEN '10' AND '20' AND bytes > '07' AND bytes <= '17' OR bytes > '20'", "SELECT * FROM testTable WHERE bytes BETWEEN '10' AND '17' OR bytes > '20'");
        testQuery("SELECT * FROM testTable WHERE int > 10 AND long > 20 AND int <= 30 AND long <= 40 AND int >= 15 AND long >= 25", "SELECT * FROM testTable WHERE int BETWEEN 15 AND 30 AND long BETWEEN 25 AND 40");
        testQuery("SELECT * FROM testTable WHERE int > 10 AND int > 20 OR int < 30 AND int < 40", "SELECT * FROM testTable WHERE int > 20 OR int < 30");
        testQuery("SELECT * FROM testTable WHERE int > 10 AND int > 20 OR long < 30 AND long < 40", "SELECT * FROM testTable WHERE int > 20 OR long < 30");
        testQuery("SELECT * FROM testTable WHERE int >= 20 AND (int > 10 AND (int IN (1, 2) OR (int = 2 OR int = 3)) AND int <= 30)", "SELECT * FROM testTable WHERE int BETWEEN 20 AND 30 AND int IN (1, 2, 3)");
        testQuery("SELECT * FROM testTable WHERE 1=1", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1!=1", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE 1=1 AND 1!=1", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE 1=1 OR 1!=1", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE \"a\"=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE \"a\"!=\"a\"", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE \"a\"=\"a\" AND \"a\"!=\"a\"", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE \"a\"=\"a\" OR \"a\"!=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1=1 AND \"a\"=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1=1 OR \"a\"=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1!=1 AND \"a\"=\"a\"", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE 1=1 AND \"a\"!=\"a\"", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE 1!=1 OR \"a\"=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1=1 OR \"a\"!=\"a\"", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1.0=1.0", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1.0=1", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE 1.01=1", "SELECT * FROM testTable WHERE false");
        testQuery("SELECT * FROM testTable WHERE 1=1 AND true", "SELECT * FROM testTable WHERE true");
        testQuery("SELECT * FROM testTable WHERE \"a\"=\"a\" AND true", "SELECT * FROM testTable WHERE true");
    }

    private static void testQuery(String str, String str2) {
        Assert.assertNotEquals(str, str2, "You must provide different queries to test");
        PinotQuery compileToPinotQuery = CalciteSqlParser.compileToPinotQuery(str);
        OPTIMIZER.optimize(compileToPinotQuery, SCHEMA);
        PinotQuery compileToPinotQuery2 = CalciteSqlParser.compileToPinotQuery(str2);
        OPTIMIZER.optimize(compileToPinotQuery2, SCHEMA);
        comparePinotQuery(compileToPinotQuery, compileToPinotQuery2);
    }

    private static void comparePinotQuery(PinotQuery pinotQuery, PinotQuery pinotQuery2) {
        if (pinotQuery2.getFilterExpression() == null) {
            Assert.assertNull(pinotQuery.getFilterExpression());
        } else {
            compareFilterExpression(pinotQuery.getFilterExpression(), pinotQuery2.getFilterExpression());
        }
    }

    private static void compareFilterExpression(Expression expression, Expression expression2) {
        if (expression.isSetLiteral()) {
            Assert.assertNull(expression.getFunctionCall());
            Assert.assertNull(expression2.getFunctionCall());
            Assert.assertTrue(expression2.isSetLiteral());
            Assert.assertEquals(expression.getLiteral(), expression2.getLiteral());
            return;
        }
        Function functionCall = expression.getFunctionCall();
        Function functionCall2 = expression2.getFunctionCall();
        FilterKind valueOf = FilterKind.valueOf(functionCall.getOperator());
        FilterKind valueOf2 = FilterKind.valueOf(functionCall2.getOperator());
        List operands = functionCall.getOperands();
        List operands2 = functionCall2.getOperands();
        if (valueOf.isRange()) {
            Assert.assertTrue(valueOf2.isRange());
            Assert.assertEquals(getRangeString(valueOf, operands), getRangeString(valueOf2, operands2));
            return;
        }
        Assert.assertEquals(valueOf, valueOf2);
        Assert.assertEquals(operands.size(), operands2.size());
        if (valueOf == FilterKind.AND || valueOf == FilterKind.OR) {
            compareFilterExpressionChildren(operands, operands2);
            return;
        }
        Assert.assertEquals(operands.get(0), operands2.get(0));
        if (valueOf == FilterKind.IN || valueOf == FilterKind.NOT_IN) {
            Assert.assertEqualsNoOrder(operands.toArray(), operands2.toArray());
        } else {
            Assert.assertEquals(operands, operands2);
        }
    }

    private static void compareFilterExpressionChildren(List<Expression> list, List<Expression> list2) {
        Assert.assertEquals(list.size(), list2.size());
        ArrayList arrayList = new ArrayList(list2);
        for (Expression expression : list) {
            Iterator it = arrayList.iterator();
            boolean z = false;
            while (it.hasNext()) {
                try {
                    compareFilterExpression(expression, (Expression) it.next());
                    it.remove();
                    z = true;
                    break;
                } catch (AssertionError e) {
                }
            }
            if (!z) {
                Assert.fail("Failed to find matching child");
            }
        }
    }

    private static String getRangeString(FilterKind filterKind, List<Expression> list) {
        switch (AnonymousClass1.$SwitchMap$org$apache$pinot$sql$FilterKind[filterKind.ordinal()]) {
            case 1:
                return "(" + list.get(1).getLiteral().getFieldValue().toString() + "��*)";
            case 2:
                return "[" + list.get(1).getLiteral().getFieldValue().toString() + "��*)";
            case 3:
                return "(*��" + list.get(1).getLiteral().getFieldValue().toString() + ")";
            case 4:
                return "(*��" + list.get(1).getLiteral().getFieldValue().toString() + "]";
            case 5:
                return "[" + list.get(1).getLiteral().getFieldValue().toString() + "��" + list.get(2).getLiteral().getFieldValue().toString() + "]";
            case 6:
                return list.get(1).getLiteral().getStringValue();
            default:
                throw new IllegalStateException();
        }
    }
}
