package org.apache.pinot.query;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.sql.SqlNode;
import org.apache.pinot.query.context.PlannerContext;
import org.apache.pinot.query.planner.PlannerUtils;
import org.apache.pinot.query.planner.QueryPlan;
import org.apache.pinot.query.planner.StageMetadata;
import org.apache.pinot.query.planner.stage.AbstractStageNode;
import org.apache.pinot.query.planner.stage.AggregateNode;
import org.apache.pinot.query.planner.stage.FilterNode;
import org.apache.pinot.query.planner.stage.JoinNode;
import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
import org.apache.pinot.query.planner.stage.ProjectNode;
import org.apache.pinot.query.planner.stage.StageNode;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/QueryCompilationTest.class */
public class QueryCompilationTest extends QueryEnvironmentTestBase {
    @Test(dataProvider = "testQueryParserDataProvider")
    public void testQueryParser(String str, String str2) throws Exception {
        SqlNode parse = this._queryEnvironment.parse(str, new PlannerContext());
        this._queryEnvironment.validate(parse);
        Assert.assertEquals(parse.toString(), str2);
    }

    @Test(dataProvider = "testQueryDataProvider")
    public void testQueryPlanWithoutException(String str) throws Exception {
        try {
            Assert.assertNotNull(this._queryEnvironment.planQuery(str));
        } catch (RuntimeException e) {
            Assert.fail("failed to plan query: " + str, e);
        }
    }

    @Test(dataProvider = "testQueryExceptionDataProvider")
    public void testQueryWithException(String str, String str2) {
        try {
            this._queryEnvironment.planQuery(str);
            Assert.fail("query plan should throw exception");
        } catch (RuntimeException e) {
            Assert.assertTrue(e.getCause().getMessage().contains(str2));
        }
    }

    @Test
    public void testQueryGroupByAfterJoinShouldNotDoDataShuffle() throws Exception {
        QueryPlan planQuery = this._queryEnvironment.planQuery("SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2  WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col2");
        Assert.assertEquals(planQuery.getQueryStageMap().size(), 5);
        Assert.assertEquals(planQuery.getStageMetadataMap().size(), 5);
        for (Map.Entry entry : planQuery.getStageMetadataMap().entrySet()) {
            if (((StageMetadata) entry.getValue()).getScannedTables().size() == 0 && !PlannerUtils.isRootStage(((Integer) entry.getKey()).intValue())) {
                Object obj = planQuery.getQueryStageMap().get(entry.getKey());
                while (true) {
                    StageNode stageNode = (StageNode) obj;
                    if (stageNode == null) {
                        break;
                    }
                    if (stageNode instanceof JoinNode) {
                        MailboxReceiveNode mailboxReceiveNode = (MailboxReceiveNode) stageNode.getInputs().get(0);
                        MailboxReceiveNode mailboxReceiveNode2 = (MailboxReceiveNode) stageNode.getInputs().get(1);
                        Assert.assertEquals(mailboxReceiveNode.getExchangeType(), RelDistribution.Type.HASH_DISTRIBUTED);
                        Assert.assertEquals(mailboxReceiveNode2.getExchangeType(), RelDistribution.Type.HASH_DISTRIBUTED);
                        break;
                    }
                    if ((stageNode instanceof AggregateNode) && (stageNode.getInputs().get(0) instanceof MailboxReceiveNode)) {
                        Assert.assertEquals(((MailboxReceiveNode) stageNode.getInputs().get(0)).getExchangeType(), RelDistribution.Type.SINGLETON);
                        break;
                    }
                    obj = stageNode.getInputs().get(0);
                }
            }
        }
    }

    @Test
    public void testQueryAndAssertStageContentForJoin() throws Exception {
        QueryPlan planQuery = this._queryEnvironment.planQuery("SELECT * FROM a JOIN b ON a.col1 = b.col2");
        Assert.assertEquals(planQuery.getQueryStageMap().size(), 4);
        Assert.assertEquals(planQuery.getStageMetadataMap().size(), 4);
        for (Map.Entry entry : planQuery.getStageMetadataMap().entrySet()) {
            List scannedTables = ((StageMetadata) entry.getValue()).getScannedTables();
            if (scannedTables.size() == 1) {
                Assert.assertEquals((Collection) ((StageMetadata) entry.getValue()).getServerInstances().stream().map((v0) -> {
                    return v0.toString();
                }).collect(Collectors.toList()), ((String) scannedTables.get(0)).equals("a") ? ImmutableList.of("Server_localhost_1", "Server_localhost_2") : ImmutableList.of("Server_localhost_1"));
            } else if (PlannerUtils.isRootStage(((Integer) entry.getKey()).intValue())) {
                Assert.assertEquals((Collection) ((StageMetadata) entry.getValue()).getServerInstances().stream().map((v0) -> {
                    return v0.toString();
                }).collect(Collectors.toList()), ImmutableList.of("Server_localhost_3"));
            } else {
                Assert.assertEquals((Collection) ((StageMetadata) entry.getValue()).getServerInstances().stream().map((v0) -> {
                    return v0.toString();
                }).collect(Collectors.toList()), ImmutableList.of("Server_localhost_1", "Server_localhost_2"));
            }
        }
    }

    @Test
    public void testQueryProjectFilterPushDownForJoin() {
        QueryPlan planQuery = this._queryEnvironment.planQuery("SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 IN  ('a', 'b') AND b.col3 < 0");
        Iterator it = ((List) planQuery.getStageMetadataMap().entrySet().stream().filter(entry -> {
            return ((StageMetadata) entry.getValue()).getScannedTables().size() == 0;
        }).map(entry2 -> {
            return (StageNode) planQuery.getQueryStageMap().get(entry2.getKey());
        }).collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            assertNodeTypeNotIn((StageNode) it.next(), ImmutableList.of(ProjectNode.class, FilterNode.class));
        }
    }

    private static void assertNodeTypeNotIn(StageNode stageNode, List<Class<? extends AbstractStageNode>> list) {
        Assert.assertFalse(isOneOf(list, stageNode));
        Iterator it = stageNode.getInputs().iterator();
        while (it.hasNext()) {
            assertNodeTypeNotIn((StageNode) it.next(), list);
        }
    }

    private static boolean isOneOf(List<Class<? extends AbstractStageNode>> list, StageNode stageNode) {
        Iterator<Class<? extends AbstractStageNode>> it = list.iterator();
        while (it.hasNext()) {
            if (stageNode.getClass() == it.next()) {
                return true;
            }
        }
        return false;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "testQueryParserDataProvider")
    private Object[][] provideQueriesAndDigest() {
        return new Object[]{new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0", "SELECT *\nFROM `a`\nINNER JOIN `b` ON `a`.`col1` = `b`.`col2`\nWHERE `a`.`col3` >= 0"}};
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "testQueryExceptionDataProvider")
    private Object[][] provideQueriesWithException() {
        return new Object[]{new Object[]{"SELECT b.col1 - a.col3 FROM a JOIN c ON a.col1 = c.col3", "Table 'b' not found"}, new Object[]{"SELECT a.col1, SUM(a.col3) FROM a", "'a.col1' is not being grouped"}};
    }
}
