/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.sort.QuickSort;
import smile.util.MulticoreExecutor;

public class RegressionTree
implements Regression<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private Attribute[] attributes;
    private double[] importance;
    private Node root;
    private int nodeSize = 5;
    private int maxNodes = 6;
    private int mtry;
    private int numFeatures;
    private transient int[][] order;

    public RegressionTree(double[][] x, double[] y, int maxNodes) {
        this(null, x, y, maxNodes, 5);
    }

    public RegressionTree(double[][] x, double[] y, int maxNodes, int nodeSize) {
        this(null, x, y, maxNodes, nodeSize);
    }

    public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes) {
        this(attributes, x, y, maxNodes, 5);
    }

    public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize) {
        this(attributes, x, y, maxNodes, nodeSize, x[0].length, null, null, null);
    }

    public RegressionTree(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize, int mtry, int[][] order, int[] samples, NodeOutput output) {
        TrainNode node;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (mtry < 1 || mtry > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }
        if (nodeSize < 2) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i2 = 0; i2 < p; ++i2) {
                attributes[i2] = new NumericAttribute("V" + (i2 + 1));
            }
        }
        this.attributes = attributes;
        this.maxNodes = maxNodes;
        this.nodeSize = nodeSize;
        this.mtry = mtry;
        this.importance = new double[attributes.length];
        if (order != null) {
            this.order = order;
        } else {
            int n = x.length;
            int p = x[0].length;
            double[] a = new double[n];
            this.order = new int[p][];
            for (int j = 0; j < p; ++j) {
                if (!(attributes[j] instanceof NumericAttribute)) continue;
                for (i = 0; i < n; ++i) {
                    a[i] = x[i][j];
                }
                this.order[j] = QuickSort.sort((double[])a);
            }
        }
        PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
        int n = 0;
        double sum = 0.0;
        if (samples == null) {
            n = y.length;
            samples = new int[n];
            for (i = 0; i < n; ++i) {
                samples[i] = 1;
                sum += y[i];
            }
        } else {
            for (i = 0; i < y.length; ++i) {
                n += samples[i];
                sum += (double)samples[i] * y[i];
            }
        }
        this.root = new Node(sum / (double)n);
        TrainNode trainRoot = new TrainNode(this.root, x, y, samples);
        if (trainRoot.findBestSplit()) {
            nextSplits.add(trainRoot);
        }
        for (int leaves = 1; leaves < this.maxNodes && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
            node.split(nextSplits);
        }
        if (output != null) {
            trainRoot.calculateOutput(output);
        }
    }

    public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes) {
        this(numFeatures, x, y, maxNodes, 5);
    }

    public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes, int nodeSize) {
        this(numFeatures, x, y, maxNodes, nodeSize, null, null);
    }

    public RegressionTree(int numFeatures, int[][] x, double[] y, int maxNodes, int nodeSize, int[] samples, NodeOutput output) {
        SparseBinaryTrainNode node;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum number of leaves: " + maxNodes);
        }
        if (nodeSize < 2) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
        }
        this.maxNodes = maxNodes;
        this.nodeSize = nodeSize;
        this.numFeatures = numFeatures;
        this.mtry = numFeatures;
        this.importance = new double[numFeatures];
        PriorityQueue<SparseBinaryTrainNode> nextSplits = new PriorityQueue<SparseBinaryTrainNode>();
        int n = 0;
        double sum = 0.0;
        if (samples == null) {
            n = y.length;
            samples = new int[n];
            for (i = 0; i < n; ++i) {
                samples[i] = 1;
                sum += y[i];
            }
        } else {
            for (i = 0; i < y.length; ++i) {
                n += samples[i];
                sum += (double)samples[i] * y[i];
            }
        }
        this.root = new Node(sum / (double)n);
        SparseBinaryTrainNode trainRoot = new SparseBinaryTrainNode(this.root, x, y, samples);
        if (trainRoot.findBestSplit()) {
            nextSplits.add(trainRoot);
        }
        for (int leaves = 1; leaves < this.maxNodes && (node = (SparseBinaryTrainNode)nextSplits.poll()) != null; ++leaves) {
            node.split(nextSplits);
        }
        if (output != null) {
            trainRoot.calculateOutput(output);
        }
    }

    public double[] importance() {
        return this.importance;
    }

    @Override
    public double predict(double[] x) {
        return this.root.predict(x);
    }

    @Override
    public double predict(int[] x) {
        return this.root.predict(x);
    }

    public int maxDepth() {
        return this.maxDepth(this.root);
    }

    private int maxDepth(Node node) {
        int rDepth;
        if (node == null) {
            return 0;
        }
        int lDepth = this.maxDepth(node.trueChild);
        if (lDepth > (rDepth = this.maxDepth(node.falseChild))) {
            return lDepth + 1;
        }
        return rDepth + 1;
    }

    public String dot() {
        StringBuilder builder = new StringBuilder();
        builder.append("digraph RegressionTree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
        int n = 0;
        LinkedList<DotNode> queue = new LinkedList<DotNode>();
        queue.add(new DotNode(-1, 0, this.root));
        while (!queue.isEmpty()) {
            DotNode dnode = (DotNode)queue.poll();
            int id = dnode.id;
            int parent = dnode.parent;
            Node node = dnode.node;
            if (node.trueChild == null && node.falseChild == null) {
                builder.append(String.format(" %d [label=<%.4f>, fillcolor=\"#00000000\", shape=ellipse];\n", id, node.output));
            } else {
                Attribute attr = this.attributes[node.splitFeature];
                if (attr.getType() == Attribute.Type.NOMINAL) {
                    builder.append(String.format(" %d [label=<%s = %s<br/>nscore = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), attr.toString(node.splitValue), node.splitScore));
                } else if (attr.getType() == Attribute.Type.NUMERIC) {
                    builder.append(String.format(" %d [label=<%s &le; %.4f<br/>score = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), node.splitValue, node.splitScore));
                } else {
                    throw new IllegalStateException("Unsupported attribute type: " + attr.getType());
                }
            }
            if (parent >= 0) {
                builder.append(' ').append(parent).append(" -> ").append(id);
                if (parent == 0) {
                    if (id == 1) {
                        builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                    } else {
                        builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                    }
                }
                builder.append(";\n");
            }
            if (node.trueChild != null) {
                queue.add(new DotNode(id, ++n, node.trueChild));
            }
            if (node.falseChild == null) continue;
            queue.add(new DotNode(id, ++n, node.falseChild));
        }
        builder.append("}");
        return builder.toString();
    }

    private class DotNode {
        int parent;
        int id;
        Node node;

        DotNode(int parent, int id, Node node) {
            this.parent = parent;
            this.id = id;
            this.node = node;
        }
    }

    class SparseBinaryTrainNode
    implements Comparable<SparseBinaryTrainNode> {
        Node node;
        SparseBinaryTrainNode trueChild;
        SparseBinaryTrainNode falseChild;
        int[][] x;
        double[] y;
        int[] samples;

        public SparseBinaryTrainNode(Node node, int[][] x, double[] y, int[] samples) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.samples = samples;
        }

        @Override
        public int compareTo(SparseBinaryTrainNode a) {
            return (int)Math.signum((double)(a.node.splitScore - this.node.splitScore));
        }

        public boolean findBestSplit() {
            int i;
            if (this.node.trueChild != null || this.node.falseChild != null) {
                throw new IllegalStateException("Split non-leaf node.");
            }
            int p = RegressionTree.this.numFeatures;
            double[] trueSum = new double[p];
            int[] trueCount = new int[p];
            int[] featureIndex = new int[p];
            int n = Math.sum((int[])this.samples);
            double sumX = 0.0;
            for (i = 0; i < this.x.length; ++i) {
                if (this.samples[i] == 0) continue;
                double target = (double)this.samples[i] * this.y[i];
                sumX += this.y[i];
                int j = 0;
                while (j < this.x[i].length) {
                    int index;
                    int n2 = index = this.x[i][j];
                    trueSum[n2] = trueSum[n2] + target;
                    int n3 = index;
                    trueCount[n3] = trueCount[n3] + this.samples[i];
                    featureIndex[index] = j++;
                }
            }
            this.node.splitScore = 0.0;
            this.node.splitFeature = -1;
            this.node.splitValue = -1.0;
            for (i = 0; i < p; ++i) {
                double falseMean;
                double trueMean;
                double gain;
                double tc = trueCount[i];
                double fc = (double)n - tc;
                if (tc < (double)RegressionTree.this.nodeSize || fc < (double)RegressionTree.this.nodeSize || !((gain = tc * (trueMean = trueSum[i] / tc) * trueMean + fc * (falseMean = (sumX - trueSum[i]) / fc) * falseMean - (double)n * this.node.output * this.node.output) > this.node.splitScore)) continue;
                this.node.splitFeature = featureIndex[i];
                this.node.splitValue = i;
                this.node.splitScore = gain;
                this.node.trueChildOutput = trueMean;
                this.node.falseChildOutput = falseMean;
            }
            return this.node.splitFeature != -1;
        }

        public void split(PriorityQueue<SparseBinaryTrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            if (this.node.trueChild != null || this.node.falseChild != null) {
                throw new IllegalStateException("Split non-leaf node.");
            }
            int n = this.x.length;
            int tc = 0;
            int fc = 0;
            int[] trueSamples = new int[n];
            for (int i = 0; i < n; ++i) {
                if (this.samples[i] <= 0) continue;
                if (this.x[i][this.node.splitFeature] == (int)this.node.splitValue) {
                    trueSamples[i] = this.samples[i];
                    tc += trueSamples[i];
                    this.samples[i] = 0;
                    continue;
                }
                fc += this.samples[i];
            }
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.node.falseChild = new Node(this.node.falseChildOutput);
            this.trueChild = new SparseBinaryTrainNode(this.node.trueChild, this.x, this.y, trueSamples);
            if (tc > RegressionTree.this.nodeSize && this.trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.trueChild);
                } else {
                    this.trueChild.split(null);
                }
            }
            this.falseChild = new SparseBinaryTrainNode(this.node.falseChild, this.x, this.y, this.samples);
            if (fc > RegressionTree.this.nodeSize && this.falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.falseChild);
                } else {
                    this.falseChild.split(null);
                }
            }
            double[] dArray = RegressionTree.this.importance;
            int n2 = this.node.splitFeature;
            dArray[n2] = dArray[n2] + this.node.splitScore;
        }

        public void calculateOutput(NodeOutput output) {
            if (this.node.trueChild == null && this.node.falseChild == null) {
                this.node.output = output.calculate(this.samples);
            } else {
                if (this.trueChild != null) {
                    this.trueChild.calculateOutput(output);
                }
                if (this.falseChild != null) {
                    this.falseChild.calculateOutput(output);
                }
            }
        }
    }

    class TrainNode
    implements Comparable<TrainNode> {
        Node node;
        TrainNode trueChild;
        TrainNode falseChild;
        double[][] x;
        double[] y;
        int[] samples;

        public TrainNode(Node node, double[][] x, double[] y, int[] samples) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.samples = samples;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum((double)(a.node.splitScore - this.node.splitScore));
        }

        public void calculateOutput(NodeOutput output) {
            if (this.node.trueChild == null && this.node.falseChild == null) {
                this.node.output = output.calculate(this.samples);
            } else {
                if (this.trueChild != null) {
                    this.trueChild.calculateOutput(output);
                }
                if (this.falseChild != null) {
                    this.falseChild.calculateOutput(output);
                }
            }
        }

        public boolean findBestSplit() {
            int n = 0;
            for (int s : this.samples) {
                n += s;
            }
            if (n <= RegressionTree.this.nodeSize) {
                return false;
            }
            double sum = this.node.output * (double)n;
            int p = RegressionTree.this.attributes.length;
            int[] variables = new int[p];
            for (int i = 0; i < p; ++i) {
                variables[i] = i;
            }
            if (RegressionTree.this.mtry < p) {
                Math.permutate((int[])variables);
                for (int j = 0; j < RegressionTree.this.mtry; ++j) {
                    Node split = this.findBestSplit(n, sum, variables[j]);
                    if (!(split.splitScore > this.node.splitScore)) continue;
                    this.node.splitFeature = split.splitFeature;
                    this.node.splitValue = split.splitValue;
                    this.node.splitScore = split.splitScore;
                    this.node.trueChildOutput = split.trueChildOutput;
                    this.node.falseChildOutput = split.falseChildOutput;
                }
            } else {
                ArrayList<SplitTask> tasks = new ArrayList<SplitTask>(RegressionTree.this.mtry);
                for (int j = 0; j < RegressionTree.this.mtry; ++j) {
                    tasks.add(new SplitTask(n, sum, variables[j]));
                }
                try {
                    for (Node split : MulticoreExecutor.run(tasks)) {
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
                catch (Exception ex) {
                    for (int j = 0; j < RegressionTree.this.mtry; ++j) {
                        Node split = this.findBestSplit(n, sum, variables[j]);
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
            }
            return this.node.splitFeature != -1;
        }

        public Node findBestSplit(int n, double sum, int j) {
            Node split = new Node(0.0);
            if (RegressionTree.this.attributes[j].getType() == Attribute.Type.NOMINAL) {
                int m = ((NominalAttribute)RegressionTree.this.attributes[j]).size();
                double[] trueSum = new double[m];
                int[] trueCount = new int[m];
                for (int i = 0; i < this.x.length; ++i) {
                    int index;
                    if (this.samples[i] <= 0) continue;
                    double target = (double)this.samples[i] * this.y[i];
                    int n2 = index = (int)this.x[i][j];
                    trueSum[n2] = trueSum[n2] + target;
                    int n3 = index;
                    trueCount[n3] = trueCount[n3] + this.samples[i];
                }
                for (int k = 0; k < m; ++k) {
                    double falseMean;
                    double trueMean;
                    double gain;
                    double tc = trueCount[k];
                    double fc = (double)n - tc;
                    if (tc < (double)RegressionTree.this.nodeSize || fc < (double)RegressionTree.this.nodeSize || !((gain = tc * (trueMean = trueSum[k] / tc) * trueMean + fc * (falseMean = (sum - trueSum[k]) / fc) * falseMean - (double)n * split.output * split.output) > split.splitScore)) continue;
                    split.splitFeature = j;
                    split.splitValue = k;
                    split.splitScore = gain;
                    split.trueChildOutput = trueMean;
                    split.falseChildOutput = falseMean;
                }
            } else if (RegressionTree.this.attributes[j].getType() == Attribute.Type.NUMERIC) {
                double trueSum = 0.0;
                int trueCount = 0;
                double prevx = Double.NaN;
                for (int i : RegressionTree.this.order[j]) {
                    if (this.samples[i] <= 0) continue;
                    if (Double.isNaN(prevx) || this.x[i][j] == prevx) {
                        prevx = this.x[i][j];
                        trueSum += (double)this.samples[i] * this.y[i];
                        trueCount += this.samples[i];
                        continue;
                    }
                    double falseCount = n - trueCount;
                    if (trueCount < RegressionTree.this.nodeSize || falseCount < (double)RegressionTree.this.nodeSize) {
                        prevx = this.x[i][j];
                        trueSum += (double)this.samples[i] * this.y[i];
                        trueCount += this.samples[i];
                        continue;
                    }
                    double trueMean = trueSum / (double)trueCount;
                    double falseMean = (sum - trueSum) / falseCount;
                    double gain = (double)trueCount * trueMean * trueMean + falseCount * falseMean * falseMean - (double)n * split.output * split.output;
                    if (gain > split.splitScore) {
                        split.splitFeature = j;
                        split.splitValue = (this.x[i][j] + prevx) / 2.0;
                        split.splitScore = gain;
                        split.trueChildOutput = trueMean;
                        split.falseChildOutput = falseMean;
                    }
                    prevx = this.x[i][j];
                    trueSum += (double)this.samples[i] * this.y[i];
                    trueCount += this.samples[i];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + RegressionTree.this.attributes[j].getType());
            }
            return split;
        }

        public boolean split(PriorityQueue<TrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int n = this.x.length;
            int tc = 0;
            int fc = 0;
            int[] trueSamples = new int[n];
            if (RegressionTree.this.attributes[this.node.splitFeature].getType() == Attribute.Type.NOMINAL) {
                for (int i = 0; i < n; ++i) {
                    if (this.samples[i] <= 0) continue;
                    if (Math.equals((double)this.x[i][this.node.splitFeature], (double)this.node.splitValue)) {
                        trueSamples[i] = this.samples[i];
                        tc += trueSamples[i];
                        this.samples[i] = 0;
                        continue;
                    }
                    fc += this.samples[i];
                }
            } else if (RegressionTree.this.attributes[this.node.splitFeature].getType() == Attribute.Type.NUMERIC) {
                for (int i = 0; i < n; ++i) {
                    if (this.samples[i] <= 0) continue;
                    if (this.x[i][this.node.splitFeature] <= this.node.splitValue) {
                        trueSamples[i] = this.samples[i];
                        tc += trueSamples[i];
                        this.samples[i] = 0;
                        continue;
                    }
                    fc += this.samples[i];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + RegressionTree.this.attributes[this.node.splitFeature].getType());
            }
            if (tc < RegressionTree.this.nodeSize || fc < RegressionTree.this.nodeSize) {
                this.node.splitFeature = -1;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = 0.0;
                return false;
            }
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.node.falseChild = new Node(this.node.falseChildOutput);
            this.trueChild = new TrainNode(this.node.trueChild, this.x, this.y, trueSamples);
            if (tc > RegressionTree.this.nodeSize && this.trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.trueChild);
                } else {
                    this.trueChild.split(null);
                }
            }
            this.falseChild = new TrainNode(this.node.falseChild, this.x, this.y, this.samples);
            if (fc > RegressionTree.this.nodeSize && this.falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.falseChild);
                } else {
                    this.falseChild.split(null);
                }
            }
            double[] dArray = RegressionTree.this.importance;
            int n2 = this.node.splitFeature;
            dArray[n2] = dArray[n2] + this.node.splitScore;
            return true;
        }

        class SplitTask
        implements Callable<Node> {
            int n;
            double sum;
            int j;

            SplitTask(int n, double sum, int j) {
                this.n = n;
                this.sum = sum;
                this.j = j;
            }

            @Override
            public Node call() {
                return TrainNode.this.findBestSplit(this.n, this.sum, this.j);
            }
        }
    }

    class Node
    implements Serializable {
        double output = 0.0;
        int splitFeature = -1;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild;
        Node falseChild;
        double trueChildOutput = 0.0;
        double falseChildOutput = 0.0;

        public Node(double output) {
            this.output = output;
        }

        public double predict(double[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (RegressionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NOMINAL) {
                if (Math.equals((double)x[this.splitFeature], (double)this.splitValue)) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (RegressionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            throw new IllegalStateException("Unsupported attribute type: " + RegressionTree.this.attributes[this.splitFeature].getType());
        }

        public double predict(int[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (x[this.splitFeature] == (int)this.splitValue) {
                return this.trueChild.predict(x);
            }
            return this.falseChild.predict(x);
        }
    }

    public static interface NodeOutput {
        public double calculate(int[] var1);
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private int nodeSize = 1;
        private int maxNodes = 100;
        private int numFeatures = -1;

        public Trainer(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
        }

        public Trainer(Attribute[] attributes, int maxNodes) {
            super(attributes);
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
        }

        public Trainer(int numFeatures, int maxNodes) {
            if (numFeatures <= 0) {
                throw new IllegalArgumentException("Invalid number of sparse binary features: " + numFeatures);
            }
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.numFeatures = numFeatures;
            this.maxNodes = maxNodes;
        }

        public Trainer setMaxNodes(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
            return this;
        }

        public Trainer setNodeSize(int nodeSize) {
            if (nodeSize < 2) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
            }
            this.nodeSize = nodeSize;
            return this;
        }

        public RegressionTree train(double[][] x, double[] y) {
            return new RegressionTree(this.attributes, x, y, this.maxNodes, this.nodeSize);
        }

        public RegressionTree train(int[][] x, double[] y) {
            if (this.numFeatures <= 0) {
                return new RegressionTree(Math.max((int[][])x) + 1, x, y, this.maxNodes, this.nodeSize);
            }
            return new RegressionTree(this.numFeatures, x, y, this.maxNodes, this.nodeSize);
        }
    }
}

