package org.apache.pinot.plugin.stream.kinesis;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.pinot.spi.stream.StreamConfig;
import org.apache.pinot.spi.stream.StreamConfigProperties;
import org.apache.pinot.spi.stream.StreamConsumerFactory;
import org.apache.pinot.spi.stream.StreamPartitionMsgOffset;
import org.easymock.Capture;
import org.easymock.EasyMock;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.KinesisClient;
import software.amazon.awssdk.services.kinesis.model.ChildShard;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;

/* loaded from: input_file:org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.class */
public class KinesisConsumerTest {
    private static final String STREAM_TYPE = "kinesis";
    private static final String TABLE_NAME_WITH_TYPE = "kinesisTest_REALTIME";
    private static final String STREAM_NAME = "kinesis-test";
    private static final String AWS_REGION = "us-west-2";
    private static final int TIMEOUT = 1000;
    private static final int NUM_RECORDS = 10;
    private static final String DUMMY_RECORD_PREFIX = "DUMMY_RECORD-";
    private static final String PARTITION_KEY_PREFIX = "PARTITION_KEY-";
    private static final String PLACEHOLDER = "DUMMY";
    private static final int MAX_RECORDS_TO_FETCH = 20;
    private KinesisConnectionHandler _kinesisConnectionHandler;
    private StreamConsumerFactory _streamConsumerFactory;
    private KinesisClient _kinesisClient;
    private List<Record> _recordList;

    private KinesisConfig getKinesisConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("streamType", STREAM_TYPE);
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "topic.name"), STREAM_NAME);
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "consumer.type"), StreamConfig.ConsumerType.LOWLEVEL.toString());
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "consumer.factory.class.name"), KinesisConsumerFactory.class.getName());
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "decoder.class.name"), "org.apache.pinot.plugin.inputformat.json.JSONMessageDecoder");
        hashMap.put("region", AWS_REGION);
        hashMap.put("maxRecordsToFetch", String.valueOf(MAX_RECORDS_TO_FETCH));
        hashMap.put("shardIteratorType", ShardIteratorType.AT_SEQUENCE_NUMBER.toString());
        return new KinesisConfig(new StreamConfig(TABLE_NAME_WITH_TYPE, hashMap));
    }

    @BeforeMethod
    public void setupTest() {
        this._kinesisConnectionHandler = (KinesisConnectionHandler) EasyMock.createMock(KinesisConnectionHandler.class);
        this._kinesisClient = (KinesisClient) EasyMock.createMock(KinesisClient.class);
        this._streamConsumerFactory = (StreamConsumerFactory) EasyMock.createMock(StreamConsumerFactory.class);
        this._recordList = new ArrayList();
        for (int i = 0; i < NUM_RECORDS; i++) {
            this._recordList.add((Record) Record.builder().data(SdkBytes.fromUtf8String("DUMMY_RECORD-" + i)).partitionKey("PARTITION_KEY-" + i).sequenceNumber(String.valueOf(i + 1)).build());
        }
    }

    @Test
    public void testBasicConsumer() {
        Capture newInstance = Capture.newInstance();
        Capture newInstance2 = Capture.newInstance();
        GetRecordsResponse getRecordsResponse = (GetRecordsResponse) GetRecordsResponse.builder().nextShardIterator((String) null).records(this._recordList).build();
        GetShardIteratorResponse getShardIteratorResponse = (GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build();
        EasyMock.expect(this._kinesisClient.getRecords((GetRecordsRequest) EasyMock.capture(newInstance))).andReturn(getRecordsResponse).anyTimes();
        EasyMock.expect(this._kinesisClient.getShardIterator((GetShardIteratorRequest) EasyMock.capture(newInstance2))).andReturn(getShardIteratorResponse).anyTimes();
        EasyMock.replay(new Object[]{this._kinesisClient});
        KinesisConsumer kinesisConsumer = new KinesisConsumer(getKinesisConfig(), this._kinesisClient);
        HashMap hashMap = new HashMap();
        hashMap.put("0", "1");
        KinesisRecordsBatch fetchMessages = kinesisConsumer.fetchMessages(new KinesisPartitionGroupOffset(hashMap), (StreamPartitionMsgOffset) null, TIMEOUT);
        Assert.assertEquals(fetchMessages.getMessageCount(), NUM_RECORDS);
        for (int i = 0; i < NUM_RECORDS; i++) {
            Assert.assertEquals(baToString(fetchMessages.getMessageAtIndex(i)), "DUMMY_RECORD-" + i);
        }
        Assert.assertFalse(fetchMessages.isEndOfPartitionGroup());
    }

    @Test
    public void testBasicConsumerWithMaxRecordsLimit() {
        Capture newInstance = Capture.newInstance();
        Capture newInstance2 = Capture.newInstance();
        GetRecordsResponse getRecordsResponse = (GetRecordsResponse) GetRecordsResponse.builder().nextShardIterator(PLACEHOLDER).records(this._recordList).build();
        GetShardIteratorResponse getShardIteratorResponse = (GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build();
        EasyMock.expect(this._kinesisClient.getRecords((GetRecordsRequest) EasyMock.capture(newInstance))).andReturn(getRecordsResponse).anyTimes();
        EasyMock.expect(this._kinesisClient.getShardIterator((GetShardIteratorRequest) EasyMock.capture(newInstance2))).andReturn(getShardIteratorResponse).anyTimes();
        EasyMock.replay(new Object[]{this._kinesisClient});
        KinesisConsumer kinesisConsumer = new KinesisConsumer(getKinesisConfig(), this._kinesisClient);
        HashMap hashMap = new HashMap();
        hashMap.put("0", "1");
        KinesisRecordsBatch fetchMessages = kinesisConsumer.fetchMessages(new KinesisPartitionGroupOffset(hashMap), (StreamPartitionMsgOffset) null, TIMEOUT);
        Assert.assertEquals(fetchMessages.getMessageCount(), MAX_RECORDS_TO_FETCH);
        for (int i = 0; i < NUM_RECORDS; i++) {
            Assert.assertEquals(baToString(fetchMessages.getMessageAtIndex(i)), "DUMMY_RECORD-" + i);
        }
    }

    @Test
    public void testBasicConsumerWithChildShard() {
        ArrayList arrayList = new ArrayList();
        arrayList.add((ChildShard) ChildShard.builder().shardId(PLACEHOLDER).parentShards(new String[]{"0"}).build());
        Capture newInstance = Capture.newInstance();
        Capture newInstance2 = Capture.newInstance();
        GetRecordsResponse getRecordsResponse = (GetRecordsResponse) GetRecordsResponse.builder().nextShardIterator((String) null).records(this._recordList).childShards(arrayList).build();
        GetShardIteratorResponse getShardIteratorResponse = (GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build();
        EasyMock.expect(this._kinesisClient.getRecords((GetRecordsRequest) EasyMock.capture(newInstance))).andReturn(getRecordsResponse).anyTimes();
        EasyMock.expect(this._kinesisClient.getShardIterator((GetShardIteratorRequest) EasyMock.capture(newInstance2))).andReturn(getShardIteratorResponse).anyTimes();
        EasyMock.replay(new Object[]{this._kinesisClient});
        KinesisConsumer kinesisConsumer = new KinesisConsumer(getKinesisConfig(), this._kinesisClient);
        HashMap hashMap = new HashMap();
        hashMap.put("0", "1");
        KinesisRecordsBatch fetchMessages = kinesisConsumer.fetchMessages(new KinesisPartitionGroupOffset(hashMap), (StreamPartitionMsgOffset) null, TIMEOUT);
        Assert.assertTrue(fetchMessages.isEndOfPartitionGroup());
        Assert.assertEquals(fetchMessages.getMessageCount(), NUM_RECORDS);
        for (int i = 0; i < NUM_RECORDS; i++) {
            Assert.assertEquals(baToString(fetchMessages.getMessageAtIndex(i)), "DUMMY_RECORD-" + i);
        }
    }

    public String baToString(byte[] bArr) {
        return SdkBytes.fromByteArray(bArr).asUtf8String();
    }
}
