package org.apache.pinot.core.query.optimizer.statement;

import java.util.Iterator;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.FilterKind;

/* loaded from: input_file:org/apache/pinot/core/query/optimizer/statement/StringPredicateFilterOptimizer.class */
public class StringPredicateFilterOptimizer implements StatementOptimizer {
    private static final String MINUS_OPERATOR_NAME = "minus";
    private static final String STRCMP_OPERATOR_NAME = "strcmp";

    @Override // org.apache.pinot.core.query.optimizer.statement.StatementOptimizer
    public void optimize(PinotQuery pinotQuery, @Nullable TableConfig tableConfig, @Nullable Schema schema) {
        if (schema == null) {
            return;
        }
        Expression filterExpression = pinotQuery.getFilterExpression();
        if (filterExpression != null) {
            optimizeExpression(filterExpression, schema);
        }
        Expression havingExpression = pinotQuery.getHavingExpression();
        if (havingExpression != null) {
            optimizeExpression(havingExpression, schema);
        }
    }

    private static void optimizeExpression(Expression expression, Schema schema) {
        if (expression.getType() != ExpressionType.FUNCTION) {
            return;
        }
        Function functionCall = expression.getFunctionCall();
        String operator = functionCall.getOperator();
        List<Expression> operands = functionCall.getOperands();
        if (!operator.equals(FilterKind.AND.name()) && !operator.equals(FilterKind.OR.name()) && !operator.equals(FilterKind.NOT.name())) {
            replaceMinusWithCompareForStrings(operands.get(0), schema);
            return;
        }
        Iterator<Expression> it2 = operands.iterator();
        while (it2.hasNext()) {
            optimizeExpression(it2.next(), schema);
        }
    }

    private static void replaceMinusWithCompareForStrings(Expression expression, Schema schema) {
        if (expression.getType() != ExpressionType.FUNCTION) {
            return;
        }
        Function functionCall = expression.getFunctionCall();
        String operator = functionCall.getOperator();
        List<Expression> operands = functionCall.getOperands();
        if (operator.equals(MINUS_OPERATOR_NAME) && operands.size() == 2 && isString(operands.get(0), schema) && isString(operands.get(1), schema)) {
            functionCall.setOperator(STRCMP_OPERATOR_NAME);
        }
    }

    private static boolean isString(Expression expression, Schema schema) {
        ExpressionType type = expression.getType();
        if (type == ExpressionType.IDENTIFIER) {
            FieldSpec fieldSpecFor = schema.getFieldSpecFor(expression.getIdentifier().getName());
            return fieldSpecFor != null && fieldSpecFor.getDataType() == FieldSpec.DataType.STRING;
        }
        if (type != ExpressionType.FUNCTION) {
            return false;
        }
        Function functionCall = expression.getFunctionCall();
        FunctionInfo lookupFunctionInfo = FunctionRegistry.lookupFunctionInfo(FunctionRegistry.canonicalize(functionCall.getOperator()), functionCall.getOperands().size());
        return lookupFunctionInfo != null && lookupFunctionInfo.getMethod().getReturnType() == String.class;
    }
}
