package org.apache.pinot.query.runtime.executor;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.utils.NamedThreadFactory;
import org.apache.pinot.query.mailbox.MailboxService;
import org.apache.pinot.query.routing.StageMetadata;
import org.apache.pinot.query.routing.WorkerMetadata;
import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.MultiStageOperator;
import org.apache.pinot.query.runtime.operator.OpChain;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.runtime.plan.pipeline.PipelineBreakerResult;
import org.apache.pinot.spi.accounting.ThreadExecutionContext;
import org.apache.pinot.spi.executor.ExecutorServiceUtils;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.class */
public class OpChainSchedulerServiceTest {
    private ExecutorService _executor;
    private AutoCloseable _mocks;
    private MultiStageOperator _operatorA;

    @BeforeClass
    public void beforeClass() {
        this._mocks = MockitoAnnotations.openMocks(this);
        this._executor = Executors.newCachedThreadPool(new NamedThreadFactory("worker_on_" + getClass().getSimpleName()));
    }

    @AfterClass
    public void afterClass() throws Exception {
        this._mocks.close();
        ExecutorServiceUtils.close(this._executor);
    }

    @BeforeMethod
    public void beforeMethod() {
        this._operatorA = (MultiStageOperator) Mockito.mock(MultiStageOperator.class);
        Mockito.clearInvocations(new MultiStageOperator[]{this._operatorA});
    }

    private OpChain getChain(MultiStageOperator multiStageOperator) {
        MailboxService mailboxService = (MailboxService) Mockito.mock(MailboxService.class);
        Mockito.when(mailboxService.getHostname()).thenReturn("localhost");
        Mockito.when(Integer.valueOf(mailboxService.getPort())).thenReturn(1234);
        WorkerMetadata workerMetadata = new WorkerMetadata(0, ImmutableMap.of(), ImmutableMap.of());
        return new OpChain(new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, ImmutableMap.of(), new StageMetadata(0, ImmutableList.of(workerMetadata), ImmutableMap.of()), workerMetadata, (PipelineBreakerResult) null, (ThreadExecutionContext) null, true), multiStageOperator);
    }

    @Test
    public void shouldScheduleSingleOpChainRegisteredAfterStart() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Mockito.when(this._operatorA.nextBlock()).thenAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            return TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0);
        });
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "expected await to be called in less than 10 seconds");
    }

    @Test
    public void shouldScheduleSingleOpChainRegisteredBeforeStart() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Mockito.when(this._operatorA.nextBlock()).thenAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            return TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0);
        });
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "expected await to be called in less than 10 seconds");
    }

    @Test
    public void shouldCallCloseOnOperatorsThatFinishSuccessfully() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Mockito.when(this._operatorA.nextBlock()).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        ((MultiStageOperator) Mockito.doAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            return null;
        }).when(this._operatorA)).close();
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "expected await to be called in less than 10 seconds");
    }

    @Test
    public void shouldCallCancelOnOperatorsThatReturnErrorBlock() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Mockito.when(this._operatorA.nextBlock()).thenReturn(TransferableBlockUtils.getErrorTransferableBlock(new RuntimeException("foo")));
        ((MultiStageOperator) Mockito.doAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            return null;
        }).when(this._operatorA)).cancel((Throwable) Mockito.any());
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "expected await to be called in less than 10 seconds");
    }

    @Test
    public void shouldCallCancelOnOpChainsWhenItIsCancelledByDispatch() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        ((MultiStageOperator) Mockito.doAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            while (true) {
                Thread.sleep(1000L);
            }
        }).when(this._operatorA)).nextBlock();
        CountDownLatch countDownLatch2 = new CountDownLatch(1);
        ((MultiStageOperator) Mockito.doAnswer(invocationOnMock2 -> {
            countDownLatch2.countDown();
            return null;
        }).when(this._operatorA)).cancel((Throwable) Mockito.any());
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "op chain doesn't seem to be started");
        opChainSchedulerService.cancel(123L);
        Assert.assertTrue(countDownLatch2.await(10L, TimeUnit.SECONDS), "expected OpChain to be cancelled");
        ((MultiStageOperator) Mockito.verify(this._operatorA, Mockito.times(1))).cancel((Throwable) Mockito.any());
    }

    @Test
    public void shouldCallCancelOnOpChainsThatThrow() throws InterruptedException {
        OpChain chain = getChain(this._operatorA);
        OpChainSchedulerService opChainSchedulerService = new OpChainSchedulerService(this._executor);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Mockito.when(this._operatorA.nextBlock()).thenThrow(new Throwable[]{new RuntimeException("foo")});
        ((MultiStageOperator) Mockito.doAnswer(invocationOnMock -> {
            countDownLatch.countDown();
            return null;
        }).when(this._operatorA)).cancel((Throwable) Mockito.any());
        opChainSchedulerService.register(chain);
        Assert.assertTrue(countDownLatch.await(10L, TimeUnit.SECONDS), "expected OpChain to be cancelled");
        ((MultiStageOperator) Mockito.verify(this._operatorA, Mockito.times(1))).cancel((Throwable) Mockito.any());
    }
}
