package org.apache.pinot.controller.helix.core.relocation;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.hc.client5.http.io.HttpClientConnectionManager;
import org.apache.helix.ClusterMessagingService;
import org.apache.helix.Criteria;
import org.apache.helix.messaging.AsyncCallback;
import org.apache.helix.model.Message;
import org.apache.helix.zookeeper.datamodel.ZNRecord;
import org.apache.pinot.common.messages.SegmentReloadMessage;
import org.apache.pinot.common.metrics.ControllerMetrics;
import org.apache.pinot.controller.ControllerConf;
import org.apache.pinot.controller.LeadControllerManager;
import org.apache.pinot.controller.helix.core.PinotHelixResourceManager;
import org.apache.pinot.controller.util.TableTierReader;
import org.apache.pinot.util.TestUtils;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/controller/helix/core/relocation/SegmentRelocatorTest.class */
public class SegmentRelocatorTest {
    @Test
    public void testTriggerLocalTierMigration() {
        TableTierReader.TableTierDetails tableTierDetails = (TableTierReader.TableTierDetails) Mockito.mock(TableTierReader.TableTierDetails.class);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("server01", "hotTier");
        hashMap2.put("server02", "hotTier");
        hashMap.put("seg01", hashMap2);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("server01", "coldTier");
        hashMap3.put("server02", "hotTier");
        hashMap.put("seg02", hashMap3);
        HashMap hashMap4 = new HashMap();
        hashMap4.put("server01", "coldTier");
        hashMap4.put("server02", "coldTier");
        hashMap.put("seg03", hashMap4);
        HashMap hashMap5 = new HashMap();
        hashMap5.put("server01", "coldTier");
        hashMap5.put("server02", "coldTier");
        hashMap.put("seg04", hashMap5);
        HashMap hashMap6 = new HashMap();
        hashMap6.put("seg01", "coldTier");
        hashMap6.put("seg02", "coldTier");
        hashMap6.put("seg03", "coldTier");
        Mockito.when(tableTierDetails.getSegmentCurrentTiers()).thenReturn(hashMap);
        Mockito.when(tableTierDetails.getSegmentTargetTiers()).thenReturn(hashMap6);
        ClusterMessagingService clusterMessagingService = (ClusterMessagingService) Mockito.mock(ClusterMessagingService.class);
        SegmentRelocator.triggerLocalTierMigration("table01", tableTierDetails, clusterMessagingService);
        ArgumentCaptor forClass = ArgumentCaptor.forClass(Criteria.class);
        ArgumentCaptor forClass2 = ArgumentCaptor.forClass(SegmentReloadMessage.class);
        ((ClusterMessagingService) Mockito.verify(clusterMessagingService, Mockito.times(2))).send((Criteria) forClass.capture(), (Message) forClass2.capture(), (AsyncCallback) ArgumentMatchers.eq((Object) null), ArgumentMatchers.eq(-1));
        List allValues = forClass.getAllValues();
        List allValues2 = forClass2.getAllValues();
        for (int i = 0; i < allValues.size(); i++) {
            String instanceName = ((Criteria) allValues.get(i)).getInstanceName();
            List segmentList = ((SegmentReloadMessage) allValues2.get(i)).getSegmentList();
            if (instanceName.equals("server01")) {
                Assert.assertEquals(segmentList.size(), 2);
                Assert.assertTrue(segmentList.containsAll(Arrays.asList("seg01", "seg04")));
            } else if (instanceName.equals("server02")) {
                Assert.assertEquals(segmentList.size(), 3);
                Assert.assertTrue(segmentList.containsAll(Arrays.asList("seg01", "seg02", "seg04")));
            } else {
                Assert.fail("Unexpected server: " + instanceName);
            }
        }
    }

    @Test
    public void testRebalanceTablesSequentially() throws InterruptedException {
        ControllerConf controllerConf = (ControllerConf) Mockito.mock(ControllerConf.class);
        Mockito.when(Boolean.valueOf(controllerConf.isSegmentRelocatorRebalanceTablesSequentially())).thenReturn(true);
        SegmentRelocator segmentRelocator = new SegmentRelocator((PinotHelixResourceManager) Mockito.mock(PinotHelixResourceManager.class), (LeadControllerManager) Mockito.mock(LeadControllerManager.class), controllerConf, (ControllerMetrics) Mockito.mock(ControllerMetrics.class), (ExecutorService) Mockito.mock(ExecutorService.class), (HttpClientConnectionManager) Mockito.mock(HttpClientConnectionManager.class));
        Random random = new Random();
        for (int i = 0; i < 10; i++) {
            segmentRelocator.putTableToWait("t_" + i);
        }
        for (int i2 = 0; i2 < 10; i2++) {
            segmentRelocator.putTableToWait("t_" + random.nextInt(10));
        }
        BlockingQueue waitingQueue = segmentRelocator.getWaitingQueue();
        Assert.assertEquals(waitingQueue.size(), 10);
        HashSet hashSet = new HashSet(waitingQueue);
        for (int i3 = 0; i3 < 10; i3++) {
            Assert.assertTrue(hashSet.contains("t_" + i3));
        }
        String[] strArr = new String[1];
        for (int i4 = 0; i4 < 10; i4++) {
            Assert.assertEquals(waitingQueue.size(), 10 - i4);
            segmentRelocator.rebalanceWaitingTable(str -> {
                strArr[0] = str;
            });
            Assert.assertEquals(strArr[0], "t_" + i4);
        }
        Assert.assertEquals(waitingQueue.size(), 0);
    }

    @Test
    public void testRebalanceTablesSequentiallyWithMultiRequesters() {
        ControllerConf controllerConf = (ControllerConf) Mockito.mock(ControllerConf.class);
        Mockito.when(Boolean.valueOf(controllerConf.isSegmentRelocatorRebalanceTablesSequentially())).thenReturn(true);
        SegmentRelocator segmentRelocator = new SegmentRelocator((PinotHelixResourceManager) Mockito.mock(PinotHelixResourceManager.class), (LeadControllerManager) Mockito.mock(LeadControllerManager.class), controllerConf, (ControllerMetrics) Mockito.mock(ControllerMetrics.class), (ExecutorService) Mockito.mock(ExecutorService.class), (HttpClientConnectionManager) Mockito.mock(HttpClientConnectionManager.class));
        ExecutorService newCachedThreadPool = Executors.newCachedThreadPool();
        Random random = new Random();
        int i = 10;
        newCachedThreadPool.submit(() -> {
            for (int i2 = 0; i2 < i; i2++) {
                segmentRelocator.putTableToWait("t_" + random.nextInt(i));
                Thread.sleep(10 + random.nextInt(20));
            }
            return null;
        });
        newCachedThreadPool.submit(() -> {
            for (int i2 = 0; i2 < i; i2++) {
                segmentRelocator.putTableToWait("t_" + random.nextInt(i));
                Thread.sleep(10 + random.nextInt(20));
            }
            return null;
        });
        newCachedThreadPool.submit(() -> {
            Thread.sleep(100L);
            for (int i2 = 0; i2 < i; i2++) {
                segmentRelocator.putTableToWait("t_" + i2);
                Thread.sleep(10 + random.nextInt(20));
            }
            return null;
        });
        try {
            BlockingQueue waitingQueue = segmentRelocator.getWaitingQueue();
            TestUtils.waitForCondition(r5 -> {
                return Boolean.valueOf(waitingQueue.size() == i);
            }, 100L, 3000L, "Expecting all tables get in waiting queue");
            HashSet hashSet = new HashSet(waitingQueue);
            for (int i2 = 0; i2 < 10; i2++) {
                Assert.assertTrue(hashSet.contains("t_" + i2));
            }
            segmentRelocator.putTableToWait("t_X");
            Assert.assertEquals(waitingQueue.size(), 10 + 1);
            TestUtils.waitForCondition(r52 -> {
                try {
                    String[] strArr = new String[1];
                    segmentRelocator.rebalanceWaitingTable(str -> {
                        strArr[0] = str;
                    });
                    return Boolean.valueOf(strArr[0].equals("t_X"));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }, 10L, 3000L, "Table t_X should get its turn");
            newCachedThreadPool.shutdownNow();
        } catch (Throwable th) {
            newCachedThreadPool.shutdownNow();
            throw th;
        }
    }

    private static ZNRecord createSegmentMetadataZNRecord(String str, String str2) {
        ZNRecord zNRecord = new ZNRecord(str);
        zNRecord.setVersion(10);
        zNRecord.setSimpleField("segment.tier", str2);
        return zNRecord;
    }
}
