package org.apache.pinot.core.operator.transform.function;

import java.util.List;
import java.util.Map;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.common.function.scalar.VectorFunctions;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;

/* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions.class */
public class VectorTransformFunctions {

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$CosineDistanceTransformFunction.class */
    public static class CosineDistanceTransformFunction extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "cosineDistance";
        private Double _defaultValue = null;

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction
        protected void checkArgumentSize(List<TransformFunction> list) {
            if (list.size() < 2 || list.size() > 3) {
                throw new IllegalArgumentException("2 or 3 arguments are required for CosineDistance function");
            }
        }

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction, org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public void init(List<TransformFunction> list, Map<String, ColumnContext> map) {
            super.init(list, map);
            if (list.size() == 3) {
                this._defaultValue = Double.valueOf(((LiteralTransformFunction) list.get(2)).getDoubleLiteral());
            }
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction
        protected double computeDistance(float[] fArr, float[] fArr2) {
            return this._defaultValue != null ? VectorFunctions.cosineDistance(fArr, fArr2, this._defaultValue.doubleValue()) : VectorFunctions.cosineDistance(fArr, fArr2);
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$InnerProductTransformFunction.class */
    public static class InnerProductTransformFunction extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "innerProduct";

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction
        protected double computeDistance(float[] fArr, float[] fArr2) {
            return VectorFunctions.innerProduct(fArr, fArr2);
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$L1DistanceTransformFunction.class */
    public static class L1DistanceTransformFunction extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "l1Distance";

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction
        protected double computeDistance(float[] fArr, float[] fArr2) {
            return VectorFunctions.l1Distance(fArr, fArr2);
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$L2DistanceTransformFunction.class */
    public static class L2DistanceTransformFunction extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "l2Distance";

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDistanceTransformFunction
        protected double computeDistance(float[] fArr, float[] fArr2) {
            return VectorFunctions.l2Distance(fArr, fArr2);
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$VectorDimsTransformFunction.class */
    public static class VectorDimsTransformFunction extends BaseTransformFunction {
        public static final String FUNCTION_NAME = "vectorDims";
        private TransformFunction _transformFunction;

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public void init(List<TransformFunction> list, Map<String, ColumnContext> map) {
            super.init(list, map);
            if (list.size() != 1) {
                throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function");
            }
            this._transformFunction = list.get(0);
            Preconditions.checkArgument(!this._transformFunction.getResultMetadata().isSingleValue(), "Argument must be multi-valued float vector for vector distance transform function: %s", getName());
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public TransformResultMetadata getResultMetadata() {
            return INT_SV_NO_DICTIONARY_METADATA;
        }

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public int[] transformToIntValuesSV(ValueBlock valueBlock) {
            int numDocs = valueBlock.getNumDocs();
            initIntValuesSV(numDocs);
            float[][] transformToFloatValuesMV = this._transformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < numDocs; i++) {
                this._intValuesSV[i] = VectorFunctions.vectorDims(transformToFloatValuesMV[i]);
            }
            return this._intValuesSV;
        }
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$VectorDistanceTransformFunction.class */
    public static abstract class VectorDistanceTransformFunction extends BaseTransformFunction {
        protected TransformFunction _leftTransformFunction;
        protected TransformFunction _rightTransformFunction;

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public void init(List<TransformFunction> list, Map<String, ColumnContext> map) {
            super.init(list, map);
            checkArgumentSize(list);
            this._leftTransformFunction = list.get(0);
            this._rightTransformFunction = list.get(1);
            Preconditions.checkArgument((this._leftTransformFunction.getResultMetadata().isSingleValue() || this._rightTransformFunction.getResultMetadata().isSingleValue()) ? false : true, "Argument must be multi-valued float vector for vector distance transform function: %s", getName());
        }

        protected void checkArgumentSize(List<TransformFunction> list) {
            if (list.size() != 2) {
                throw new IllegalArgumentException("Exactly 2 arguments are required for Vector transform function");
            }
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public TransformResultMetadata getResultMetadata() {
            return DOUBLE_SV_NO_DICTIONARY_METADATA;
        }

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
            int numDocs = valueBlock.getNumDocs();
            initDoubleValuesSV(numDocs);
            float[][] transformToFloatValuesMV = this._leftTransformFunction.transformToFloatValuesMV(valueBlock);
            float[][] transformToFloatValuesMV2 = this._rightTransformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < numDocs; i++) {
                this._doubleValuesSV[i] = computeDistance(transformToFloatValuesMV[i], transformToFloatValuesMV2[i]);
            }
            return this._doubleValuesSV;
        }

        protected abstract double computeDistance(float[] fArr, float[] fArr2);
    }

    /* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctions$VectorNormTransformFunction.class */
    public static class VectorNormTransformFunction extends BaseTransformFunction {
        public static final String FUNCTION_NAME = "vectorNorm";
        private TransformFunction _transformFunction;

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public void init(List<TransformFunction> list, Map<String, ColumnContext> map) {
            super.init(list, map);
            if (list.size() != 1) {
                throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function");
            }
            this._transformFunction = list.get(0);
            Preconditions.checkArgument(!this._transformFunction.getResultMetadata().isSingleValue(), "Argument must be multi-valued float vector for vector distance transform function: %s", getName());
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override // org.apache.pinot.core.operator.transform.function.TransformFunction
        public TransformResultMetadata getResultMetadata() {
            return DOUBLE_SV_NO_DICTIONARY_METADATA;
        }

        @Override // org.apache.pinot.core.operator.transform.function.BaseTransformFunction, org.apache.pinot.core.operator.transform.function.TransformFunction
        public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
            int numDocs = valueBlock.getNumDocs();
            initDoubleValuesSV(numDocs);
            float[][] transformToFloatValuesMV = this._transformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < numDocs; i++) {
                this._doubleValuesSV[i] = VectorFunctions.vectorNorm(transformToFloatValuesMV[i]);
            }
            return this._doubleValuesSV;
        }
    }
}
