package org.apache.pinot.query.service;

import com.google.common.collect.Lists;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.core.transport.ServerInstance;
import org.apache.pinot.query.QueryEnvironment;
import org.apache.pinot.query.QueryEnvironmentTestBase;
import org.apache.pinot.query.QueryTestSet;
import org.apache.pinot.query.planner.QueryPlan;
import org.apache.pinot.query.planner.StageMetadata;
import org.apache.pinot.query.planner.stage.StageNode;
import org.apache.pinot.query.routing.WorkerInstance;
import org.apache.pinot.query.runtime.QueryRunner;
import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
import org.apache.pinot.query.runtime.plan.serde.QueryPlanSerDeUtils;
import org.apache.pinot.query.testutils.QueryTestUtils;
import org.apache.pinot.util.TestUtils;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/service/QueryServerTest.class */
public class QueryServerTest extends QueryTestSet {
    private static final Random RANDOM_REQUEST_ID_GEN = new Random();
    private static final int QUERY_SERVER_COUNT = 2;
    private static final String KEY_OF_SERVER_INSTANCE_HOST = "pinot.query.runner.server.hostname";
    private static final String KEY_OF_SERVER_INSTANCE_PORT = "pinot.query.runner.server.port";
    private final Map<Integer, QueryServer> _queryServerMap = new HashMap();
    private final Map<Integer, ServerInstance> _queryServerInstanceMap = new HashMap();
    private final Map<Integer, QueryRunner> _queryRunnerMap = new HashMap();
    private QueryEnvironment _queryEnvironment;

    @BeforeClass
    public void setUp() throws Exception {
        for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
            int availablePort = QueryTestUtils.getAvailablePort();
            QueryRunner queryRunner = (QueryRunner) Mockito.mock(QueryRunner.class);
            QueryServer queryServer = new QueryServer(availablePort, queryRunner);
            queryServer.start();
            this._queryServerMap.put(Integer.valueOf(availablePort), queryServer);
            this._queryRunnerMap.put(Integer.valueOf(availablePort), queryRunner);
            this._queryServerInstanceMap.put(Integer.valueOf(availablePort), new WorkerInstance("localhost", availablePort, availablePort, availablePort, availablePort));
        }
        ArrayList newArrayList = Lists.newArrayList(this._queryServerMap.keySet());
        this._queryEnvironment = QueryEnvironmentTestBase.getQueryEnvironment(1, ((Integer) newArrayList.get(0)).intValue(), ((Integer) newArrayList.get(1)).intValue(), QueryEnvironmentTestBase.TABLE_SCHEMAS, QueryEnvironmentTestBase.SERVER1_SEGMENTS, QueryEnvironmentTestBase.SERVER2_SEGMENTS);
    }

    @AfterClass
    public void tearDown() {
        Iterator<QueryServer> it = this._queryServerMap.values().iterator();
        while (it.hasNext()) {
            it.next().shutdown();
        }
    }

    @Test(dataProvider = "testSql")
    public void testWorkerAcceptsWorkerRequestCorrect(String str) throws Exception {
        QueryPlan planQuery = this._queryEnvironment.planQuery(str);
        Iterator it = planQuery.getStageMetadataMap().keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (intValue > 0) {
                Worker.QueryRequest queryRequest = getQueryRequest(planQuery, intValue);
                submitRequest(queryRequest);
                StageMetadata stageMetadata = (StageMetadata) planQuery.getStageMetadataMap().get(Integer.valueOf(intValue));
                QueryRunner queryRunner = this._queryRunnerMap.get(Integer.valueOf(Integer.parseInt(queryRequest.getMetadataOrThrow(KEY_OF_SERVER_INSTANCE_PORT))));
                String metadataOrThrow = queryRequest.getMetadataOrThrow("pinot.query.runner.broker.request.id");
                TestUtils.waitForCondition(r10 -> {
                    try {
                        ((QueryRunner) Mockito.verify(queryRunner)).processQuery((DistributedStagePlan) Mockito.argThat(distributedStagePlan -> {
                            return isStageNodesEqual((StageNode) planQuery.getQueryStageMap().get(Integer.valueOf(intValue)), distributedStagePlan.getStageRoot()) && isMetadataMapsEqual(stageMetadata, (StageMetadata) distributedStagePlan.getMetadataMap().get(Integer.valueOf(intValue)));
                        }), (Map) Mockito.argThat(map -> {
                            return metadataOrThrow.equals(map.get("pinot.query.runner.broker.request.id"));
                        }));
                        return true;
                    } catch (Throwable th) {
                        return false;
                    }
                }, 10000L, "Error verifying mock QueryRunner intercepted query payload!");
            }
        }
    }

    private static boolean isMetadataMapsEqual(StageMetadata stageMetadata, StageMetadata stageMetadata2) {
        return stageMetadata.getServerInstances().equals(stageMetadata2.getServerInstances()) && stageMetadata.getServerInstanceToSegmentsMap().equals(stageMetadata2.getServerInstanceToSegmentsMap()) && stageMetadata.getScannedTables().equals(stageMetadata2.getScannedTables());
    }

    private static boolean isStageNodesEqual(StageNode stageNode, StageNode stageNode2) {
        if (stageNode.getStageId() != stageNode2.getStageId() || stageNode.getClass() != stageNode2.getClass() || stageNode.getInputs().size() != stageNode2.getInputs().size()) {
            return false;
        }
        stageNode.getInputs().sort(Comparator.comparingInt((v0) -> {
            return v0.getStageId();
        }));
        stageNode2.getInputs().sort(Comparator.comparingInt((v0) -> {
            return v0.getStageId();
        }));
        for (int i = 0; i < stageNode.getInputs().size(); i++) {
            if (!isStageNodesEqual((StageNode) stageNode.getInputs().get(i), (StageNode) stageNode2.getInputs().get(i))) {
                return false;
            }
        }
        return true;
    }

    private void submitRequest(Worker.QueryRequest queryRequest) {
        ManagedChannel build = ManagedChannelBuilder.forAddress((String) queryRequest.getMetadataMap().get(KEY_OF_SERVER_INSTANCE_HOST), Integer.parseInt((String) queryRequest.getMetadataMap().get(KEY_OF_SERVER_INSTANCE_PORT))).usePlaintext().build();
        Assert.assertNotNull(PinotQueryWorkerGrpc.newBlockingStub(build).submit(queryRequest).getMetadataMap().get("OK"));
        build.shutdown();
    }

    private Worker.QueryRequest getQueryRequest(QueryPlan queryPlan, int i) {
        ServerInstance serverInstance = (ServerInstance) ((StageMetadata) queryPlan.getStageMetadataMap().get(Integer.valueOf(i))).getServerInstances().get(0);
        return Worker.QueryRequest.newBuilder().setStagePlan(QueryPlanSerDeUtils.serialize(QueryDispatcher.constructDistributedStagePlan(queryPlan, i, serverInstance))).putMetadata("pinot.query.runner.broker.request.id", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong())).putMetadata("pinot.query.runner.broker.request.timeout.ms", String.valueOf(10000L)).putMetadata(KEY_OF_SERVER_INSTANCE_HOST, serverInstance.getHostname()).putMetadata(KEY_OF_SERVER_INSTANCE_PORT, String.valueOf(serverInstance.getPort())).build();
    }
}
