package org.apache.pinot.query.planner.explain;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.pinot.query.planner.logical.TransformationTracker;
import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
import org.apache.pinot.query.planner.plannode.PlanNode;

/* loaded from: input_file:org/apache/pinot/query/planner/explain/MultiStageExplainAskingServersUtils.class */
public class MultiStageExplainAskingServersUtils {
    private MultiStageExplainAskingServersUtils() {
    }

    public static RelNode modifyRel(RelNode relNode, Collection<DispatchablePlanFragment> collection, TransformationTracker<PlanNode, RelNode> transformationTracker, AskingServerStageExplainer askingServerStageExplainer) {
        return replace(relNode, createSubstitutionMap((Map) collection.stream().filter(dispatchablePlanFragment -> {
            return dispatchablePlanFragment.getTableName() != null;
        }).collect(Collectors.toMap(Function.identity(), dispatchablePlanFragment2 -> {
            return dispatchablePlanFragment2.getPlanFragment().getFragmentRoot();
        })), transformationTracker, askingServerStageExplainer));
    }

    private static Map<RelNode, RelNode> createSubstitutionMap(Map<DispatchablePlanFragment, PlanNode> map, TransformationTracker<PlanNode, RelNode> transformationTracker, AskingServerStageExplainer askingServerStageExplainer) {
        HashMap hashMap = new HashMap(map.size());
        for (Map.Entry<DispatchablePlanFragment, PlanNode> entry : map.entrySet()) {
            DispatchablePlanFragment key = entry.getKey();
            PlanNode value = entry.getValue();
            RelNode creatorOf = transformationTracker.getCreatorOf(value);
            if (creatorOf == null) {
                throw new IllegalStateException("Cannot find the corresponding RelNode for PlanNode: " + String.valueOf(value));
            }
            if (hashMap.containsKey(creatorOf)) {
                throw new IllegalStateException("Duplicate RelNode found in the leaf nodes: " + String.valueOf(creatorOf));
            }
            hashMap.put(creatorOf, askingServerStageExplainer.explainFragment(key));
        }
        return hashMap;
    }

    private static RelNode replace(RelNode relNode, Map<RelNode, RelNode> map) {
        RelNode relNode2 = map.get(relNode);
        if (relNode2 != null) {
            return relNode2;
        }
        replaceRecursive(relNode, map);
        return relNode;
    }

    private static void replaceRecursive(RelNode relNode, Map<RelNode, RelNode> map) {
        for (int i = 0; i < relNode.getInputs().size(); i++) {
            RelNode input = relNode.getInput(i);
            RelNode relNode2 = map.get(input);
            if (relNode2 != null) {
                relNode.replaceInput(i, relNode2);
            } else {
                replaceRecursive(input, map);
            }
        }
    }
}
