package org.apache.pinot.query.runtime.operator;

import java.util.List;
import java.util.Map;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.AggregateOperator;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/query/runtime/operator/AggregateOperatorTest.class */
public class AggregateOperatorTest {
    private AutoCloseable _mocks;

    @Mock
    private MultiStageOperator _input;

    @Mock
    private VirtualServerAddress _serverAddress;

    @BeforeMethod
    public void setUp() {
        this._mocks = MockitoAnnotations.openMocks(this);
        Mockito.when(this._serverAddress.toString()).thenReturn(new VirtualServerAddress("mock", 80, 0).toString());
    }

    @AfterMethod
    public void tearDown() throws Exception {
        this._mocks.close();
    }

    @Test
    public void shouldHandleUpstreamErrorBlocks() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        Mockito.when(this._input.nextBlock()).thenReturn(TransferableBlockUtils.getErrorTransferableBlock(new Exception("foo!")));
        TransferableBlock nextBlock = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3).nextBlock();
        ((MultiStageOperator) Mockito.verify(this._input, Mockito.times(1))).nextBlock();
        Assert.assertTrue(nextBlock.isErrorBlock(), "Input errors should propagate immediately");
    }

    @Test
    public void shouldHandleEndOfStreamBlockWithNoOtherInputs() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        Mockito.when(this._input.nextBlock()).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        TransferableBlock nextBlock = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3).nextBlock();
        ((MultiStageOperator) Mockito.verify(this._input, Mockito.times(1))).nextBlock();
        Assert.assertTrue(nextBlock.isEndOfStreamBlock(), "EOS blocks should propagate");
    }

    /* JADX WARN: Type inference failed for: r2v4, types: [java.lang.Object[], java.lang.Object[][]] */
    @Test
    public void testAggregateSingleInputBlock() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        Mockito.when(this._input.nextBlock()).thenReturn(OperatorTestUtil.block(new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), new Object[]{new Object[]{2, Double.valueOf(1.0d)}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        AggregateOperator operator = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3);
        List container = operator.nextBlock().getContainer();
        Assert.assertEquals(container.size(), 1);
        Assert.assertEquals((Object[]) container.get(0), new Object[]{2, Double.valueOf(1.0d)}, "Expected two columns (group by key, agg value), agg value is final result");
        Assert.assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
    }

    /* JADX WARN: Type inference failed for: r2v4, types: [java.lang.Object[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r2v6, types: [java.lang.Object[], java.lang.Object[][]] */
    @Test
    public void testAggregateMultipleInputBlocks() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        DataSchema dataSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE});
        Mockito.when(this._input.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, Double.valueOf(1.0d)}, new Object[]{2, Double.valueOf(2.0d)}})).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, Double.valueOf(3.0d)}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        AggregateOperator operator = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3);
        List container = operator.nextBlock().getContainer();
        Assert.assertEquals(container.size(), 1);
        Assert.assertEquals((Object[]) container.get(0), new Object[]{2, Double.valueOf(6.0d)}, "Expected two columns (group by key, agg value), agg value is final result");
        Assert.assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
    }

    /* JADX WARN: Type inference failed for: r2v5, types: [java.lang.Object[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r2v7, types: [java.lang.Object[], java.lang.Object[][]] */
    @Test
    public void testAggregateWithFilter() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)), getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1, 2);
        List<Integer> of3 = List.of(0);
        DataSchema dataSchema = new DataSchema(new String[]{"group", "arg", "filterArg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.BOOLEAN});
        Mockito.when(this._input.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, Double.valueOf(1.0d), 0}, new Object[]{2, Double.valueOf(2.0d), 1}})).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, Double.valueOf(3.0d), 1}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        AggregateOperator operator = getOperator(new DataSchema(new String[]{"group", "sum", "sumWithFilter"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3);
        List container = operator.nextBlock().getContainer();
        Assert.assertEquals(container.size(), 1);
        Assert.assertEquals((Object[]) container.get(0), new Object[]{2, Double.valueOf(6.0d), Double.valueOf(5.0d)}, "Expected three columns (group by key, agg value, agg value with filter), agg value is final result");
        Assert.assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
    }

    @Test
    public void testGroupByAggregateWithHashCollision() {
        this._input = OperatorTestUtil.getOperator(OperatorTestUtil.OP_1);
        AggregateOperator operator = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE}), List.of(getSum(new RexExpression.InputRef(0))), List.of(-1), List.of(1));
        List container = operator.nextBlock().getContainer();
        Assert.assertEquals(container.size(), 2);
        if (((Object[]) container.get(0))[0].equals("Aa")) {
            Assert.assertEquals((Object[]) container.get(0), new Object[]{"Aa", Double.valueOf(1.0d)});
            Assert.assertEquals((Object[]) container.get(1), new Object[]{"BB", Double.valueOf(5.0d)});
        } else {
            Assert.assertEquals((Object[]) container.get(0), new Object[]{"BB", Double.valueOf(5.0d)});
            Assert.assertEquals((Object[]) container.get(1), new Object[]{"Aa", Double.valueOf(1.0d)});
        }
        Assert.assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock());
    }

    @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = ".*AVERAGE.*")
    public void shouldThrowOnUnknownAggFunction() {
        getOperator(new DataSchema(new String[]{"unknown"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE}), List.of(new RexExpression.FunctionCall(DataSchema.ColumnDataType.INT, "AVERAGE", List.of())), List.of(-1), List.of(0));
    }

    /* JADX WARN: Type inference failed for: r2v4, types: [java.lang.Object[], java.lang.Object[][]] */
    @Test
    public void shouldReturnErrorBlockOnUnexpectedInputType() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        Mockito.when(this._input.nextBlock()).thenReturn(OperatorTestUtil.block(new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.STRING}), new Object[]{new Object[]{2, "foo"}, new Object[]{2, "foo"}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        TransferableBlock nextBlock = getOperator(new DataSchema(new String[]{"sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE}), of, of2, of3).nextBlock();
        Assert.assertTrue(nextBlock.isErrorBlock(), "expected ERROR block from invalid computation");
        Assert.assertTrue(((String) nextBlock.getExceptions().get(1000)).contains("cannot be cast to class"), "expected it to fail with class cast exception");
    }

    /* JADX WARN: Type inference failed for: r2v6, types: [java.lang.Object[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r2v8, types: [java.lang.Object[], java.lang.Object[][]] */
    @Test
    public void shouldHandleGroupLimitExceed() {
        List<RexExpression.FunctionCall> of = List.of(getSum(new RexExpression.InputRef(1)));
        List<Integer> of2 = List.of(-1);
        List<Integer> of3 = List.of(0);
        PlanNode.NodeHint nodeHint = new PlanNode.NodeHint(Map.of("aggOptions", Map.of("num_groups_limit", "1")));
        DataSchema dataSchema = new DataSchema(new String[]{"group", "arg"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE});
        Mockito.when(this._input.nextBlock()).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{2, Double.valueOf(1.0d)}, new Object[]{3, Double.valueOf(2.0d)}})).thenReturn(OperatorTestUtil.block(dataSchema, new Object[]{new Object[]{3, Double.valueOf(3.0d)}})).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
        AggregateOperator operator = getOperator(new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE}), of, of2, of3, nodeHint);
        TransferableBlock nextBlock = operator.nextBlock();
        TransferableBlock nextBlock2 = operator.nextBlock();
        ((MultiStageOperator) Mockito.verify(this._input)).earlyTerminate();
        Assert.assertEquals(nextBlock.getNumRows(), 1, "when group limit reach it should only return that many groups");
        Assert.assertTrue(nextBlock2.isEndOfStreamBlock(), "Second block is EOS (done processing)");
        Assert.assertTrue(OperatorTestUtil.getStatMap(AggregateOperator.StatKey.class, nextBlock2).getBoolean(AggregateOperator.StatKey.NUM_GROUPS_LIMIT_REACHED), "num groups limit should be reached");
    }

    private static RexExpression.FunctionCall getSum(RexExpression rexExpression) {
        return new RexExpression.FunctionCall(DataSchema.ColumnDataType.INT, SqlKind.SUM.name(), List.of(rexExpression));
    }

    private AggregateOperator getOperator(DataSchema dataSchema, List<RexExpression.FunctionCall> list, List<Integer> list2, List<Integer> list3, PlanNode.NodeHint nodeHint) {
        return new AggregateOperator(OperatorTestUtil.getTracingContext(), this._input, new AggregateNode(-1, dataSchema, nodeHint, List.of(), list, list2, list3, AggregateNode.AggType.DIRECT));
    }

    private AggregateOperator getOperator(DataSchema dataSchema, List<RexExpression.FunctionCall> list, List<Integer> list2, List<Integer> list3) {
        return getOperator(dataSchema, list, list2, list3, PlanNode.NodeHint.EMPTY);
    }
}
