/*
 * Decompiled with CFR 0.152.
 */
package io.questdb.griffin.engine.functions.array;

import io.questdb.cairo.CairoConfiguration;
import io.questdb.cairo.CairoException;
import io.questdb.cairo.ColumnType;
import io.questdb.cairo.arr.ArrayView;
import io.questdb.cairo.sql.Function;
import io.questdb.cairo.sql.Record;
import io.questdb.griffin.FunctionFactory;
import io.questdb.griffin.SqlException;
import io.questdb.griffin.SqlExecutionContext;
import io.questdb.griffin.engine.functions.BinaryFunction;
import io.questdb.griffin.engine.functions.DoubleFunction;
import io.questdb.std.IntList;
import io.questdb.std.Numbers;
import io.questdb.std.ObjList;

public class DoubleArrayDotProductFunctionFactory
implements FunctionFactory {
    private static final String FUNCTION_NAME = "dot_product";

    @Override
    public String getSignature() {
        return "dot_product(D[]D[])";
    }

    @Override
    public Function newInstance(int position, ObjList<Function> args, IntList argPositions, CairoConfiguration configuration, SqlExecutionContext sqlExecutionContext) throws SqlException {
        return new Func(args.getQuick(0), args.getQuick(1), argPositions.getQuick(0));
    }

    private static class Func
    extends DoubleFunction
    implements BinaryFunction {
        private final Function leftArg;
        private final int leftArgPos;
        private final Function rightArg;

        public Func(Function leftArg, Function rightArg, int leftArgPos) throws SqlException {
            this.leftArg = leftArg;
            this.rightArg = rightArg;
            this.leftArgPos = leftArgPos;
            int nDimsLeft = ColumnType.decodeArrayDimensionality(leftArg.getType());
            int nDimsRight = ColumnType.decodeArrayDimensionality(rightArg.getType());
            if (nDimsLeft != nDimsRight) {
                throw SqlException.position(leftArgPos).put("arrays have different number of dimensions [nDimsLeft=").put(nDimsLeft).put(", nDimsRight=").put(nDimsRight).put(']');
            }
        }

        @Override
        public double getDouble(Record rec) {
            ArrayView left = this.leftArg.getArray(rec);
            ArrayView right = this.rightArg.getArray(rec);
            if (left.isNull() || right.isNull()) {
                return 0.0;
            }
            if (left.shapeDiffers(right)) {
                throw CairoException.nonCritical().position(this.leftArgPos).put("arrays have different shapes [leftShape=").put(left.shapeToString()).put(", rightShape=").put(right.shapeToString()).put(']');
            }
            if (left.isVanilla() && right.isVanilla()) {
                double value = 0.0;
                int n = left.getFlatViewLength();
                for (int i = 0; i < n; ++i) {
                    double leftVal = left.getDouble(i);
                    double rightVal = right.getDouble(i);
                    if (!Numbers.isFinite(leftVal) || !Numbers.isFinite(rightVal)) continue;
                    value += leftVal * rightVal;
                }
                return value;
            }
            return Func.applyRecursive(0, left, 0, right, 0, 0.0);
        }

        @Override
        public Function getLeft() {
            return this.leftArg;
        }

        @Override
        public String getName() {
            return DoubleArrayDotProductFunctionFactory.FUNCTION_NAME;
        }

        @Override
        public Function getRight() {
            return this.rightArg;
        }

        @Override
        public boolean isThreadSafe() {
            return false;
        }

        private static double applyRecursive(int dim, ArrayView left, int flatIndexLeft, ArrayView right, int flatIndexRight, double sum) {
            boolean atDeepestDim;
            int count = left.getDimLen(dim);
            int strideLeft = left.getStride(dim);
            int strideRight = right.getStride(dim);
            boolean bl = atDeepestDim = dim == left.getDimCount() - 1;
            if (atDeepestDim) {
                for (int i = 0; i < count; ++i) {
                    double leftVal = left.getDouble(flatIndexLeft);
                    double rightVal = right.getDouble(flatIndexLeft);
                    if (Numbers.isFinite(leftVal) && Numbers.isFinite(rightVal)) {
                        sum += leftVal * rightVal;
                    }
                    flatIndexLeft += strideLeft;
                    flatIndexRight += strideRight;
                }
            } else {
                for (int i = 0; i < count; ++i) {
                    sum = Func.applyRecursive(dim + 1, left, flatIndexLeft, right, flatIndexRight, sum);
                    flatIndexLeft += strideLeft;
                    flatIndexRight += strideRight;
                }
            }
            return sum;
        }
    }
}

