package org.apache.pinot.query.runtime.operator;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.planner.plannode.JoinNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.SortNode;
import org.apache.pinot.query.planner.plannode.WindowNode;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils;
import org.apache.pinot.spi.accounting.QueryResourceTracker;
import org.apache.pinot.spi.accounting.ThreadExecutionContext;
import org.apache.pinot.spi.accounting.ThreadResourceTracker;
import org.apache.pinot.spi.accounting.ThreadResourceUsageAccountant;
import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.trace.Tracing;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.testng.Assert;
import org.testng.ITest;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Factory;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.class */
public class MultiStageAccountingTest implements ITest {
    private AutoCloseable _mocks;

    @Mock
    private VirtualServerAddress _serverAddress;
    protected String _testName;
    protected MultiStageOperator _operator;

    @Factory(dataProvider = "operatorProvider")
    public MultiStageAccountingTest(String str, MultiStageOperator multiStageOperator) {
        this._testName = str;
        this._operator = multiStageOperator;
    }

    @BeforeClass
    public static void setUpClass() {
        ThreadResourceUsageProvider.setThreadMemoryMeasurementEnabled(true);
        HashMap hashMap = new HashMap();
        ServerMetrics.register((ServerMetrics) Mockito.mock(ServerMetrics.class));
        hashMap.put("accounting.oom.alarming.usage.ratio", Float.valueOf(0.0f));
        hashMap.put("accounting.oom.critical.heap.usage.ratio", Float.valueOf(0.0f));
        hashMap.put("accounting.factory.name", "org.apache.pinot.core.accounting.PerQueryCPUMemAccountantFactory");
        hashMap.put("accounting.enable.thread.memory.sampling", true);
        hashMap.put("accounting.enable.thread.cpu.sampling", false);
        hashMap.put("accounting.oom.enable.killing.query", true);
        Tracing.ThreadAccountantOps.initializeThreadAccountant(new PinotConfiguration(hashMap), "testGroupBy");
        Tracing.ThreadAccountantOps.setupRunner("MultiStageAccountingTest", ThreadExecutionContext.TaskType.MSE);
        ThreadExecutionContext threadExecutionContext = Tracing.getThreadAccountant().getThreadExecutionContext();
        Tracing.ThreadAccountantOps.setupWorker(1, ThreadExecutionContext.TaskType.MSE, new ThreadResourceUsageProvider(), threadExecutionContext);
    }

    @BeforeMethod
    public void setUp() {
        this._mocks = MockitoAnnotations.openMocks(this);
        Mockito.when(this._serverAddress.toString()).thenReturn(new VirtualServerAddress("mock", 80, 0).toString());
    }

    @AfterMethod
    public void tearDown() throws Exception {
        this._mocks.close();
    }

    @Test
    public void testOperatorAccounting() {
        this._operator.nextBlock().getContainer();
        ThreadResourceUsageAccountant threadAccountant = Tracing.getThreadAccountant();
        Collection values = threadAccountant.getQueryResources().values();
        Collection threadResources = threadAccountant.getThreadResources();
        Assert.assertEquals(values.size(), 1);
        Assert.assertEquals(threadResources.size(), 1);
        Assert.assertTrue(((QueryResourceTracker) values.iterator().next()).getAllocatedBytes() > 0);
        Assert.assertTrue(((ThreadResourceTracker) threadResources.iterator().next()).getAllocatedBytes() > 0);
        Assert.assertTrue(this._operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "operatorProvider")
    public static Object[][] getOperators() {
        return new Object[]{new Object[]{"AggregateOperator", getAggregateOperator()}, new Object[]{"SortOperator", getSortOperator()}, new Object[]{"HashJoinOperator", getHashJoinOperator()}, new Object[]{"WindowAggregateOperator", getWindowAggregateOperator()}, new Object[]{"SetOperator", getIntersectOperator()}};
    }

    /* JADX WARN: Type inference failed for: r2v4, types: [java.lang.Object[], java.lang.Object[][]] */
    private static MultiStageOperator getAggregateOperator() {
        MultiStageOperator multiStageOperator = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        List of = List.of(getSum(new RexExpression.InputRef(1)));
        List of2 = List.of(-1);
        List of3 = List.of(0);
        Mockito.when(multiStageOperator.nextBlock()).thenReturn(OperatorTestUtil.block(new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), new Object[]{new Object[]{2, Double.valueOf(1.0d)}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        return new AggregateOperator(OperatorTestUtil.getTracingContext(), multiStageOperator, new AggregateNode(-1, new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), PlanNode.NodeHint.EMPTY, List.of(), of, of2, of3, AggregateNode.AggType.DIRECT, false));
    }

    /* JADX WARN: Type inference failed for: r2v5, types: [java.lang.Object[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r2v7, types: [java.lang.Object[], java.lang.Object[][]] */
    private static MultiStageOperator getHashJoinOperator() {
        MultiStageOperator multiStageOperator = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        MultiStageOperator multiStageOperator2 = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        DataSchema dataSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING});
        DataSchema dataSchema2 = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING});
        Mockito.when(multiStageOperator.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{1, "Aa"}, new Object[]{2, "BB"}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        Mockito.when(multiStageOperator2.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema2, new Object[]{new Object[]{2, "Aa"}, new Object[]{2, "BB"}, new Object[]{3, "BB"}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        return new HashJoinOperator(OperatorTestUtil.getTracingContext(), multiStageOperator, dataSchema, multiStageOperator2, new JoinNode(-1, new DataSchema(new String[]{"int_col1", "string_col1", "int_col2", "string_co2"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING}), PlanNode.NodeHint.EMPTY, List.of(), JoinRelType.INNER, List.of(0), List.of(0), List.of(), JoinNode.JoinStrategy.HASH));
    }

    private static MultiStageOperator getSortOperator() {
        MultiStageOperator multiStageOperator = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        DataSchema dataSchema = new DataSchema(new String[]{"sort"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT});
        Mockito.when(multiStageOperator.nextBlock()).thenReturn(new TransferableBlock(List.of(new Object[]{2}, new Object[]{1}), dataSchema, DataBlock.Type.ROW)).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        return new SortOperator(OperatorTestUtil.getTracingContext(), multiStageOperator, new SortNode(-1, dataSchema, PlanNode.NodeHint.EMPTY, List.of(), List.of(new RelFieldCollation(0, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)), 10, 0));
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [java.lang.Object[], java.lang.Object[][]] */
    private static MultiStageOperator getWindowAggregateOperator() {
        MultiStageOperator multiStageOperator = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        DataSchema dataSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.INT});
        Mockito.when(multiStageOperator.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, 1}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        return new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), multiStageOperator, dataSchema, new WindowNode(-1, new DataSchema(new String[]{"group", "arg", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), PlanNode.NodeHint.EMPTY, List.of(), List.of(0), List.of(), List.of(getSum(new RexExpression.InputRef(1))), WindowNode.WindowFrameType.RANGE, Integer.MIN_VALUE, Integer.MAX_VALUE, List.of()));
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [java.lang.Object[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r2v5, types: [java.lang.Object[], java.lang.Object[][]] */
    private static MultiStageOperator getIntersectOperator() {
        MultiStageOperator multiStageOperator = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        MultiStageOperator multiStageOperator2 = (MultiStageOperator) Mockito.mock(new MultiStageOperator[0]);
        DataSchema dataSchema = new DataSchema(new String[]{"int_col", "string_col"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING});
        Mockito.when(multiStageOperator.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{1, "AA"}, new Object[]{2, "BB"}, new Object[]{3, "CC"}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        Mockito.when(multiStageOperator2.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{1, "AA"}, new Object[]{2, "BB"}, new Object[]{4, "DD"}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        return new IntersectOperator(OperatorTestUtil.getTracingContext(), ImmutableList.of(multiStageOperator, multiStageOperator2), dataSchema);
    }

    private static RexExpression.FunctionCall getSum(RexExpression rexExpression) {
        return new RexExpression.FunctionCall(DataSchema.ColumnDataType.INT, SqlKind.SUM.name(), List.of(rexExpression));
    }

    public String getTestName() {
        return this._testName;
    }
}
