package org.apache.pinot.broker.routing.segmentpruner;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.apache.helix.model.ExternalView;
import org.apache.helix.model.IdealState;
import org.apache.helix.zookeeper.datamodel.ZNRecord;
import org.apache.pinot.broker.routing.segmentpruner.interval.Interval;
import org.apache.pinot.broker.routing.segmentpruner.interval.IntervalTree;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.Identifier;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.DateTimeFormatSpec;
import org.apache.pinot.sql.FilterKind;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/pinot/broker/routing/segmentpruner/TimeSegmentPruner.class */
public class TimeSegmentPruner implements SegmentPruner {
    private static final Logger LOGGER;
    private static final long MIN_START_TIME = 0;
    private static final long MAX_END_TIME = Long.MAX_VALUE;
    private static final Interval DEFAULT_INTERVAL;
    private final String _tableNameWithType;
    private final String _timeColumn;
    private final DateTimeFormatSpec _timeFormatSpec;
    private volatile IntervalTree<String> _intervalTree;
    private final Map<String, Interval> _intervalMap = new HashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    public TimeSegmentPruner(TableConfig tableConfig, String str, DateTimeFormatSpec dateTimeFormatSpec) {
        this._tableNameWithType = tableConfig.getTableName();
        this._timeColumn = str;
        this._timeFormatSpec = dateTimeFormatSpec;
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public void init(IdealState idealState, ExternalView externalView, List<String> list, List<ZNRecord> list2) {
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            this._intervalMap.put(str, extractIntervalFromSegmentZKMetaZNRecord(str, list2.get(i)));
        }
        this._intervalTree = new IntervalTree<>(this._intervalMap);
    }

    private Interval extractIntervalFromSegmentZKMetaZNRecord(String str, @Nullable ZNRecord zNRecord) {
        if (zNRecord == null) {
            LOGGER.warn("Failed to find segment ZK metadata for segment: {}, table: {}", str, this._tableNameWithType);
            return DEFAULT_INTERVAL;
        }
        long longField = zNRecord.getLongField("segment.start.time", -1L);
        long longField2 = zNRecord.getLongField("segment.end.time", -1L);
        if (longField < 0 || longField2 < 0 || longField > longField2) {
            LOGGER.warn("Failed to find valid time interval for segment: {}, table: {}", str, this._tableNameWithType);
            return DEFAULT_INTERVAL;
        }
        TimeUnit timeUnit = (TimeUnit) zNRecord.getEnumField("segment.time.unit", TimeUnit.class, TimeUnit.DAYS);
        return new Interval(timeUnit.toMillis(longField), timeUnit.toMillis(longField2));
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public synchronized void onAssignmentChange(IdealState idealState, ExternalView externalView, Set<String> set, List<String> list, List<ZNRecord> list2) {
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            ZNRecord zNRecord = list2.get(i);
            this._intervalMap.computeIfAbsent(str, str2 -> {
                return extractIntervalFromSegmentZKMetaZNRecord(str2, zNRecord);
            });
        }
        this._intervalMap.keySet().retainAll(set);
        this._intervalTree = new IntervalTree<>(this._intervalMap);
    }

    @Override // org.apache.pinot.broker.routing.segmentmetadata.SegmentZkMetadataFetchListener
    public synchronized void refreshSegment(String str, @Nullable ZNRecord zNRecord) {
        this._intervalMap.put(str, extractIntervalFromSegmentZKMetaZNRecord(str, zNRecord));
        this._intervalTree = new IntervalTree<>(this._intervalMap);
    }

    @Override // org.apache.pinot.broker.routing.segmentpruner.SegmentPruner
    public Set<String> prune(BrokerRequest brokerRequest, Set<String> set) {
        List<Interval> filterTimeIntervals;
        IntervalTree<String> intervalTree = this._intervalTree;
        Expression filterExpression = brokerRequest.getPinotQuery().getFilterExpression();
        if (filterExpression != null && (filterTimeIntervals = getFilterTimeIntervals(filterExpression)) != null) {
            if (filterTimeIntervals.isEmpty()) {
                return Collections.emptySet();
            }
            HashSet hashSet = new HashSet();
            Iterator<Interval> it2 = filterTimeIntervals.iterator();
            while (it2.hasNext()) {
                for (String str : intervalTree.searchAll(it2.next())) {
                    if (set.contains(str)) {
                        hashSet.add(str);
                    }
                }
            }
            return hashSet;
        }
        return set;
    }

    @Nullable
    private List<Interval> getFilterTimeIntervals(Expression expression) {
        Function functionCall = expression.getFunctionCall();
        FilterKind valueOf = FilterKind.valueOf(functionCall.getOperator());
        List<Expression> operands = functionCall.getOperands();
        switch (valueOf) {
            case AND:
                ArrayList arrayList = new ArrayList();
                Iterator<Expression> it2 = operands.iterator();
                while (it2.hasNext()) {
                    List<Interval> filterTimeIntervals = getFilterTimeIntervals(it2.next());
                    if (filterTimeIntervals != null) {
                        if (filterTimeIntervals.isEmpty()) {
                            return Collections.emptyList();
                        }
                        arrayList.add(filterTimeIntervals);
                    }
                }
                if (arrayList.isEmpty()) {
                    return null;
                }
                return getIntersectionSortedIntervals(arrayList);
            case OR:
                ArrayList arrayList2 = new ArrayList();
                Iterator<Expression> it3 = operands.iterator();
                while (it3.hasNext()) {
                    List<Interval> filterTimeIntervals2 = getFilterTimeIntervals(it3.next());
                    if (filterTimeIntervals2 == null) {
                        return null;
                    }
                    arrayList2.add(filterTimeIntervals2);
                }
                return getUnionSortedIntervals(arrayList2);
            case NOT:
                if (!$assertionsDisabled && operands.size() != 1) {
                    throw new AssertionError();
                }
                List<Interval> filterTimeIntervals3 = getFilterTimeIntervals(operands.get(0));
                if (filterTimeIntervals3 == null) {
                    return null;
                }
                return getComplementSortedIntervals(filterTimeIntervals3);
            case EQUALS:
                Identifier identifier = operands.get(0).getIdentifier();
                if (identifier == null || !identifier.getName().equals(this._timeColumn)) {
                    return null;
                }
                long fromFormatToMillis = this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString());
                return Collections.singletonList(new Interval(fromFormatToMillis, fromFormatToMillis));
            case IN:
                Identifier identifier2 = operands.get(0).getIdentifier();
                if (identifier2 == null || !identifier2.getName().equals(this._timeColumn)) {
                    return null;
                }
                int size = operands.size();
                ArrayList arrayList3 = new ArrayList(size - 1);
                for (int i = 1; i < size; i++) {
                    long fromFormatToMillis2 = this._timeFormatSpec.fromFormatToMillis(operands.get(i).getLiteral().getFieldValue().toString());
                    arrayList3.add(new Interval(fromFormatToMillis2, fromFormatToMillis2));
                }
                return arrayList3;
            case GREATER_THAN:
                Identifier identifier3 = operands.get(0).getIdentifier();
                if (identifier3 == null || !identifier3.getName().equals(this._timeColumn)) {
                    return null;
                }
                return Collections.singletonList(new Interval(this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString()) + 1, Long.MAX_VALUE));
            case GREATER_THAN_OR_EQUAL:
                Identifier identifier4 = operands.get(0).getIdentifier();
                if (identifier4 == null || !identifier4.getName().equals(this._timeColumn)) {
                    return null;
                }
                return Collections.singletonList(new Interval(this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString()), Long.MAX_VALUE));
            case LESS_THAN:
                Identifier identifier5 = operands.get(0).getIdentifier();
                if (identifier5 == null || !identifier5.getName().equals(this._timeColumn)) {
                    return null;
                }
                long fromFormatToMillis3 = this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString());
                return fromFormatToMillis3 > 0 ? Collections.singletonList(new Interval(0L, fromFormatToMillis3 - 1)) : Collections.emptyList();
            case LESS_THAN_OR_EQUAL:
                Identifier identifier6 = operands.get(0).getIdentifier();
                if (identifier6 == null || !identifier6.getName().equals(this._timeColumn)) {
                    return null;
                }
                long fromFormatToMillis4 = this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString());
                return fromFormatToMillis4 >= 0 ? Collections.singletonList(new Interval(0L, fromFormatToMillis4)) : Collections.emptyList();
            case BETWEEN:
                Identifier identifier7 = operands.get(0).getIdentifier();
                if (identifier7 == null || !identifier7.getName().equals(this._timeColumn)) {
                    return null;
                }
                long fromFormatToMillis5 = this._timeFormatSpec.fromFormatToMillis(operands.get(1).getLiteral().getFieldValue().toString());
                long fromFormatToMillis6 = this._timeFormatSpec.fromFormatToMillis(operands.get(2).getLiteral().getFieldValue().toString());
                return fromFormatToMillis6 >= fromFormatToMillis5 ? Collections.singletonList(new Interval(fromFormatToMillis5, fromFormatToMillis6)) : Collections.emptyList();
            case RANGE:
                Identifier identifier8 = operands.get(0).getIdentifier();
                if (identifier8 == null || !identifier8.getName().equals(this._timeColumn)) {
                    return null;
                }
                return parseInterval(operands.get(1).getLiteral().getFieldValue().toString());
            default:
                return null;
        }
    }

    private List<Interval> getIntersectionSortedIntervals(List<List<Interval>> list) {
        return getIntersectionSortedIntervals(list, 0, list.size());
    }

    private List<Interval> getIntersectionSortedIntervals(List<List<Interval>> list, int i, int i2) {
        if (i + 1 == i2) {
            return list.get(i);
        }
        int i3 = i + ((i2 - i) / 2);
        return getIntersectionTwoSortedIntervals(getIntersectionSortedIntervals(list, i, i3), getIntersectionSortedIntervals(list, i3, i2));
    }

    private List<Interval> getIntersectionTwoSortedIntervals(List<Interval> list, List<Interval> list2) {
        ArrayList arrayList = new ArrayList();
        int size = list.size();
        int size2 = list2.size();
        int i = 0;
        int i2 = 0;
        while (i < size && i2 < size2) {
            Interval interval = list.get(i);
            Interval interval2 = list2.get(i2);
            if (interval.intersects(interval2)) {
                arrayList.add(Interval.getIntersection(interval, interval2));
            }
            if (interval._max <= interval2._max) {
                i++;
            } else {
                i2++;
            }
        }
        return arrayList;
    }

    private List<Interval> getUnionSortedIntervals(List<List<Interval>> list) {
        return getUnionSortedIntervals(list, 0, list.size());
    }

    private List<Interval> getUnionSortedIntervals(List<List<Interval>> list, int i, int i2) {
        if (i + 1 == i2) {
            return list.get(i);
        }
        int i3 = i + ((i2 - i) / 2);
        return getUnionTwoSortedIntervals(getUnionSortedIntervals(list, i, i3), getUnionSortedIntervals(list, i3, i2));
    }

    private List<Interval> getUnionTwoSortedIntervals(List<Interval> list, List<Interval> list2) {
        Interval interval;
        ArrayList arrayList = new ArrayList();
        int size = list.size();
        int size2 = list2.size();
        int i = 0;
        int i2 = 0;
        while (true) {
            if (i >= size && i2 >= size2) {
                return arrayList;
            }
            if (i2 == size2 || (i < size && list.get(i).compareTo(list2.get(i2)) <= 0)) {
                int i3 = i;
                i++;
                interval = list.get(i3);
            } else {
                int i4 = i2;
                i2++;
                interval = list2.get(i4);
            }
            int size3 = arrayList.size();
            if (arrayList.isEmpty() || !interval.intersects((Interval) arrayList.get(size3 - 1))) {
                arrayList.add(interval);
            } else {
                arrayList.set(size3 - 1, Interval.getUnion(interval, (Interval) arrayList.get(size3 - 1)));
            }
        }
    }

    private List<Interval> getComplementSortedIntervals(List<Interval> list) {
        ArrayList arrayList = new ArrayList();
        long j = 0;
        for (Interval interval : list) {
            if (interval._min > j) {
                arrayList.add(new Interval(j, interval._min - 1));
            }
            if (interval._max == Long.MAX_VALUE) {
                return arrayList;
            }
            j = interval._max + 1;
        }
        arrayList.add(new Interval(j, Long.MAX_VALUE));
        return arrayList;
    }

    private List<Interval> parseInterval(String str) {
        long j = 0;
        long j2 = Long.MAX_VALUE;
        int length = str.length();
        boolean z = str.charAt(0) == '(';
        boolean z2 = str.charAt(length - 1) == ')';
        String[] split = StringUtils.split(str.substring(1, length - 1), (char) 0);
        if (!split[0].equals("*")) {
            j = this._timeFormatSpec.fromFormatToMillis(split[0]);
            if (z) {
                j++;
            }
        }
        if (!split[1].equals("*")) {
            j2 = this._timeFormatSpec.fromFormatToMillis(split[1]);
            if (z2) {
                j2--;
            }
        }
        return j > j2 ? Collections.emptyList() : Collections.singletonList(new Interval(j, j2));
    }

    static {
        $assertionsDisabled = !TimeSegmentPruner.class.desiredAssertionStatus();
        LOGGER = LoggerFactory.getLogger((Class<?>) TimeSegmentPruner.class);
        DEFAULT_INTERVAL = new Interval(0L, Long.MAX_VALUE);
    }
}
