package org.apache.pinot.sql.parsers.rewriter;

import com.google.common.annotations.VisibleForTesting;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.sql.parsers.SqlCompilationException;

/* loaded from: input_file:org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.class */
public class CompileTimeFunctionsInvoker implements QueryRewriter {
    @Override // org.apache.pinot.sql.parsers.rewriter.QueryRewriter
    public PinotQuery rewrite(PinotQuery pinotQuery) {
        for (int i = 0; i < pinotQuery.getSelectListSize(); i++) {
            pinotQuery.getSelectList().set(i, invokeCompileTimeFunctionExpression(pinotQuery.getSelectList().get(i)));
        }
        for (int i2 = 0; i2 < pinotQuery.getGroupByListSize(); i2++) {
            pinotQuery.getGroupByList().set(i2, invokeCompileTimeFunctionExpression(pinotQuery.getGroupByList().get(i2)));
        }
        for (int i3 = 0; i3 < pinotQuery.getOrderByListSize(); i3++) {
            pinotQuery.getOrderByList().set(i3, invokeCompileTimeFunctionExpression(pinotQuery.getOrderByList().get(i3)));
        }
        pinotQuery.setFilterExpression(invokeCompileTimeFunctionExpression(pinotQuery.getFilterExpression()));
        pinotQuery.setHavingExpression(invokeCompileTimeFunctionExpression(pinotQuery.getHavingExpression()));
        return pinotQuery;
    }

    @VisibleForTesting
    public static Expression invokeCompileTimeFunctionExpression(@Nullable Expression expression) {
        FunctionInfo lookupFunctionInfo;
        Object invoke;
        if (expression == null || expression.getFunctionCall() == null) {
            return expression;
        }
        Function functionCall = expression.getFunctionCall();
        List<Expression> operands = functionCall.getOperands();
        int size = operands.size();
        boolean z = true;
        DataSchema.ColumnDataType[] columnDataTypeArr = new DataSchema.ColumnDataType[size];
        Object[] objArr = new Object[size];
        for (int i = 0; i < size; i++) {
            Expression invokeCompileTimeFunctionExpression = invokeCompileTimeFunctionExpression(operands.get(i));
            operands.set(i, invokeCompileTimeFunctionExpression);
            Literal literal = invokeCompileTimeFunctionExpression.getLiteral();
            if (!z || literal == null) {
                z = false;
            } else {
                Pair<DataSchema.ColumnDataType, Object> literalTypeAndValue = RequestUtils.getLiteralTypeAndValue(literal);
                columnDataTypeArr[i] = (DataSchema.ColumnDataType) literalTypeAndValue.getLeft();
                objArr[i] = literalTypeAndValue.getRight();
            }
        }
        if (z && (lookupFunctionInfo = FunctionRegistry.lookupFunctionInfo(FunctionRegistry.canonicalize(functionCall.getOperator()), columnDataTypeArr)) != null) {
            try {
                FunctionInvoker functionInvoker = new FunctionInvoker(lookupFunctionInfo);
                if (functionInvoker.getMethod().isVarArgs()) {
                    invoke = functionInvoker.invoke(new Object[]{objArr});
                } else {
                    functionInvoker.convertTypes(objArr);
                    invoke = functionInvoker.invoke(objArr);
                }
                return RequestUtils.getLiteralExpression(invoke);
            } catch (Exception e) {
                throw new SqlCompilationException("Caught exception while invoking method: " + String.valueOf(lookupFunctionInfo.getMethod()) + " with arguments: " + Arrays.toString(objArr), e);
            }
        }
        return expression;
    }
}
