package org.apache.pinot.core.operator.transform.function;

import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.request.context.RequestContextUtils;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.class */
public class VectorTransformFunctionTest extends BaseTransformFunctionTest {
    @Test(dataProvider = "testVectorTransformFunctionDataProvider")
    public void testVectorTransformFunction(String str, double d, double d2) {
        double[] transformToDoubleValuesSV = TransformFunctionFactory.get(RequestContextUtils.getExpression(str), this._dataSourceMap).transformToDoubleValuesSV(this._projectionBlock);
        for (int i = 0; i < 1000; i++) {
            boolean z = transformToDoubleValuesSV[i] >= d;
            Assert.assertTrue(z, transformToDoubleValuesSV[i] + " < " + z);
            boolean z2 = transformToDoubleValuesSV[i] <= d2;
            Assert.assertTrue(z2, transformToDoubleValuesSV[i] + " > " + z2);
        }
    }

    @Test
    public void testVectorDimsTransformFunction() {
        int[] transformToIntValuesSV = TransformFunctionFactory.get(RequestContextUtils.getExpression("vectorDims(vector1)"), this._dataSourceMap).transformToIntValuesSV(this._projectionBlock);
        for (int i = 0; i < 1000; i++) {
            Assert.assertEquals(transformToIntValuesSV[i], 512);
        }
        int[] transformToIntValuesSV2 = TransformFunctionFactory.get(RequestContextUtils.getExpression("vectorDims(vector2)"), this._dataSourceMap).transformToIntValuesSV(this._projectionBlock);
        for (int i2 = 0; i2 < 1000; i2++) {
            Assert.assertEquals(transformToIntValuesSV2[i2], 512);
        }
    }

    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "testVectorTransformFunctionDataProvider")
    public Object[][] testVectorTransformFunctionDataProvider() {
        String str = "ARRAY[0.0" + StringUtils.repeat(",0.0", 511) + "]";
        return new Object[]{new Object[]{"cosineDistance(vector1, vector2)", Double.valueOf(0.1d), Double.valueOf(0.4d)}, new Object[]{"cosineDistance(vector1, vector2, 0)", Double.valueOf(0.1d), Double.valueOf(0.4d)}, new Object[]{"cosineDistance(vector1, zeroVector, 0)", Double.valueOf(0.0d), Double.valueOf(0.0d)}, new Object[]{"innerProduct(vector1, vector2)", 100, 160}, new Object[]{"l1Distance(vector1, vector2)", 140, 210}, new Object[]{"l2Distance(vector1, vector2)", 8, 11}, new Object[]{"vectorNorm(vector1)", 10, 16}, new Object[]{"vectorNorm(vector2)", 10, 16}, new Object[]{String.format("cosineDistance(vector1, %s, 0)", str), Double.valueOf(0.0d), Double.valueOf(0.0d)}, new Object[]{String.format("innerProduct(vector1, %s)", str), Double.valueOf(0.0d), Double.valueOf(0.0d)}, new Object[]{String.format("l1Distance(vector1, %s)", str), 0, 512}, new Object[]{String.format("l2Distance(vector1, %s)", str), 0, 512}, new Object[]{String.format("vectorDims(%s)", str), 512, 512}, new Object[]{String.format("vectorNorm(%s)", str), Double.valueOf(0.0d), Double.valueOf(0.0d)}};
    }
}
