/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.eval;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.stat.regression.MultipleLinearRegression;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.eval.ManyValueWorker;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.eval.RecursiveObjectEvaluator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;

public class OLSRegressionEvaluator
extends RecursiveObjectEvaluator
implements ManyValueWorker {
    protected static final long serialVersionUID = 1L;

    public OLSRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
        super(expression, factory);
    }

    @Override
    public Object doWork(Object ... values) throws IOException {
        Matrix observations = null;
        List outcomes = null;
        if (!(values[0] instanceof Matrix)) {
            throw new IOException("The first parameter for olsRegress should be the observation matrix.");
        }
        observations = (Matrix)values[0];
        if (!(values[1] instanceof List)) {
            throw new IOException("The second parameter for olsRegress should be outcome array. ");
        }
        outcomes = (List)values[1];
        double[][] observationData = observations.getData();
        double[] outcomeData = new double[outcomes.size()];
        int i = 0;
        while (i < outcomeData.length) {
            outcomeData[i] = ((Number)outcomes.get(i)).doubleValue();
            ++i;
        }
        OLSMultipleLinearRegression multipleLinearRegression = (OLSMultipleLinearRegression)this.regress(observationData, outcomeData);
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("regressandVariance", multipleLinearRegression.estimateRegressandVariance());
        map.put("regressionParameters", this.list(multipleLinearRegression.estimateRegressionParameters()));
        map.put("RSquared", multipleLinearRegression.calculateRSquared());
        map.put("adjustedRSquared", multipleLinearRegression.calculateAdjustedRSquared());
        map.put("residualSumSquares", multipleLinearRegression.calculateResidualSumOfSquares());
        try {
            map.put("regressionParametersStandardErrors", this.list(multipleLinearRegression.estimateRegressionParametersStandardErrors()));
            map.put("regressionParametersVariance", new Matrix(multipleLinearRegression.estimateRegressionParametersVariance()));
        }
        catch (Exception exception) {
            // empty catch block
        }
        return new MultipleRegressionTuple((MultipleLinearRegression)multipleLinearRegression, map);
    }

    private List<Number> list(double[] values) {
        ArrayList<Number> list = new ArrayList<Number>();
        double[] dArray = values;
        int n = values.length;
        int n2 = 0;
        while (n2 < n) {
            double d = dArray[n2];
            list.add(d);
            ++n2;
        }
        return list;
    }

    protected MultipleLinearRegression regress(double[][] observations, double[] outcomes) {
        OLSMultipleLinearRegression olsMultipleLinearRegression = new OLSMultipleLinearRegression();
        olsMultipleLinearRegression.newSampleData(outcomes, observations);
        return olsMultipleLinearRegression;
    }

    public static class MultipleRegressionTuple
    extends Tuple {
        private MultipleLinearRegression multipleLinearRegression;

        public MultipleRegressionTuple(MultipleLinearRegression multipleLinearRegression, Map<?, ?> map) {
            super(map);
            this.multipleLinearRegression = multipleLinearRegression;
        }

        public double predict(double[] values) {
            List weights = (List)this.get("regressionParameters");
            double prediction = 0.0;
            ArrayList<Double> predictors = new ArrayList<Double>();
            predictors.add(1.0);
            double[] dArray = values;
            int n = values.length;
            int n2 = 0;
            while (n2 < n) {
                double d = dArray[n2];
                predictors.add(d);
                ++n2;
            }
            int i = 0;
            while (i < predictors.size()) {
                prediction += ((Number)weights.get(i)).doubleValue() * ((Number)predictors.get(i)).doubleValue();
                ++i;
            }
            return prediction;
        }
    }
}

