package org.apache.pinot.query.planner.logical;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.rel.RelDistribution;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.plannode.AggregateNode;
import org.apache.pinot.query.planner.plannode.ExchangeNode;
import org.apache.pinot.query.planner.plannode.FilterNode;
import org.apache.pinot.query.planner.plannode.JoinNode;
import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
import org.apache.pinot.query.planner.plannode.MailboxSendNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;
import org.apache.pinot.query.planner.plannode.ProjectNode;
import org.apache.pinot.query.planner.plannode.SetOpNode;
import org.apache.pinot.query.planner.plannode.SortNode;
import org.apache.pinot.query.planner.plannode.TableScanNode;
import org.apache.pinot.query.planner.plannode.ValueNode;
import org.apache.pinot.query.planner.plannode.WindowNode;

/* loaded from: input_file:org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.class */
public class ShuffleRewriteVisitor implements PlanNodeVisitor<Set<Integer>, Void> {
    public static void optimizeShuffles(PlanNode planNode) {
        planNode.visit(new ShuffleRewriteVisitor(), null);
    }

    private ShuffleRewriteVisitor() {
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitAggregate(AggregateNode aggregateNode, Void r6) {
        return deriveNewPartitionKeysFromRexExpressions(aggregateNode.getGroupSet(), (Set) aggregateNode.getInputs().get(0).visit(this, r6));
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitWindow(WindowNode windowNode, Void r6) {
        throw new UnsupportedOperationException("Window not yet supported!");
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitSetOp(SetOpNode setOpNode, Void r7) {
        HashSet hashSet = new HashSet();
        setOpNode.getInputs().forEach(planNode -> {
            hashSet.addAll((Collection) planNode.visit(this, r7));
        });
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitExchange(ExchangeNode exchangeNode, Void r6) {
        throw new UnsupportedOperationException("Exchange not yet supported!");
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitFilter(FilterNode filterNode, Void r6) {
        return (Set) filterNode.getInputs().get(0).visit(this, r6);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitJoin(JoinNode joinNode, Void r6) {
        Set set = (Set) joinNode.getInputs().get(0).visit(this, r6);
        List<Integer> leftKeys = joinNode.getJoinKeys().getLeftKeys();
        List<Integer> rightKeys = joinNode.getJoinKeys().getRightKeys();
        joinNode.getInputs().get(0).getDataSchema().size();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < leftKeys.size(); i++) {
            int intValue = leftKeys.get(i).intValue();
            rightKeys.get(i).intValue();
            if (set.contains(Integer.valueOf(intValue))) {
                hashSet.add(Integer.valueOf(intValue));
            }
        }
        return hashSet;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitMailboxReceive(MailboxReceiveNode mailboxReceiveNode, Void r6) {
        Set<Integer> set = (Set) mailboxReceiveNode.getSender().visit(this, r6);
        List<Integer> distributionKeys = mailboxReceiveNode.getDistributionKeys();
        if (!canSkipShuffle(set, distributionKeys)) {
            return distributionKeys == null ? new HashSet() : new HashSet(distributionKeys);
        }
        mailboxReceiveNode.setDistributionType(RelDistribution.Type.SINGLETON);
        return set;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitMailboxSend(MailboxSendNode mailboxSendNode, Void r6) {
        Set<Integer> set = (Set) mailboxSendNode.getInputs().get(0).visit(this, r6);
        if (!canSkipShuffle(set, mailboxSendNode.getDistributionKeys())) {
            return new HashSet();
        }
        mailboxSendNode.setDistributionType(RelDistribution.Type.SINGLETON);
        return set;
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitProject(ProjectNode projectNode, Void r6) {
        return deriveNewPartitionKeysFromRexExpressions(projectNode.getProjects(), (Set) projectNode.getInputs().get(0).visit(this, r6));
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitSort(SortNode sortNode, Void r6) {
        return (Set) sortNode.getInputs().get(0).visit(this, r6);
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitTableScan(TableScanNode tableScanNode, Void r5) {
        return new HashSet();
    }

    @Override // org.apache.pinot.query.planner.plannode.PlanNodeVisitor
    public Set<Integer> visitValue(ValueNode valueNode, Void r5) {
        return new HashSet();
    }

    private static boolean canSkipShuffle(Set<Integer> set, @Nullable List<Integer> list) {
        if (set.isEmpty() || list == null) {
            return false;
        }
        return list.containsAll(set);
    }

    private static Set<Integer> deriveNewPartitionKeysFromRexExpressions(List<RexExpression> list, Set<Integer> set) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            RexExpression rexExpression = list.get(i);
            if (rexExpression instanceof RexExpression.InputRef) {
                hashMap.put(Integer.valueOf(((RexExpression.InputRef) rexExpression).getIndex()), Integer.valueOf(i));
            }
        }
        if (!hashMap.keySet().containsAll(set)) {
            return new HashSet();
        }
        HashSet hashSet = new HashSet();
        Iterator<Integer> it2 = set.iterator();
        while (it2.hasNext()) {
            hashSet.add((Integer) hashMap.get(Integer.valueOf(it2.next().intValue())));
        }
        return hashSet;
    }
}
