/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.eval;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.commons.math3.analysis.BivariateFunction;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.solr.client.solrj.io.eval.KnnRegressionEvaluator;
import org.apache.solr.client.solrj.io.eval.ManyValueWorker;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.eval.OLSRegressionEvaluator;
import org.apache.solr.client.solrj.io.eval.RecursiveObjectEvaluator;
import org.apache.solr.client.solrj.io.eval.RegressionEvaluator;
import org.apache.solr.client.solrj.io.eval.VectorFunction;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;

public class PredictEvaluator
extends RecursiveObjectEvaluator
implements ManyValueWorker {
    protected static final long serialVersionUID = 1L;

    public PredictEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
        super(expression, factory);
    }

    @Override
    public Object doWork(Object ... objects) throws IOException {
        if (objects.length != 2 && objects.length != 3) {
            throw new IOException("The predict function expects 2 or 3 parameters.");
        }
        Object first = objects[0];
        Object second = objects[1];
        if (!(first instanceof BivariateFunction || first instanceof VectorFunction || first instanceof RegressionEvaluator.RegressionTuple || first instanceof OLSRegressionEvaluator.MultipleRegressionTuple || first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", this.toExpression(this.constructingFactory), first.getClass().getSimpleName()));
        }
        if (!(second instanceof Number || second instanceof List || second instanceof Matrix)) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a Number, Array or Matrix", this.toExpression(this.constructingFactory), first.getClass().getSimpleName()));
        }
        if (first instanceof RegressionEvaluator.RegressionTuple) {
            RegressionEvaluator.RegressionTuple regressedTuple = (RegressionEvaluator.RegressionTuple)first;
            if (second instanceof Number) {
                return regressedTuple.predict(((Number)second).doubleValue());
            }
            return ((List)second).stream().map(value -> regressedTuple.predict(((Number)value).doubleValue())).collect(Collectors.toList());
        }
        if (first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) {
            OLSRegressionEvaluator.MultipleRegressionTuple regressedTuple = (OLSRegressionEvaluator.MultipleRegressionTuple)first;
            if (second instanceof List) {
                List list = (List)second;
                double[] predictors = new double[list.size()];
                int i = 0;
                while (i < list.size()) {
                    predictors[i] = ((Number)list.get(i)).doubleValue();
                    ++i;
                }
                return regressedTuple.predict(predictors);
            }
            if (second instanceof Matrix) {
                Matrix m = (Matrix)second;
                double[][] data = m.getData();
                ArrayList<Double> predictions = new ArrayList<Double>();
                double[][] dArray = data;
                int n = data.length;
                int n2 = 0;
                while (n2 < n) {
                    double[] predictors = dArray[n2];
                    predictions.add(regressedTuple.predict(predictors));
                    ++n2;
                }
                return predictions;
            }
        } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
            KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple)first;
            if (regressedTuple.getBivariate()) {
                if (second instanceof Number) {
                    double[] predictors = new double[]{((Number)second).doubleValue()};
                    return regressedTuple.predict(predictors);
                }
                if (second instanceof List) {
                    List vec = (List)second;
                    ArrayList<Double> predictions = new ArrayList<Double>();
                    for (Number num : vec) {
                        double[] predictors = new double[]{num.doubleValue()};
                        predictions.add(regressedTuple.predict(predictors));
                    }
                    return predictions;
                }
            } else {
                if (second instanceof List) {
                    List list = (List)second;
                    double[] predictors = new double[list.size()];
                    int i = 0;
                    while (i < list.size()) {
                        predictors[i] = ((Number)list.get(i)).doubleValue();
                        ++i;
                    }
                    if (regressedTuple.getScale()) {
                        predictors = regressedTuple.scale(predictors);
                    }
                    return regressedTuple.predict(predictors);
                }
                if (second instanceof Matrix) {
                    Matrix m = (Matrix)second;
                    if (regressedTuple.getScale()) {
                        m = regressedTuple.scale(m);
                    }
                    double[][] data = m.getData();
                    ArrayList<Double> predictions = new ArrayList<Double>();
                    double[][] dArray = data;
                    int n = data.length;
                    int n3 = 0;
                    while (n3 < n) {
                        double[] predictors = dArray[n3];
                        predictions.add(regressedTuple.predict(predictors));
                        ++n3;
                    }
                    return predictions;
                }
            }
        } else {
            if (first instanceof VectorFunction) {
                VectorFunction vectorFunction = (VectorFunction)first;
                UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction();
                if (second instanceof Number) {
                    double x = ((Number)second).doubleValue();
                    return univariateFunction.value(x);
                }
                return ((List)second).stream().map(value -> univariateFunction.value(((Number)value).doubleValue())).collect(Collectors.toList());
            }
            if (first instanceof BivariateFunction) {
                BivariateFunction bivariateFunction = (BivariateFunction)first;
                if (objects.length == 3) {
                    Object third = objects[2];
                    double x = 0.0;
                    double y = 0.0;
                    if (second instanceof Number && third instanceof Number) {
                        x = ((Number)second).doubleValue();
                        y = ((Number)third).doubleValue();
                        return bivariateFunction.value(x, y);
                    }
                    throw new IOException("BivariateFunction requires two numberic parameters.");
                }
                if (objects.length == 2) {
                    if (second instanceof Matrix) {
                        Matrix m = (Matrix)second;
                        double[][] data = m.getData();
                        if (data[0].length == 2) {
                            ArrayList<Double> out = new ArrayList<Double>();
                            double[][] dArray = data;
                            int n = data.length;
                            int n4 = 0;
                            while (n4 < n) {
                                double[] row = dArray[n4];
                                out.add(bivariateFunction.value(row[0], row[1]));
                                ++n4;
                            }
                            return out;
                        }
                        throw new IOException("Bivariate Function expects a matrix with two columns");
                    }
                    throw new IOException("Bivariate Function requires a matrix parameter.");
                }
            }
        }
        return null;
    }
}

