package org.apache.pinot.query.planner.logical;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.rel.RelDistribution;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
import org.apache.pinot.query.planner.partitioning.KeySelector;
import org.apache.pinot.query.planner.stage.AggregateNode;
import org.apache.pinot.query.planner.stage.FilterNode;
import org.apache.pinot.query.planner.stage.JoinNode;
import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
import org.apache.pinot.query.planner.stage.MailboxSendNode;
import org.apache.pinot.query.planner.stage.ProjectNode;
import org.apache.pinot.query.planner.stage.SortNode;
import org.apache.pinot.query.planner.stage.StageNode;
import org.apache.pinot.query.planner.stage.StageNodeVisitor;
import org.apache.pinot.query.planner.stage.TableScanNode;
import org.apache.pinot.query.planner.stage.ValueNode;

/* loaded from: input_file:org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.class */
public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Void> {
    public static void optimizeShuffles(StageNode stageNode) {
        stageNode.visit(new ShuffleRewriteVisitor(), null);
    }

    private ShuffleRewriteVisitor() {
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitAggregate(AggregateNode aggregateNode, Void r6) {
        return deriveNewPartitionKeysFromRexExpressions(aggregateNode.getGroupSet(), (Set) aggregateNode.getInputs().get(0).visit(this, r6));
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitFilter(FilterNode filterNode, Void r6) {
        return (Set) filterNode.getInputs().get(0).visit(this, r6);
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitJoin(JoinNode joinNode, Void r6) {
        Set set = (Set) joinNode.getInputs().get(0).visit(this, r6);
        FieldSelectionKeySelector fieldSelectionKeySelector = (FieldSelectionKeySelector) joinNode.getJoinKeys().getLeftJoinKeySelector();
        FieldSelectionKeySelector fieldSelectionKeySelector2 = (FieldSelectionKeySelector) joinNode.getJoinKeys().getRightJoinKeySelector();
        joinNode.getInputs().get(0).getDataSchema().size();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < fieldSelectionKeySelector.getColumnIndices().size(); i++) {
            int intValue = fieldSelectionKeySelector.getColumnIndices().get(i).intValue();
            fieldSelectionKeySelector2.getColumnIndices().get(i).intValue();
            if (set.contains(Integer.valueOf(intValue))) {
                hashSet.add(Integer.valueOf(intValue));
            }
        }
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitMailboxReceive(MailboxReceiveNode mailboxReceiveNode, Void r6) {
        Set<Integer> set = (Set) mailboxReceiveNode.getSender().visit(this, r6);
        KeySelector<Object[], Object[]> partitionKeySelector = mailboxReceiveNode.getPartitionKeySelector();
        if (!canSkipShuffle(set, partitionKeySelector)) {
            return partitionKeySelector == null ? new HashSet() : new HashSet(((FieldSelectionKeySelector) partitionKeySelector).getColumnIndices());
        }
        mailboxReceiveNode.setExchangeType(RelDistribution.Type.SINGLETON);
        return set;
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitMailboxSend(MailboxSendNode mailboxSendNode, Void r6) {
        Set<Integer> set = (Set) mailboxSendNode.getInputs().get(0).visit(this, r6);
        if (!canSkipShuffle(set, mailboxSendNode.getPartitionKeySelector())) {
            return new HashSet();
        }
        mailboxSendNode.setExchangeType(RelDistribution.Type.SINGLETON);
        return set;
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitProject(ProjectNode projectNode, Void r6) {
        return deriveNewPartitionKeysFromRexExpressions(projectNode.getProjects(), (Set) projectNode.getInputs().get(0).visit(this, r6));
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitSort(SortNode sortNode, Void r6) {
        return (Set) sortNode.getInputs().get(0).visit(this, r6);
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitTableScan(TableScanNode tableScanNode, Void r5) {
        return new HashSet();
    }

    @Override // org.apache.pinot.query.planner.stage.StageNodeVisitor
    public Set<Integer> visitValue(ValueNode valueNode, Void r5) {
        return new HashSet();
    }

    private static boolean canSkipShuffle(Set<Integer> set, KeySelector<Object[], Object[]> keySelector) {
        if (set.isEmpty() || keySelector == null) {
            return false;
        }
        return new HashSet(((FieldSelectionKeySelector) keySelector).getColumnIndices()).containsAll(set);
    }

    private static Set<Integer> deriveNewPartitionKeysFromRexExpressions(List<RexExpression> list, Set<Integer> set) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            RexExpression rexExpression = list.get(i);
            if (rexExpression instanceof RexExpression.InputRef) {
                hashMap.put(Integer.valueOf(((RexExpression.InputRef) rexExpression).getIndex()), Integer.valueOf(i));
            }
        }
        if (!hashMap.keySet().containsAll(set)) {
            return new HashSet();
        }
        HashSet hashSet = new HashSet();
        Iterator<Integer> it2 = set.iterator();
        while (it2.hasNext()) {
            hashSet.add((Integer) hashMap.get(Integer.valueOf(it2.next().intValue())));
        }
        return hashSet;
    }
}
