package org.apache.pinot.query.planner.logical;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.pinot.query.context.PlannerContext;
import org.apache.pinot.query.planner.QueryPlan;
import org.apache.pinot.query.planner.StageMetadata;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
import org.apache.pinot.query.planner.partitioning.KeySelector;
import org.apache.pinot.query.planner.stage.AggregateNode;
import org.apache.pinot.query.planner.stage.FilterNode;
import org.apache.pinot.query.planner.stage.JoinNode;
import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
import org.apache.pinot.query.planner.stage.MailboxSendNode;
import org.apache.pinot.query.planner.stage.ProjectNode;
import org.apache.pinot.query.planner.stage.StageNode;
import org.apache.pinot.query.planner.stage.TableScanNode;
import org.apache.pinot.query.routing.WorkerManager;

/* loaded from: input_file:org/apache/pinot/query/planner/logical/StagePlanner.class */
public class StagePlanner {
    private final PlannerContext _plannerContext;
    private final WorkerManager _workerManager;
    private Map<Integer, StageNode> _queryStageMap;
    private Map<Integer, StageMetadata> _stageMetadataMap;
    private int _stageIdCounter;

    public StagePlanner(PlannerContext plannerContext, WorkerManager workerManager) {
        this._plannerContext = plannerContext;
        this._workerManager = workerManager;
    }

    public QueryPlan makePlan(RelNode relNode) {
        this._queryStageMap = new HashMap();
        this._stageMetadataMap = new HashMap();
        this._stageIdCounter = 1;
        StageNode walkRelPlan = walkRelPlan(relNode, getNewStageId());
        MailboxReceiveNode mailboxReceiveNode = new MailboxReceiveNode(0, walkRelPlan.getDataSchema(), walkRelPlan.getStageId(), RelDistribution.Type.RANDOM_DISTRIBUTED, null);
        MailboxSendNode mailboxSendNode = new MailboxSendNode(walkRelPlan.getStageId(), walkRelPlan.getDataSchema(), mailboxReceiveNode.getStageId(), RelDistribution.Type.RANDOM_DISTRIBUTED, null);
        mailboxSendNode.addInput(walkRelPlan);
        this._queryStageMap.put(Integer.valueOf(mailboxSendNode.getStageId()), mailboxSendNode);
        this._stageMetadataMap.get(Integer.valueOf(mailboxSendNode.getStageId())).attach(mailboxSendNode);
        this._queryStageMap.put(Integer.valueOf(mailboxReceiveNode.getStageId()), mailboxReceiveNode);
        StageMetadata stageMetadata = new StageMetadata();
        stageMetadata.attach(mailboxReceiveNode);
        this._stageMetadataMap.put(Integer.valueOf(mailboxReceiveNode.getStageId()), stageMetadata);
        for (Map.Entry<Integer, StageMetadata> entry : this._stageMetadataMap.entrySet()) {
            this._workerManager.assignWorkerToStage(entry.getKey().intValue(), entry.getValue());
        }
        return new QueryPlan(this._queryStageMap, this._stageMetadataMap);
    }

    private StageNode walkRelPlan(RelNode relNode, int i) {
        MailboxReceiveNode mailboxReceiveNode;
        MailboxSendNode mailboxSendNode;
        if (!isExchangeNode(relNode)) {
            StageNode stageNode = RelToStageConverter.toStageNode(relNode, i);
            Iterator it = relNode.getInputs().iterator();
            while (it.hasNext()) {
                stageNode.addInput(walkRelPlan((RelNode) it.next(), i));
            }
            updateStageMetadata(i, stageNode, this._stageMetadataMap);
            return stageNode;
        }
        StageNode walkRelPlan = walkRelPlan(relNode.getInput(0), getNewStageId());
        RelDistribution distribution = ((LogicalExchange) relNode).getDistribution();
        List keys = distribution.getKeys();
        RelDistribution.Type type = distribution.getType();
        FieldSelectionKeySelector fieldSelectionKeySelector = type == RelDistribution.Type.HASH_DISTRIBUTED ? new FieldSelectionKeySelector((List<Integer>) keys) : null;
        if (canSkipShuffle(walkRelPlan, fieldSelectionKeySelector)) {
            mailboxReceiveNode = new MailboxReceiveNode(i, walkRelPlan.getDataSchema(), walkRelPlan.getStageId(), RelDistribution.Type.SINGLETON, fieldSelectionKeySelector);
            mailboxSendNode = new MailboxSendNode(walkRelPlan.getStageId(), walkRelPlan.getDataSchema(), mailboxReceiveNode.getStageId(), RelDistribution.Type.SINGLETON, fieldSelectionKeySelector);
        } else {
            mailboxReceiveNode = new MailboxReceiveNode(i, walkRelPlan.getDataSchema(), walkRelPlan.getStageId(), type, fieldSelectionKeySelector);
            mailboxSendNode = new MailboxSendNode(walkRelPlan.getStageId(), walkRelPlan.getDataSchema(), mailboxReceiveNode.getStageId(), type, fieldSelectionKeySelector);
        }
        mailboxSendNode.addInput(walkRelPlan);
        this._queryStageMap.put(Integer.valueOf(mailboxSendNode.getStageId()), mailboxSendNode);
        updateStageMetadata(mailboxSendNode.getStageId(), mailboxSendNode, this._stageMetadataMap);
        updateStageMetadata(mailboxReceiveNode.getStageId(), mailboxReceiveNode, this._stageMetadataMap);
        return mailboxReceiveNode;
    }

    private boolean canSkipShuffle(StageNode stageNode, KeySelector<Object[], Object[]> keySelector) {
        Set<Integer> partitionKeys = stageNode.getPartitionKeys();
        if (partitionKeys.isEmpty() || keySelector == null) {
            return false;
        }
        return new HashSet(((FieldSelectionKeySelector) keySelector).getColumnIndices()).containsAll(partitionKeys);
    }

    private static void updateStageMetadata(int i, StageNode stageNode, Map<Integer, StageMetadata> map) {
        updatePartitionKeys(stageNode);
        map.computeIfAbsent(Integer.valueOf(i), num -> {
            return new StageMetadata();
        }).attach(stageNode);
    }

    private static void updatePartitionKeys(StageNode stageNode) {
        FieldSelectionKeySelector fieldSelectionKeySelector;
        if (stageNode instanceof ProjectNode) {
            Set<Integer> partitionKeys = stageNode.getInputs().get(0).getPartitionKeys();
            HashSet hashSet = new HashSet();
            ProjectNode projectNode = (ProjectNode) stageNode;
            for (int i = 0; i < projectNode.getProjects().size(); i++) {
                RexExpression rexExpression = projectNode.getProjects().get(i);
                if ((rexExpression instanceof RexExpression.InputRef) && partitionKeys.contains(Integer.valueOf(((RexExpression.InputRef) rexExpression).getIndex()))) {
                    hashSet.add(Integer.valueOf(i));
                }
            }
            projectNode.setPartitionKeys(hashSet);
            return;
        }
        if (stageNode instanceof FilterNode) {
            stageNode.setPartitionKeys(stageNode.getInputs().get(0).getPartitionKeys());
            return;
        }
        if (stageNode instanceof AggregateNode) {
            Set<Integer> partitionKeys2 = stageNode.getInputs().get(0).getPartitionKeys();
            HashSet hashSet2 = new HashSet();
            AggregateNode aggregateNode = (AggregateNode) stageNode;
            for (int i2 = 0; i2 < aggregateNode.getGroupSet().size(); i2++) {
                RexExpression rexExpression2 = aggregateNode.getGroupSet().get(i2);
                if ((rexExpression2 instanceof RexExpression.InputRef) && partitionKeys2.contains(Integer.valueOf(((RexExpression.InputRef) rexExpression2).getIndex()))) {
                    hashSet2.add(Integer.valueOf(i2));
                }
            }
            aggregateNode.setPartitionKeys(hashSet2);
            return;
        }
        if (!(stageNode instanceof JoinNode)) {
            if (stageNode instanceof TableScanNode) {
                return;
            }
            if (stageNode instanceof MailboxReceiveNode) {
                FieldSelectionKeySelector fieldSelectionKeySelector2 = (FieldSelectionKeySelector) ((MailboxReceiveNode) stageNode).getPartitionKeySelector();
                if (fieldSelectionKeySelector2 != null) {
                    stageNode.setPartitionKeys(new HashSet<>(fieldSelectionKeySelector2.getColumnIndices()));
                    return;
                }
                return;
            }
            if (!(stageNode instanceof MailboxSendNode) || (fieldSelectionKeySelector = (FieldSelectionKeySelector) ((MailboxSendNode) stageNode).getPartitionKeySelector()) == null) {
                return;
            }
            stageNode.setPartitionKeys(new HashSet<>(fieldSelectionKeySelector.getColumnIndices()));
            return;
        }
        int size = stageNode.getInputs().get(0).getDataSchema().size();
        Set<Integer> partitionKeys3 = stageNode.getInputs().get(0).getPartitionKeys();
        Set<Integer> partitionKeys4 = stageNode.getInputs().get(1).getPartitionKeys();
        FieldSelectionKeySelector fieldSelectionKeySelector3 = (FieldSelectionKeySelector) ((JoinNode) stageNode).getCriteria().get(0).getLeftJoinKeySelector();
        FieldSelectionKeySelector fieldSelectionKeySelector4 = (FieldSelectionKeySelector) ((JoinNode) stageNode).getCriteria().get(0).getRightJoinKeySelector();
        HashSet hashSet3 = new HashSet();
        for (int i3 = 0; i3 < fieldSelectionKeySelector3.getColumnIndices().size(); i3++) {
            int intValue = fieldSelectionKeySelector3.getColumnIndices().get(i3).intValue();
            int intValue2 = fieldSelectionKeySelector4.getColumnIndices().get(i3).intValue();
            if (partitionKeys3.contains(Integer.valueOf(intValue))) {
                hashSet3.add(Integer.valueOf(i3));
            }
            if (partitionKeys4.contains(Integer.valueOf(intValue2))) {
                hashSet3.add(Integer.valueOf(size + i3));
            }
        }
        stageNode.setPartitionKeys(hashSet3);
    }

    private boolean isExchangeNode(RelNode relNode) {
        return relNode instanceof LogicalExchange;
    }

    private int getNewStageId() {
        int i = this._stageIdCounter;
        this._stageIdCounter = i + 1;
        return i;
    }
}
