/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.eval;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.util.MathArrays;
import org.apache.solr.client.solrj.io.eval.ManyValueWorker;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.eval.RecursiveObjectEvaluator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;

public class NormalizeSumEvaluator
extends RecursiveObjectEvaluator
implements ManyValueWorker {
    protected static final long serialVersionUID = 1L;

    public NormalizeSumEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
        super(expression, factory);
        if (2 < this.containedEvaluators.size()) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - expecting at most two parameters but found %d", expression, this.containedEvaluators.size()));
        }
    }

    @Override
    public Object doWork(Object ... values) throws IOException {
        Object value = values[0];
        double sumTo = 1.0;
        if (values.length == 2) {
            Number n = (Number)values[1];
            sumTo = n.doubleValue();
        }
        if (value == null) {
            return null;
        }
        if (value instanceof Matrix) {
            Matrix matrix = (Matrix)value;
            double[][] data = matrix.getData();
            double[][] unitData = new double[data.length][];
            int i = 0;
            while (i < data.length) {
                double[] row = data[i];
                double[] unitRow = MathArrays.normalizeArray((double[])row, (double)sumTo);
                unitData[i] = unitRow;
                ++i;
            }
            Matrix m = new Matrix(unitData);
            m.setRowLabels(matrix.getRowLabels());
            m.setColumnLabels(matrix.getColumnLabels());
            return m;
        }
        if (value instanceof List) {
            double[] unitArray;
            List vals = (List)value;
            double[] doubles = new double[vals.size()];
            int i = 0;
            while (i < doubles.length) {
                doubles[i] = ((Number)vals.get(i)).doubleValue();
                ++i;
            }
            ArrayList<Double> unitList = new ArrayList<Double>(doubles.length);
            double[] dArray = unitArray = MathArrays.normalizeArray((double[])doubles, (double)sumTo);
            int n = unitArray.length;
            int n2 = 0;
            while (n2 < n) {
                double d = dArray[n2];
                unitList.add(d);
                ++n2;
            }
            return unitList;
        }
        throw new IOException("The unit function expects either a numeric array or matrix as a parameter");
    }
}

