package org.apache.pinot.core.data.function;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.pinot.segment.local.function.InbuiltFunctionEvaluator;
import org.apache.pinot.spi.data.readers.GenericRow;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:org/apache/pinot/core/data/function/VectorFunctionsTest.class */
public class VectorFunctionsTest {
    private void testFunction(String str, List<String> list, GenericRow genericRow, Object obj) {
        InbuiltFunctionEvaluator inbuiltFunctionEvaluator = new InbuiltFunctionEvaluator(str);
        Assert.assertEquals(inbuiltFunctionEvaluator.getArguments(), list);
        Assert.assertEquals(inbuiltFunctionEvaluator.evaluate(genericRow), obj);
    }

    @Test(dataProvider = "vectorFunctionsDataProvider")
    public void testVectorFunctions(String str, List<String> list, GenericRow genericRow, Object obj) {
        testFunction(str, list, genericRow, obj);
    }

    @DataProvider(name = "vectorFunctionsDataProvider")
    public Object[][] vectorFunctionsDataProvider() {
        ArrayList arrayList = new ArrayList();
        GenericRow genericRow = new GenericRow();
        genericRow.putValue("vector1", new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f});
        genericRow.putValue("vector2", new float[]{0.6f, 0.7f, 0.8f, 0.9f, 1.0f});
        arrayList.add(new Object[]{"cosineDistance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(0.03504950750101454d)});
        arrayList.add(new Object[]{"innerProduct(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(1.2999999970197678d)});
        arrayList.add(new Object[]{"l2Distance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(1.1180339754218913d)});
        arrayList.add(new Object[]{"l1Distance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(2.4999999701976776d)});
        arrayList.add(new Object[]{"vectorDims(vector1)", Lists.newArrayList(new String[]{"vector1"}), genericRow, 5});
        arrayList.add(new Object[]{"vectorDims(vector2)", Lists.newArrayList(new String[]{"vector2"}), genericRow, 5});
        arrayList.add(new Object[]{"vectorNorm(vector1)", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(0.741619857751291d)});
        arrayList.add(new Object[]{"vectorNorm(vector2)", Lists.newArrayList(new String[]{"vector2"}), genericRow, Double.valueOf(1.8165902091773676d)});
        return (Object[][]) arrayList.toArray(new Object[0]);
    }

    @Test(dataProvider = "vectorFunctionsZeroDataProvider")
    public void testVectorFunctionsWithZeroVector(String str, List<String> list, GenericRow genericRow, Object obj) {
        testFunction(str, list, genericRow, obj);
    }

    @DataProvider(name = "vectorFunctionsZeroDataProvider")
    public Object[][] vectorFunctionsZeroDataProvider() {
        ArrayList arrayList = new ArrayList();
        GenericRow genericRow = new GenericRow();
        genericRow.putValue("vector1", new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f});
        genericRow.putValue("vector2", new float[]{0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
        arrayList.add(new Object[]{"cosineDistance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(Double.NaN)});
        arrayList.add(new Object[]{"cosineDistance(vector1, vector2, 0.0)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(0.0d)});
        arrayList.add(new Object[]{"cosineDistance(vector1, vector2, 1.0)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(1.0d)});
        arrayList.add(new Object[]{"innerProduct(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(0.0d)});
        arrayList.add(new Object[]{"l2Distance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(0.741619857751291d)});
        arrayList.add(new Object[]{"l1Distance(vector1, vector2)", Lists.newArrayList(new String[]{"vector1", "vector2"}), genericRow, Double.valueOf(1.5000000223517418d)});
        arrayList.add(new Object[]{"vectorDims(vector1)", Lists.newArrayList(new String[]{"vector1"}), genericRow, 5});
        arrayList.add(new Object[]{"vectorDims(vector2)", Lists.newArrayList(new String[]{"vector2"}), genericRow, 5});
        arrayList.add(new Object[]{"vectorNorm(vector1)", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(0.741619857751291d)});
        arrayList.add(new Object[]{"vectorNorm(vector2)", Lists.newArrayList(new String[]{"vector2"}), genericRow, Double.valueOf(0.0d)});
        arrayList.add(new Object[]{"cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(Double.NaN)});
        arrayList.add(new Object[]{"cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 0.0)", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(0.0d)});
        arrayList.add(new Object[]{"cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 1.0)", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(1.0d)});
        arrayList.add(new Object[]{"innerProduct(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(0.0d)});
        arrayList.add(new Object[]{"l2Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(0.741619857751291d)});
        arrayList.add(new Object[]{"l1Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList(new String[]{"vector1"}), genericRow, Double.valueOf(1.5000000223517418d)});
        return (Object[][]) arrayList.toArray(new Object[0]);
    }
}
