package org.apache.pinot.query.planner.physical.colocated;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rel.RelDistribution;
import org.apache.pinot.common.config.provider.TableCache;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.partitioning.KeySelector;
import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.planner.plannode.ExchangeNode;
import org.apache.pinot.query.planner.plannode.FilterNode;
import org.apache.pinot.query.planner.plannode.JoinNode;
import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
import org.apache.pinot.query.planner.plannode.MailboxSendNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;
import org.apache.pinot.query.planner.plannode.ProjectNode;
import org.apache.pinot.query.planner.plannode.SetOpNode;
import org.apache.pinot.query.planner.plannode.SortNode;
import org.apache.pinot.query.planner.plannode.TableScanNode;
import org.apache.pinot.query.planner.plannode.ValueNode;
import org.apache.pinot.query.planner.plannode.WindowNode;
import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
import org.apache.pinot.spi.config.table.IndexingConfig;
import org.apache.pinot.spi.config.table.TableConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.class */
public class GreedyShuffleRewriteVisitor implements PlanNodeVisitor<Set<ColocationKey>, GreedyShuffleRewriteContext> {
    private static final Logger LOGGER = LoggerFactory.getLogger(GreedyShuffleRewriteVisitor.class);
    private final TableCache _tableCache;
    private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
    private boolean _canSkipShuffleForJoin = false;

    public static void optimizeShuffles(PlanNode planNode, Map<Integer, DispatchablePlanMetadata> map, TableCache tableCache) {
        GreedyShuffleRewriteContext preComputeContext = GreedyShuffleRewritePreComputeVisitor.preComputeContext(planNode);
        for (int size = map.size() - 1; size >= 0; size--) {
            preComputeContext.getRootStageNode(Integer.valueOf(size)).visit(new GreedyShuffleRewriteVisitor(tableCache, map), preComputeContext);
        }
    }

    private GreedyShuffleRewriteVisitor(TableCache tableCache, Map<Integer, DispatchablePlanMetadata> map) {
        this._tableCache = tableCache;
        this._dispatchablePlanMetadataMap = map;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitAggregate(AggregateNode aggregateNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Set set = (Set) aggregateNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < aggregateNode.getGroupSet().size(); i++) {
            RexExpression rexExpression = aggregateNode.getGroupSet().get(i);
            if (rexExpression instanceof RexExpression.InputRef) {
                hashMap.put(Integer.valueOf(((RexExpression.InputRef) rexExpression).getIndex()), Integer.valueOf(i));
            }
        }
        return computeNewColocationKeys(set, hashMap);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitFilter(FilterNode filterNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        return (Set) filterNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitJoin(JoinNode joinNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        List list = (List) greedyShuffleRewriteContext.getLeafNodes(Integer.valueOf(joinNode.getPlanFragmentId())).stream().map(planNode -> {
            return (MailboxReceiveNode) planNode;
        }).collect(Collectors.toList());
        Preconditions.checkState(list.size() == 2);
        if ((((canJoinBeColocated(joinNode) && canServerAssignmentAllowShuffleSkip(joinNode.getPlanFragmentId(), ((MailboxReceiveNode) list.get(0)).getSenderStageId(), ((MailboxReceiveNode) list.get(1)).getSenderStageId())) && partitionKeyConditionForJoin((MailboxReceiveNode) list.get(0), ((MailboxReceiveNode) list.get(0)).getSender(), greedyShuffleRewriteContext)) && partitionKeyConditionForJoin((MailboxReceiveNode) list.get(1), ((MailboxReceiveNode) list.get(1)).getSender(), greedyShuffleRewriteContext)) && checkPartitionScheme((MailboxReceiveNode) list.get(0), (MailboxReceiveNode) list.get(1), greedyShuffleRewriteContext)) {
            this._dispatchablePlanMetadataMap.get(Integer.valueOf(joinNode.getPlanFragmentId())).setWorkerIdToServerInstanceMap(this._dispatchablePlanMetadataMap.get(Integer.valueOf(((MailboxReceiveNode) list.get(0)).getSenderStageId())).getWorkerIdToServerInstanceMap());
            this._canSkipShuffleForJoin = true;
        }
        Set set = (Set) joinNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
        Set<ColocationKey> set2 = (Set) joinNode.getInputs().get(1).visit(this, greedyShuffleRewriteContext);
        int size = joinNode.getInputs().get(0).getDataSchema().size();
        HashSet hashSet = new HashSet(set);
        for (ColocationKey colocationKey : set2) {
            ColocationKey colocationKey2 = new ColocationKey(colocationKey.getNumPartitions(), colocationKey.getHashAlgorithm());
            Iterator<Integer> it = colocationKey.getIndices().iterator();
            while (it.hasNext()) {
                colocationKey2.addIndex(size + it.next().intValue());
            }
            hashSet.add(colocationKey2);
        }
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitMailboxReceive(MailboxReceiveNode mailboxReceiveNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Set<ColocationKey> colocationKeys = greedyShuffleRewriteContext.getColocationKeys(Integer.valueOf(mailboxReceiveNode.getSenderStageId()));
        List<Integer> distributionKeys = mailboxReceiveNode.getDistributionKeys();
        if (greedyShuffleRewriteContext.isJoinStage(Integer.valueOf(mailboxReceiveNode.getPlanFragmentId()))) {
            if (this._canSkipShuffleForJoin) {
                mailboxReceiveNode.setDistributionType(RelDistribution.Type.SINGLETON);
                mailboxReceiveNode.getSender().setDistributionType(RelDistribution.Type.SINGLETON);
                return colocationKeys;
            }
            if (distributionKeys == null) {
                return new HashSet();
            }
            int size = new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(mailboxReceiveNode.getPlanFragmentId())).getWorkerIdToServerInstanceMap().values()).size();
            return new HashSet((List) distributionKeys.stream().map(num -> {
                return new ColocationKey(num.intValue(), size, KeySelector.DEFAULT_HASH_ALGORITHM);
            }).collect(Collectors.toList()));
        }
        if (distributionKeys == null) {
            return new HashSet();
        }
        if (!colocationKeyCondition(colocationKeys, distributionKeys) || !areServersSuperset(mailboxReceiveNode.getPlanFragmentId(), mailboxReceiveNode.getSenderStageId())) {
            int size2 = new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(mailboxReceiveNode.getPlanFragmentId())).getWorkerIdToServerInstanceMap().values()).size();
            return new HashSet((List) distributionKeys.stream().map(num2 -> {
                return new ColocationKey(num2.intValue(), size2, KeySelector.DEFAULT_HASH_ALGORITHM);
            }).collect(Collectors.toList()));
        }
        mailboxReceiveNode.setDistributionType(RelDistribution.Type.SINGLETON);
        this._dispatchablePlanMetadataMap.get(Integer.valueOf(mailboxReceiveNode.getPlanFragmentId())).setWorkerIdToServerInstanceMap(this._dispatchablePlanMetadataMap.get(Integer.valueOf(mailboxReceiveNode.getSenderStageId())).getWorkerIdToServerInstanceMap());
        return colocationKeys;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitMailboxSend(MailboxSendNode mailboxSendNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Set<ColocationKey> hashSet;
        Set<ColocationKey> set = (Set) mailboxSendNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
        boolean colocationKeyCondition = colocationKeyCondition(set, mailboxSendNode.getDistributionKeys());
        if (greedyShuffleRewriteContext.isJoinStage(Integer.valueOf(mailboxSendNode.getReceiverStageId()))) {
            Set<ColocationKey> hashSet2 = colocationKeyCondition ? set : new HashSet<>();
            greedyShuffleRewriteContext.setColocationKeys(Integer.valueOf(mailboxSendNode.getPlanFragmentId()), hashSet2);
            return hashSet2;
        }
        if (colocationKeyCondition && areServersSuperset(mailboxSendNode.getReceiverStageId(), mailboxSendNode.getPlanFragmentId())) {
            mailboxSendNode.setDistributionType(RelDistribution.Type.SINGLETON);
            hashSet = set;
        } else {
            hashSet = new HashSet();
        }
        greedyShuffleRewriteContext.setColocationKeys(Integer.valueOf(mailboxSendNode.getPlanFragmentId()), hashSet);
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitProject(ProjectNode projectNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Set set = (Set) projectNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < projectNode.getProjects().size(); i++) {
            RexExpression rexExpression = projectNode.getProjects().get(i);
            if (rexExpression instanceof RexExpression.InputRef) {
                hashMap.put(Integer.valueOf(((RexExpression.InputRef) rexExpression).getIndex()), Integer.valueOf(i));
            }
        }
        return computeNewColocationKeys(set, hashMap);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitSort(SortNode sortNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        return (Set) sortNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitWindow(WindowNode windowNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        return (Set) windowNode.getInputs().get(0).visit(this, greedyShuffleRewriteContext);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitSetOp(SetOpNode setOpNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        return ImmutableSet.of();
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitExchange(ExchangeNode exchangeNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        throw new UnsupportedOperationException("ExchangeNode should not be visited by this visitor");
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitTableScan(TableScanNode tableScanNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Map columnPartitionMap;
        TableConfig tableConfig = this._tableCache.getTableConfig(tableScanNode.getTableName());
        if (tableConfig == null) {
            LOGGER.warn("Couldn't find tableConfig for {}", tableScanNode.getTableName());
            return new HashSet();
        }
        IndexingConfig indexingConfig = tableConfig.getIndexingConfig();
        if (indexingConfig == null || indexingConfig.getSegmentPartitionConfig() == null || (columnPartitionMap = indexingConfig.getSegmentPartitionConfig().getColumnPartitionMap()) == null) {
            return new HashSet();
        }
        Set keySet = columnPartitionMap.keySet();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < tableScanNode.getTableScanColumns().size(); i++) {
            String str = tableScanNode.getTableScanColumns().get(i);
            if (keySet.contains(tableScanNode.getTableScanColumns().get(i))) {
                hashSet.add(new ColocationKey(i, ((ColumnPartitionConfig) columnPartitionMap.get(str)).getNumPartitions(), ((ColumnPartitionConfig) columnPartitionMap.get(str)).getFunctionName()));
            }
        }
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<ColocationKey> visitValue(ValueNode valueNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        return new HashSet();
    }

    private boolean canJoinBeColocated(JoinNode joinNode) {
        return true;
    }

    private boolean areServersSuperset(int i, int i2) {
        return new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(i)).getWorkerIdToServerInstanceMap().values()).containsAll(this._dispatchablePlanMetadataMap.get(Integer.valueOf(i2)).getWorkerIdToServerInstanceMap().values());
    }

    private boolean canServerAssignmentAllowShuffleSkip(int i, int i2, int i3) {
        HashSet hashSet = new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(i2)).getWorkerIdToServerInstanceMap().values());
        HashSet hashSet2 = new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(i3)).getWorkerIdToServerInstanceMap().values());
        return hashSet.containsAll(hashSet2) && hashSet.size() == hashSet2.size() && new HashSet(this._dispatchablePlanMetadataMap.get(Integer.valueOf(i)).getWorkerIdToServerInstanceMap().values()).containsAll(hashSet);
    }

    private static Set<ColocationKey> computeNewColocationKeys(Set<ColocationKey> set, Map<Integer, Integer> map) {
        HashSet hashSet = new HashSet();
        for (ColocationKey colocationKey : set) {
            boolean z = false;
            ColocationKey colocationKey2 = new ColocationKey(colocationKey.getNumPartitions(), colocationKey.getHashAlgorithm());
            Iterator<Integer> it = colocationKey.getIndices().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Integer next = it.next();
                if (!map.containsKey(next)) {
                    z = true;
                    break;
                }
                colocationKey2.addIndex(map.get(next).intValue());
            }
            if (!z) {
                hashSet.add(colocationKey2);
            }
        }
        return hashSet;
    }

    private static boolean colocationKeyCondition(Set<ColocationKey> set, @Nullable List<Integer> list) {
        if (set.isEmpty() || list == null) {
            return false;
        }
        for (ColocationKey colocationKey : set) {
            if (list.size() >= colocationKey.getIndices().size() && list.subList(0, colocationKey.getIndices().size()).equals(colocationKey.getIndices())) {
                return true;
            }
        }
        return false;
    }

    private static boolean partitionKeyConditionForJoin(MailboxReceiveNode mailboxReceiveNode, MailboxSendNode mailboxSendNode, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        Set<ColocationKey> colocationKeys = greedyShuffleRewriteContext.getColocationKeys(Integer.valueOf(mailboxSendNode.getPlanFragmentId()));
        if (colocationKeyCondition(colocationKeys, mailboxSendNode.getDistributionKeys())) {
            return colocationKeyCondition(colocationKeys, mailboxReceiveNode.getDistributionKeys());
        }
        return false;
    }

    private static ColocationKey getEquivalentSenderKey(Set<ColocationKey> set, List<Integer> list) {
        if (!set.isEmpty() && list != null) {
            for (ColocationKey colocationKey : set) {
                if (list.size() >= colocationKey.getIndices().size() && list.subList(0, colocationKey.getIndices().size()).equals(colocationKey.getIndices())) {
                    return colocationKey;
                }
            }
        }
        throw new IllegalStateException("Receiver's Equivalent Key in Sender Can't be Determined. This indicates a bug.");
    }

    private static boolean checkPartitionScheme(MailboxReceiveNode mailboxReceiveNode, MailboxReceiveNode mailboxReceiveNode2, GreedyShuffleRewriteContext greedyShuffleRewriteContext) {
        int senderStageId = mailboxReceiveNode.getSenderStageId();
        int senderStageId2 = mailboxReceiveNode2.getSenderStageId();
        ColocationKey equivalentSenderKey = getEquivalentSenderKey(greedyShuffleRewriteContext.getColocationKeys(Integer.valueOf(senderStageId)), mailboxReceiveNode.getDistributionKeys());
        ColocationKey equivalentSenderKey2 = getEquivalentSenderKey(greedyShuffleRewriteContext.getColocationKeys(Integer.valueOf(senderStageId2)), mailboxReceiveNode2.getDistributionKeys());
        if (equivalentSenderKey.getNumPartitions() != equivalentSenderKey2.getNumPartitions()) {
            return false;
        }
        return equivalentSenderKey.getHashAlgorithm().equals(equivalentSenderKey2.getHashAlgorithm());
    }
}
