package org.apache.pinot.broker.routing.segmentpruner;

import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.helix.model.ExternalView;
import org.apache.helix.model.IdealState;
import org.apache.helix.zookeeper.datamodel.ZNRecord;
import org.apache.pinot.broker.routing.segmentpartition.SegmentPartitionInfo;
import org.apache.pinot.broker.routing.segmentpartition.SegmentPartitionUtils;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.Identifier;
import org.apache.pinot.sql.FilterKind;

/* loaded from: input_file:org/apache/pinot/broker/routing/segmentpruner/SinglePartitionColumnSegmentPruner.class */
public class SinglePartitionColumnSegmentPruner implements SegmentPruner {
    private final String _tableNameWithType;
    private final String _partitionColumn;
    private final Map<String, SegmentPartitionInfo> _partitionInfoMap = new ConcurrentHashMap();

    public SinglePartitionColumnSegmentPruner(String str, String str2) {
        this._tableNameWithType = str;
        this._partitionColumn = str2;
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public void init(IdealState idealState, ExternalView externalView, List<String> list, List<ZNRecord> list2) {
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            SegmentPartitionInfo extractPartitionInfo = SegmentPartitionUtils.extractPartitionInfo(this._tableNameWithType, this._partitionColumn, str, list2.get(i));
            if (extractPartitionInfo != null) {
                this._partitionInfoMap.put(str, extractPartitionInfo);
            }
        }
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public synchronized void onAssignmentChange(IdealState idealState, ExternalView externalView, Set<String> set, List<String> list, List<ZNRecord> list2) {
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            ZNRecord zNRecord = list2.get(i);
            this._partitionInfoMap.computeIfAbsent(str, str2 -> {
                return SegmentPartitionUtils.extractPartitionInfo(this._tableNameWithType, this._partitionColumn, str2, zNRecord);
            });
        }
        this._partitionInfoMap.keySet().retainAll(set);
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public synchronized void refreshSegment(String str, @Nullable ZNRecord zNRecord) {
        SegmentPartitionInfo extractPartitionInfo = SegmentPartitionUtils.extractPartitionInfo(this._tableNameWithType, this._partitionColumn, str, zNRecord);
        if (extractPartitionInfo != null) {
            this._partitionInfoMap.put(str, extractPartitionInfo);
        } else {
            this._partitionInfoMap.remove(str);
        }
    }

    @Override // org.apache.pinot.broker.routing.segmentpruner.SegmentPruner
    public Set<String> prune(BrokerRequest brokerRequest, Set<String> set) {
        Expression filterExpression = brokerRequest.getPinotQuery().getFilterExpression();
        if (filterExpression == null) {
            return set;
        }
        HashSet hashSet = new HashSet();
        for (String str : set) {
            SegmentPartitionInfo segmentPartitionInfo = this._partitionInfoMap.get(str);
            if (segmentPartitionInfo == null || segmentPartitionInfo == SegmentPartitionUtils.INVALID_PARTITION_INFO || isPartitionMatch(filterExpression, segmentPartitionInfo)) {
                hashSet.add(str);
            }
        }
        return hashSet;
    }

    private boolean isPartitionMatch(Expression expression, SegmentPartitionInfo segmentPartitionInfo) {
        Function functionCall = expression.getFunctionCall();
        FilterKind valueOf = FilterKind.valueOf(functionCall.getOperator());
        List<Expression> operands = functionCall.getOperands();
        switch (valueOf) {
            case AND:
                Iterator<Expression> it2 = operands.iterator();
                while (it2.hasNext()) {
                    if (!isPartitionMatch(it2.next(), segmentPartitionInfo)) {
                        return false;
                    }
                }
                return true;
            case OR:
                Iterator<Expression> it3 = operands.iterator();
                while (it3.hasNext()) {
                    if (isPartitionMatch(it3.next(), segmentPartitionInfo)) {
                        return true;
                    }
                }
                return false;
            case EQUALS:
                Identifier identifier = operands.get(0).getIdentifier();
                if (identifier == null || !identifier.getName().equals(this._partitionColumn)) {
                    return true;
                }
                return segmentPartitionInfo.getPartitions().contains(Integer.valueOf(segmentPartitionInfo.getPartitionFunction().getPartition(operands.get(1).getLiteral().getFieldValue().toString())));
            case IN:
                Identifier identifier2 = operands.get(0).getIdentifier();
                if (identifier2 == null || !identifier2.getName().equals(this._partitionColumn)) {
                    return true;
                }
                int size = operands.size();
                for (int i = 1; i < size; i++) {
                    if (segmentPartitionInfo.getPartitions().contains(Integer.valueOf(segmentPartitionInfo.getPartitionFunction().getPartition(operands.get(i).getLiteral().getFieldValue().toString())))) {
                        return true;
                    }
                }
                return false;
            default:
                return true;
        }
    }
}
