package org.apache.pinot.sql.parsers.rewriter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.EnumUtils;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.sql.FilterKind;
import org.apache.pinot.sql.parsers.SqlCompilationException;

/* loaded from: input_file:org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.class */
public class PredicateComparisonRewriter implements QueryRewriter {
    @Override // org.apache.pinot.sql.parsers.rewriter.QueryRewriter
    public PinotQuery rewrite(PinotQuery pinotQuery) {
        Expression filterExpression = pinotQuery.getFilterExpression();
        if (filterExpression != null) {
            pinotQuery.setFilterExpression(updatePredicate(filterExpression));
        }
        Expression havingExpression = pinotQuery.getHavingExpression();
        if (havingExpression != null) {
            pinotQuery.setHavingExpression(updatePredicate(havingExpression));
        }
        return pinotQuery;
    }

    private static Expression updatePredicate(Expression expression) {
        switch (expression.getType()) {
            case FUNCTION:
                return updateFunctionExpression(expression);
            case IDENTIFIER:
                return convertPredicateToEqualsBooleanExpression(expression);
            case LITERAL:
                return expression;
            default:
                throw new IllegalStateException();
        }
    }

    private static Expression updateFunctionExpression(Expression expression) {
        Function functionCall = expression.getFunctionCall();
        if (!EnumUtils.isValidEnum(FilterKind.class, functionCall.getOperator())) {
            return convertPredicateToEqualsBooleanExpression(expression);
        }
        FilterKind valueOf = FilterKind.valueOf(functionCall.getOperator());
        List<Expression> operands = functionCall.getOperands();
        switch (valueOf) {
            case AND:
            case OR:
            case NOT:
                for (int i = 0; i < operands.size(); i++) {
                    operands.set(i, updatePredicate(operands.get(i)));
                }
                break;
            case EQUALS:
            case NOT_EQUALS:
            case GREATER_THAN:
            case GREATER_THAN_OR_EQUAL:
            case LESS_THAN:
            case LESS_THAN_OR_EQUAL:
                Expression expression2 = operands.get(0);
                Expression expression3 = operands.get(1);
                if (expression2.isSetLiteral()) {
                    if (!expression3.isSetLiteral()) {
                        functionCall.setOperator(getOppositeOperator(valueOf).name());
                        operands.set(0, expression3);
                        operands.set(1, expression2);
                        break;
                    }
                } else if (!expression3.isSetLiteral()) {
                    Expression functionExpression = RequestUtils.getFunctionExpression("minus");
                    functionExpression.getFunctionCall().setOperands(Arrays.asList(expression2, expression3));
                    operands.set(0, functionExpression);
                    operands.set(1, RequestUtils.getLiteralExpression(0L));
                    break;
                }
                break;
            case VECTOR_SIMILARITY:
                Preconditions.checkArgument(operands.size() >= 2 && operands.size() <= 3, "For %s predicate, the number of operands must be at either 2 or 3, got: %s", valueOf, expression);
                if ((operands.get(1).getFunctionCall() != null && !operands.get(1).getFunctionCall().getOperator().equalsIgnoreCase("arrayvalueconstructor")) || (operands.get(1).getLiteral() != null && !operands.get(1).getLiteral().isSetFloatArrayValue() && !operands.get(1).getLiteral().isSetDoubleArrayValue())) {
                    throw new SqlCompilationException(String.format("For %s predicate, the second operand must be a float/double array literal, got: %s", valueOf, expression));
                }
                if (operands.size() == 3 && operands.get(2).getLiteral() == null) {
                    throw new SqlCompilationException(String.format("For %s predicate, the third operand must be a literal, got: %s", valueOf, expression));
                }
                break;
            default:
                int size = operands.size();
                for (int i2 = 1; i2 < size; i2++) {
                    if (!operands.get(i2).isSetLiteral()) {
                        throw new SqlCompilationException(String.format("For %s predicate, the operands except for the first one must be literal, got: %s", valueOf, expression));
                    }
                }
                break;
        }
        return expression;
    }

    private static Expression convertPredicateToEqualsBooleanExpression(Expression expression) {
        Expression functionExpression = RequestUtils.getFunctionExpression(FilterKind.EQUALS.name());
        ArrayList arrayList = new ArrayList();
        arrayList.add(expression);
        arrayList.add(RequestUtils.getLiteralExpression(true));
        functionExpression.getFunctionCall().setOperands(arrayList);
        return functionExpression;
    }

    private static FilterKind getOppositeOperator(FilterKind filterKind) {
        switch (filterKind) {
            case GREATER_THAN:
                return FilterKind.LESS_THAN;
            case GREATER_THAN_OR_EQUAL:
                return FilterKind.LESS_THAN_OR_EQUAL;
            case LESS_THAN:
                return FilterKind.GREATER_THAN;
            case LESS_THAN_OR_EQUAL:
                return FilterKind.GREATER_THAN_OR_EQUAL;
            default:
                return filterKind;
        }
    }
}
