package org.apache.pinot.core.query.optimizer.filter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.FilterKind;

/* loaded from: input_file:org/apache/pinot/core/query/optimizer/filter/MergeEqInFilterOptimizer.class */
public class MergeEqInFilterOptimizer implements FilterOptimizer {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.apache.pinot.core.query.optimizer.filter.FilterOptimizer
    public Expression optimize(Expression expression, @Nullable Schema schema) {
        return expression.getType() == ExpressionType.FUNCTION ? optimize(expression) : expression;
    }

    private Expression optimize(Expression expression) {
        Function functionCall = expression.getFunctionCall();
        String operator = functionCall.getOperator();
        if (!operator.equals(FilterKind.OR.name())) {
            if (operator.equals(FilterKind.AND.name())) {
                functionCall.getOperands().replaceAll(this::optimize);
                return expression;
            }
            if (!operator.equals(FilterKind.IN.name())) {
                return expression;
            }
            List<Expression> operands = functionCall.getOperands();
            Expression expression2 = operands.get(0);
            HashSet hashSet = new HashSet();
            int size = operands.size();
            for (int i = 1; i < size; i++) {
                hashSet.add(operands.get(i));
            }
            int size2 = hashSet.size();
            return (size2 == 1 || size2 != size - 1) ? getFilterExpression(expression2, hashSet) : expression;
        }
        List<Expression> operands2 = functionCall.getOperands();
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        boolean z = false;
        for (Expression expression3 : operands2) {
            Function functionCall2 = expression3.getFunctionCall();
            String operator2 = functionCall2.getOperator();
            if (!$assertionsDisabled && operator2.equals(FilterKind.OR.name())) {
                throw new AssertionError();
            }
            if (operator2.equals(FilterKind.AND.name()) || operator2.equals(FilterKind.NOT.name())) {
                functionCall2.getOperands().replaceAll(this::optimize);
                arrayList.add(expression3);
            } else if (operator2.equals(FilterKind.EQUALS.name())) {
                List<Expression> operands3 = functionCall2.getOperands();
                Expression expression4 = operands3.get(0);
                Expression expression5 = operands3.get(1);
                Set set = (Set) hashMap.get(expression4);
                if (set == null) {
                    HashSet hashSet2 = new HashSet();
                    hashSet2.add(expression5);
                    hashMap.put(expression4, hashSet2);
                } else {
                    set.add(expression5);
                    z = true;
                }
            } else if (operator2.equals(FilterKind.IN.name())) {
                List<Expression> operands4 = functionCall2.getOperands();
                Expression expression6 = operands4.get(0);
                HashSet hashSet3 = new HashSet();
                int size3 = operands4.size();
                for (int i2 = 1; i2 < size3; i2++) {
                    hashSet3.add(operands4.get(i2));
                }
                int size4 = hashSet3.size();
                if (size4 == 1 || size4 != size3 - 1) {
                    z = true;
                }
                Set set2 = (Set) hashMap.get(expression6);
                if (set2 == null) {
                    hashMap.put(expression6, hashSet3);
                } else {
                    set2.addAll(hashSet3);
                    z = true;
                }
            } else {
                arrayList.add(expression3);
            }
        }
        if (!z) {
            return expression;
        }
        if (arrayList.isEmpty() && hashMap.size() == 1) {
            Map.Entry entry = (Map.Entry) hashMap.entrySet().iterator().next();
            return getFilterExpression((Expression) entry.getKey(), (Set) entry.getValue());
        }
        for (Map.Entry entry2 : hashMap.entrySet()) {
            arrayList.add(getFilterExpression((Expression) entry2.getKey(), (Set) entry2.getValue()));
        }
        functionCall.setOperands(arrayList);
        return expression;
    }

    private static Expression getFilterExpression(Expression expression, Set<Expression> set) {
        int size = set.size();
        if (size == 1) {
            Expression functionExpression = RequestUtils.getFunctionExpression(FilterKind.EQUALS.name());
            functionExpression.getFunctionCall().setOperands(Arrays.asList(expression, set.iterator().next()));
            return functionExpression;
        }
        Expression functionExpression2 = RequestUtils.getFunctionExpression(FilterKind.IN.name());
        ArrayList arrayList = new ArrayList(size + 1);
        arrayList.add(expression);
        arrayList.addAll(set);
        functionExpression2.getFunctionCall().setOperands(arrayList);
        return functionExpression2;
    }

    static {
        $assertionsDisabled = !MergeEqInFilterOptimizer.class.desiredAssertionStatus();
    }
}
