/*
 * Decompiled with CFR 0.152.
 */
package boone;

import boone.BrainPart;
import boone.Link;
import boone.NeuralNet;
import boone.Neuron;
import boone.PatternSet;
import boone.io.Storable;
import boone.training.SquareError;
import boone.training.TrainingSignalGenerator;
import boone.util.CloneException;
import boone.util.Xml;
import java.util.List;
import org.jdom2.Element;

public abstract class Trainer
implements Cloneable,
Storable {
    protected NeuralNet net;
    protected PatternSet trainingData;
    protected PatternSet testData;
    private int batchSize;
    protected TrainingSignalGenerator trainingSignalGenerator;
    protected boolean stepMode;
    protected boolean shuffle;
    protected double learnRate;
    protected int epochs;
    protected int currentEpoch;

    public Trainer() {
        this.setTrainingSignalGenerator(new SquareError());
        this.learnRate = 0.2;
        this.epochs = 1000;
    }

    public Trainer clone() {
        try {
            Trainer trainer = (Trainer)super.clone();
            trainer.net = null;
            trainer.trainingSignalGenerator = this.trainingSignalGenerator.clone();
            return trainer;
        }
        catch (CloneNotSupportedException cloneNotSupportedException) {
            throw new CloneException(cloneNotSupportedException);
        }
    }

    public PatternSet getTrainingData() {
        return this.trainingData;
    }

    public void setTrainingData(PatternSet patternSet) {
        this.trainingData = patternSet;
        if (this.batchSize == 0) {
            this.batchSize = patternSet.size();
        }
    }

    public PatternSet getTestData() {
        return this.testData;
    }

    public void setTestData(PatternSet patternSet) {
        this.testData = patternSet;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int n) {
        this.batchSize = n;
    }

    public void setStepMode(boolean bl) {
        this.stepMode = bl;
        if (bl) {
            this.reset();
        } else {
            this.endTrain();
        }
    }

    public void setShuffle(boolean bl) {
        this.shuffle = bl;
    }

    public void train() {
        if (!this.stepMode) {
            this.reset();
        }
        int n = 0;
        this.currentEpoch = 0;
        while (this.currentEpoch < this.epochs) {
            if (this.shuffle) {
                this.trainingData.shuffle();
            }
            while (this.trainingData.hasMoreElements()) {
                int n2 = this.trainingData.getNextElement();
                this.train(this.trainingData.getInputs().get(n2), this.trainingData.getTargets().get(n2));
                if (++n != this.batchSize) continue;
                this.endBatch();
                n = 0;
            }
            ++this.currentEpoch;
        }
        if (!this.stepMode) {
            this.endTrain();
        }
    }

    public void trainPattern(List<Double> list, List<Double> list2) {
        this.train(list, list2);
        this.endBatch();
    }

    public double test(List<Double> list, List<Double> list2) {
        return this.trainingSignalGenerator.computeError(this.net, list, list2);
    }

    public double test() {
        double d = 0.0;
        while (this.testData.hasMoreElements()) {
            int n = this.testData.getNextElement();
            d += this.trainingSignalGenerator.computeError(this.net, this.testData.getInputs().get(n), this.testData.getTargets().get(n));
        }
        return d;
    }

    protected void endBatch() {
        this.updateNeurons();
        this.updateLinks();
    }

    protected void updateNeurons() {
        int n = this.net.getNeuronCount();
        for (int i = 0; i < n; ++i) {
            Neuron neuron = this.net.getNeuron(i);
            if (!neuron.isUsingBias()) continue;
            neuron.addToBias(this.calcAdaptation(neuron));
            neuron.setGradient(0.0);
        }
    }

    protected void updateLinks() {
        int n = this.net.getLinkCount();
        for (int i = 0; i < n; ++i) {
            Link link = this.net.getLink(i);
            if (!link.isTrainable()) continue;
            link.addToWeight(this.calcAdaptation(link));
            link.setGradient(0.0);
        }
    }

    public NeuralNet getNetwork() {
        return this.net;
    }

    public void setNetwork(NeuralNet neuralNet) {
        this.net = neuralNet;
    }

    public void setLearnRate(double d) {
        this.learnRate = d;
    }

    public double getLearnRate() {
        return this.learnRate;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setEpochs(int n) {
        this.epochs = n;
    }

    public TrainingSignalGenerator getTrainingSignalGenerator() {
        return this.trainingSignalGenerator;
    }

    public void setTrainingSignalGenerator(TrainingSignalGenerator trainingSignalGenerator) {
        this.trainingSignalGenerator = trainingSignalGenerator;
    }

    public Neuron getWinningNeuron(List<Double> list) {
        Neuron neuron = null;
        double d = Double.MIN_VALUE;
        this.net.setInput(list);
        this.net.innervate();
        for (int i = 0; i < this.net.getOutputNeuronCount(); ++i) {
            Neuron neuron2 = this.net.getOutputNeuron(i);
            double d2 = neuron2.getOutput();
            if (!(d2 > d)) continue;
            d = d2;
            neuron = neuron2;
        }
        return neuron;
    }

    public int getRank(List<Double> list, int n) {
        this.net.setInput(list);
        this.net.innervate();
        int n2 = 0;
        double d = this.net.getOutputNeuron(n).getOutput();
        for (int i = 0; i < this.net.getOutputNeuronCount(); ++i) {
            Neuron neuron = this.net.getOutputNeuron(i);
            if (!(neuron.getOutput() >= d)) continue;
            ++n2;
        }
        return n2;
    }

    public String toString() {
        return "(" + this.getClass().getName() + ")";
    }

    @Override
    public void fromXML(Element element) {
        this.learnRate = Xml.getProperty(element, "learnRate", this.learnRate);
        this.epochs = Xml.getProperty(element, "epochs", this.epochs);
        this.trainingSignalGenerator = (TrainingSignalGenerator)Xml.getStorable(element, "trainingSignalGenerator");
    }

    @Override
    public Element toXML(Element element) {
        Element element2 = Xml.addStorable(element, "trainer", this);
        element2.setAttribute("learnRate", String.valueOf(this.learnRate));
        element2.setAttribute("epochs", String.valueOf(this.epochs));
        this.trainingSignalGenerator.toXML(element2);
        return element2;
    }

    protected abstract void reset();

    protected abstract void endTrain();

    protected abstract void train(List<Double> var1, List<Double> var2);

    protected abstract double calcAdaptation(BrainPart var1);
}

