package org.apache.pinot.plugin.stream.kinesis;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.pinot.spi.stream.StreamConfig;
import org.apache.pinot.spi.stream.StreamConfigProperties;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kinesis.KinesisClient;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;

/* loaded from: input_file:org/apache/pinot/plugin/stream/kinesis/KinesisConsumerTest.class */
public class KinesisConsumerTest {
    private static final String STREAM_TYPE = "kinesis";
    private static final String TABLE_NAME_WITH_TYPE = "kinesisTest_REALTIME";
    private static final String STREAM_NAME = "kinesis-test";
    private static final String AWS_REGION = "us-west-2";
    private static final int TIMEOUT = 1000;
    private static final int NUM_RECORDS = 10;
    private static final String DUMMY_RECORD_PREFIX = "DUMMY_RECORD-";
    private static final String PARTITION_KEY_PREFIX = "PARTITION_KEY-";
    private static final String PLACEHOLDER = "DUMMY";
    private static final int MAX_RECORDS_TO_FETCH = 20;
    private KinesisConfig _kinesisConfig;
    private List<Record> _records;

    private KinesisConfig getKinesisConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("streamType", STREAM_TYPE);
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "topic.name"), STREAM_NAME);
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "consumer.factory.class.name"), KinesisConsumerFactory.class.getName());
        hashMap.put(StreamConfigProperties.constructStreamProperty(STREAM_TYPE, "decoder.class.name"), "org.apache.pinot.plugin.inputformat.json.JSONMessageDecoder");
        hashMap.put("region", AWS_REGION);
        hashMap.put("maxRecordsToFetch", String.valueOf(MAX_RECORDS_TO_FETCH));
        hashMap.put("shardIteratorType", ShardIteratorType.AT_SEQUENCE_NUMBER.toString());
        return new KinesisConfig(new StreamConfig(TABLE_NAME_WITH_TYPE, hashMap));
    }

    @BeforeClass
    public void setUp() {
        this._kinesisConfig = getKinesisConfig();
        this._records = new ArrayList(NUM_RECORDS);
        for (int i = 0; i < NUM_RECORDS; i++) {
            this._records.add((Record) Record.builder().data(SdkBytes.fromUtf8String("DUMMY_RECORD-" + i)).partitionKey("PARTITION_KEY-" + i).approximateArrivalTimestamp(Instant.now()).sequenceNumber(String.valueOf(i + 1)).build());
        }
    }

    @Test
    public void testBasicConsumer() {
        KinesisClient kinesisClient = (KinesisClient) Mockito.mock(KinesisClient.class);
        Mockito.when(kinesisClient.getShardIterator((GetShardIteratorRequest) ArgumentMatchers.any(GetShardIteratorRequest.class))).thenReturn((GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
        Mockito.when(kinesisClient.getRecords((GetRecordsRequest) ArgumentMatchers.any(GetRecordsRequest.class))).thenReturn((GetRecordsResponse) GetRecordsResponse.builder().nextShardIterator(PLACEHOLDER).records(this._records).build());
        KinesisConsumer kinesisConsumer = new KinesisConsumer(this._kinesisConfig, kinesisClient);
        KinesisMessageBatch fetchMessages = kinesisConsumer.fetchMessages(new KinesisPartitionGroupOffset("0", "1"), TIMEOUT);
        Assert.assertEquals(fetchMessages.getMessageCount(), NUM_RECORDS);
        for (int i = 0; i < NUM_RECORDS; i++) {
            Assert.assertEquals(baToString((byte[]) fetchMessages.getStreamMessage(i).getValue()), "DUMMY_RECORD-" + i);
        }
        Assert.assertFalse(fetchMessages.isEndOfPartitionGroup());
        KinesisMessageBatch fetchMessages2 = kinesisConsumer.fetchMessages(fetchMessages.getOffsetOfNextBatch(), TIMEOUT);
        Assert.assertEquals(fetchMessages2.getMessageCount(), NUM_RECORDS);
        for (int i2 = 0; i2 < NUM_RECORDS; i2++) {
            Assert.assertEquals(baToString((byte[]) fetchMessages2.getStreamMessage(i2).getValue()), "DUMMY_RECORD-" + i2);
        }
        Assert.assertFalse(fetchMessages2.isEndOfPartitionGroup());
        ((KinesisClient) Mockito.verify(kinesisClient, Mockito.times(1))).getShardIterator((GetShardIteratorRequest) ArgumentMatchers.any(GetShardIteratorRequest.class));
        ((KinesisClient) Mockito.verify(kinesisClient, Mockito.times(2))).getRecords((GetRecordsRequest) ArgumentMatchers.any(GetRecordsRequest.class));
    }

    @Test
    public void testEndOfShard() {
        KinesisClient kinesisClient = (KinesisClient) Mockito.mock(KinesisClient.class);
        Mockito.when(kinesisClient.getShardIterator((GetShardIteratorRequest) ArgumentMatchers.any(GetShardIteratorRequest.class))).thenReturn((GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(PLACEHOLDER).build());
        Mockito.when(kinesisClient.getRecords((GetRecordsRequest) ArgumentMatchers.any(GetRecordsRequest.class))).thenReturn((GetRecordsResponse) GetRecordsResponse.builder().nextShardIterator((String) null).records(this._records).build());
        KinesisConsumer kinesisConsumer = new KinesisConsumer(this._kinesisConfig, kinesisClient);
        KinesisMessageBatch fetchMessages = kinesisConsumer.fetchMessages(new KinesisPartitionGroupOffset("0", "1"), TIMEOUT);
        Assert.assertEquals(fetchMessages.getMessageCount(), NUM_RECORDS);
        for (int i = 0; i < NUM_RECORDS; i++) {
            Assert.assertEquals(baToString((byte[]) fetchMessages.getStreamMessage(i).getValue()), "DUMMY_RECORD-" + i);
        }
        Assert.assertTrue(fetchMessages.isEndOfPartitionGroup());
        KinesisMessageBatch fetchMessages2 = kinesisConsumer.fetchMessages(fetchMessages.getOffsetOfNextBatch(), TIMEOUT);
        Assert.assertEquals(fetchMessages2.getMessageCount(), 0);
        Assert.assertTrue(fetchMessages2.isEndOfPartitionGroup());
        ((KinesisClient) Mockito.verify(kinesisClient, Mockito.times(1))).getShardIterator((GetShardIteratorRequest) ArgumentMatchers.any(GetShardIteratorRequest.class));
        ((KinesisClient) Mockito.verify(kinesisClient, Mockito.times(1))).getRecords((GetRecordsRequest) ArgumentMatchers.any(GetRecordsRequest.class));
    }

    public String baToString(byte[] bArr) {
        return SdkBytes.fromByteArray(bArr).asUtf8String();
    }
}
