package org.apache.pinot.query.service.server;

import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Plan;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.core.routing.TimeBoundaryInfo;
import org.apache.pinot.query.QueryEnvironment;
import org.apache.pinot.query.QueryEnvironmentTestBase;
import org.apache.pinot.query.QueryTestSet;
import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
import org.apache.pinot.query.routing.QueryPlanSerDeUtils;
import org.apache.pinot.query.routing.QueryServerInstance;
import org.apache.pinot.query.routing.StageMetadata;
import org.apache.pinot.query.routing.StagePlan;
import org.apache.pinot.query.routing.WorkerMetadata;
import org.apache.pinot.query.runtime.QueryRunner;
import org.apache.pinot.query.testutils.QueryTestUtils;
import org.apache.pinot.spi.utils.EqualityUtils;
import org.apache.pinot.util.TestUtils;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/service/server/QueryServerTest.class */
public class QueryServerTest extends QueryTestSet {
    private static final Random RANDOM_REQUEST_ID_GEN = new Random();
    private static final int QUERY_SERVER_COUNT = 2;
    private static final String KEY_OF_SERVER_INSTANCE_HOST = "pinot.query.runner.server.hostname";
    private static final String KEY_OF_SERVER_INSTANCE_PORT = "pinot.query.runner.server.port";
    private final Map<Integer, QueryServer> _queryServerMap = new HashMap();
    private final Map<Integer, QueryRunner> _queryRunnerMap = new HashMap();
    private QueryEnvironment _queryEnvironment;

    @BeforeClass
    public void setUp() throws Exception {
        for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
            int availablePort = QueryTestUtils.getAvailablePort();
            QueryRunner queryRunner = (QueryRunner) Mockito.mock(QueryRunner.class);
            QueryServer queryServer = new QueryServer(availablePort, queryRunner);
            queryServer.start();
            this._queryServerMap.put(Integer.valueOf(availablePort), queryServer);
            this._queryRunnerMap.put(Integer.valueOf(availablePort), queryRunner);
        }
        ArrayList newArrayList = Lists.newArrayList(this._queryServerMap.keySet());
        this._queryEnvironment = QueryEnvironmentTestBase.getQueryEnvironment(1, ((Integer) newArrayList.get(0)).intValue(), ((Integer) newArrayList.get(1)).intValue(), QueryEnvironmentTestBase.TABLE_SCHEMAS, QueryEnvironmentTestBase.SERVER1_SEGMENTS, QueryEnvironmentTestBase.SERVER2_SEGMENTS, (Map) null);
    }

    @AfterClass
    public void tearDown() {
        Iterator<QueryServer> it = this._queryServerMap.values().iterator();
        while (it.hasNext()) {
            it.next().shutdown();
        }
    }

    @Test
    public void testException() throws Exception {
        Worker.QueryRequest queryRequest = getQueryRequest(this._queryEnvironment.planQuery("SELECT * FROM a"), 1);
        Map<String, String> fromProtoProperties = QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata());
        QueryRunner queryRunner = this._queryRunnerMap.get(Integer.valueOf(Integer.parseInt(fromProtoProperties.get(KEY_OF_SERVER_INSTANCE_PORT))));
        ((QueryRunner) Mockito.doThrow(new Throwable[]{new RuntimeException("foo")}).when(queryRunner)).processQuery((WorkerMetadata) ArgumentMatchers.any(), (StagePlan) ArgumentMatchers.any(), (Map) ArgumentMatchers.any());
        Worker.QueryResponse submitRequest = submitRequest(queryRequest, fromProtoProperties);
        Mockito.reset(new QueryRunner[]{queryRunner});
        Assert.assertTrue(((String) submitRequest.getMetadataMap().get("ERROR")).contains("foo"));
    }

    @Test(dataProvider = "testSql")
    public void testWorkerAcceptsWorkerRequestCorrect(String str) throws Exception {
        DispatchableSubPlan planQuery = this._queryEnvironment.planQuery(str);
        List queryStageList = planQuery.getQueryStageList();
        int size = queryStageList.size();
        for (int i = 1; i < size; i++) {
            Worker.QueryRequest queryRequest = getQueryRequest(planQuery, i);
            Map<String, String> fromProtoProperties = QueryPlanSerDeUtils.fromProtoProperties(queryRequest.getMetadata());
            Assert.assertTrue(submitRequest(queryRequest, fromProtoProperties).getMetadataMap().containsKey("OK"));
            DispatchablePlanFragment dispatchablePlanFragment = (DispatchablePlanFragment) queryStageList.get(i);
            List workerMetadataList = dispatchablePlanFragment.getWorkerMetadataList();
            StageMetadata stageMetadata = new StageMetadata(i, workerMetadataList, dispatchablePlanFragment.getCustomProperties());
            QueryRunner queryRunner = this._queryRunnerMap.get(Integer.valueOf(Integer.parseInt(fromProtoProperties.get(KEY_OF_SERVER_INSTANCE_PORT))));
            String str2 = fromProtoProperties.get("requestId");
            TestUtils.waitForCondition(r12 -> {
                try {
                    ((QueryRunner) Mockito.verify(queryRunner, Mockito.times(workerMetadataList.size()))).processQuery((WorkerMetadata) ArgumentMatchers.any(), (StagePlan) ArgumentMatchers.argThat(stagePlan -> {
                        return isStageNodesEqual(dispatchablePlanFragment.getPlanFragment().getFragmentRoot(), stagePlan.getRootNode()) && isStageMetadataEqual(stageMetadata, stagePlan.getStageMetadata());
                    }), (Map) ArgumentMatchers.argThat(map -> {
                        return str2.equals(map.get("requestId"));
                    }));
                    return true;
                } catch (Throwable th) {
                    return false;
                }
            }, 10000L, "Error verifying mock QueryRunner intercepted query payload!");
            Mockito.reset(new QueryRunner[]{queryRunner});
        }
    }

    private boolean isStageMetadataEqual(StageMetadata stageMetadata, StageMetadata stageMetadata2) {
        if (!Objects.equals(stageMetadata.getTableName(), stageMetadata2.getTableName())) {
            return false;
        }
        TimeBoundaryInfo timeBoundary = stageMetadata.getTimeBoundary();
        TimeBoundaryInfo timeBoundary2 = stageMetadata2.getTimeBoundary();
        if (!(timeBoundary == null && timeBoundary2 == null) && (timeBoundary == null || timeBoundary2 == null || !timeBoundary.getTimeColumn().equals(timeBoundary2.getTimeColumn()) || !timeBoundary.getTimeValue().equals(timeBoundary2.getTimeValue()))) {
            return false;
        }
        List workerMetadataList = stageMetadata.getWorkerMetadataList();
        List workerMetadataList2 = stageMetadata2.getWorkerMetadataList();
        if (workerMetadataList.size() != workerMetadataList2.size()) {
            return false;
        }
        for (int i = 0; i < workerMetadataList.size(); i++) {
            if (!isWorkerMetadataEqual((WorkerMetadata) workerMetadataList.get(i), (WorkerMetadata) workerMetadataList2.get(i))) {
                return false;
            }
        }
        return true;
    }

    private static boolean isWorkerMetadataEqual(WorkerMetadata workerMetadata, WorkerMetadata workerMetadata2) {
        return workerMetadata.getWorkerId() == workerMetadata2.getWorkerId() && EqualityUtils.isEqual(workerMetadata.getTableSegmentsMap(), workerMetadata2.getTableSegmentsMap());
    }

    private static boolean isStageNodesEqual(PlanNode planNode, PlanNode planNode2) {
        if (planNode.getPlanFragmentId() != planNode2.getPlanFragmentId() || planNode.getClass() != planNode2.getClass() || planNode.getInputs().size() != planNode2.getInputs().size()) {
            return false;
        }
        planNode.getInputs().sort(Comparator.comparingInt((v0) -> {
            return v0.getPlanFragmentId();
        }));
        planNode2.getInputs().sort(Comparator.comparingInt((v0) -> {
            return v0.getPlanFragmentId();
        }));
        for (int i = 0; i < planNode.getInputs().size(); i++) {
            if (!isStageNodesEqual((PlanNode) planNode.getInputs().get(i), (PlanNode) planNode2.getInputs().get(i))) {
                return false;
            }
        }
        return true;
    }

    private Worker.QueryResponse submitRequest(Worker.QueryRequest queryRequest, Map<String, String> map) {
        String str = map.get(KEY_OF_SERVER_INSTANCE_HOST);
        int parseInt = Integer.parseInt(map.get(KEY_OF_SERVER_INSTANCE_PORT));
        long parseLong = Long.parseLong(map.get("timeoutMs"));
        ManagedChannel build = ManagedChannelBuilder.forAddress(str, parseInt).usePlaintext().build();
        Worker.QueryResponse submit = PinotQueryWorkerGrpc.newBlockingStub(build).withDeadline(Deadline.after(parseLong, TimeUnit.MILLISECONDS)).submit(queryRequest);
        build.shutdown();
        return submit;
    }

    private Worker.QueryRequest getQueryRequest(DispatchableSubPlan dispatchableSubPlan, int i) {
        DispatchablePlanFragment dispatchablePlanFragment = (DispatchablePlanFragment) dispatchableSubPlan.getQueryStageList().get(i);
        Plan.StageNode serializeStageNode = StageNodeSerDeUtils.serializeStageNode(dispatchablePlanFragment.getPlanFragment().getFragmentRoot());
        List protoWorkerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(dispatchablePlanFragment.getWorkerMetadataList());
        ByteString protoProperties = QueryPlanSerDeUtils.toProtoProperties(dispatchablePlanFragment.getCustomProperties());
        QueryServerInstance queryServerInstance = (QueryServerInstance) dispatchablePlanFragment.getServerInstanceToWorkerIdMap().keySet().iterator().next();
        Worker.StagePlan build = Worker.StagePlan.newBuilder().setRootNode(serializeStageNode.toByteString()).setStageMetadata(Worker.StageMetadata.newBuilder().setStageId(i).addAllWorkerMetadata(protoWorkerMetadataList).setCustomProperty(protoProperties).build()).build();
        HashMap hashMap = new HashMap();
        hashMap.put("requestId", String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()));
        hashMap.put("timeoutMs", String.valueOf(10000L));
        hashMap.put(KEY_OF_SERVER_INSTANCE_HOST, queryServerInstance.getHostname());
        hashMap.put(KEY_OF_SERVER_INSTANCE_PORT, Integer.toString(queryServerInstance.getQueryServicePort()));
        return Worker.QueryRequest.newBuilder().addStagePlan(build).setMetadata(QueryPlanSerDeUtils.toProtoProperties(hashMap)).build();
    }
}
