/*
 * Decompiled with CFR 0.152.
 */
package boone.training;

import boone.BrainPart;
import boone.Link;
import boone.NeuralNet;
import boone.Neuron;
import boone.PatternSet;
import boone.Trainer;
import boone.map.Function;
import boone.neurons.NeuronList;
import boone.training.RpropTrainer;
import boone.training.SquareError;
import boone.util.Conversion;
import java.util.ArrayList;
import java.util.List;

public class SAETrainer
extends Trainer {
    protected Trainer partTrainer;
    protected NeuralNet partNet;
    protected NeuronList hiddenLayer;
    protected boolean wasInputLayer = false;
    protected boolean wasOutputLayer = false;
    protected int steps = 100;
    protected boolean print = false;

    public SAETrainer() {
        this.setTrainingSignalGenerator(new SquareError());
        this.hiddenLayer = new NeuronList();
        this.partTrainer = new RpropTrainer();
        this.partNet = new NeuralNet();
    }

    public SAETrainer(Trainer trainer, int n) {
        this();
        this.steps = n;
        this.partTrainer = trainer;
    }

    @Override
    protected void train(List<Double> list, List<Double> list2) {
    }

    @Override
    protected double calcAdaptation(BrainPart brainPart) {
        return 0.0;
    }

    @Override
    protected void reset() {
    }

    @Override
    protected void endBatch() {
    }

    @Override
    protected void endTrain() {
    }

    @Override
    public void train() {
        this.resetTrainer();
        this.updatePartTrainer();
        this.updatePartTrainerData();
        for (int i = 0; i < this.numberOfLayers(); ++i) {
            this.extractInputOutputLayer(this.partNet, i);
            this.extractHiddenLayer(this.partNet, i + 1);
            this.partNet.setTopology(0);
            for (int j = 0; j < this.steps; ++j) {
                this.partTrainer.train();
                if (!this.print) continue;
                System.out.println(j * this.epochs + ". - " + this.partTrainer.test());
            }
            this.partTrainer.setTrainingData(this.getNextSet(this.partNet, this.partTrainer.getTrainingData()));
            this.partTrainer.setTestData(this.getNextSet(this.partNet, this.partTrainer.getTestData()));
            this.resetLayer(i);
            this.resetTrainer();
        }
    }

    protected void resetLayer(int n) {
        int n2;
        for (n2 = this.partNet.getNeuronCount() - 1; n2 >= 0; --n2) {
            Neuron neuron = this.partNet.getNeuron(n2);
            if (neuron.isInputNeuron()) {
                if (!this.wasInputLayer) {
                    this.partNet.getNeuron(n2).setActivationFn(((Neuron)this.getNeuronsOfLayer(n + 1).get(0)).getActivationFn());
                    this.partNet.getNeuron(n2).setUsingBias(true);
                }
            } else if (neuron.isOutputNeuron()) {
                this.partNet.removeNeuron(neuron);
            } else if (this.wasOutputLayer) {
                neuron.setOutputNeuron(true);
            }
            neuron.setOutput(0.0);
            neuron.setInput(0.0);
            neuron.setLinkInput(0.0);
            neuron.setBias(0.0);
            neuron.reset();
            neuron.setBias(Math.random() * 0.2 - 0.1);
        }
        int n3 = this.partNet.getLinkCount();
        for (n2 = 0; n2 < n3; ++n2) {
            this.partNet.getLink(n2).reset();
        }
        this.wasInputLayer = false;
        this.wasOutputLayer = false;
        this.net.setTopology(0);
    }

    protected void resetTrainer() {
        this.partNet = new NeuralNet();
        this.partNet.setTrainer(this.partTrainer);
        this.partTrainer.setNetwork(this.partNet);
    }

    protected void updatePartTrainer() {
        this.partTrainer.setEpochs(this.epochs);
        this.partTrainer.setLearnRate(this.learnRate);
        this.partTrainer.setShuffle(this.shuffle);
        this.partTrainer.setStepMode(this.stepMode);
    }

    protected void updatePartTrainerData() {
        this.partTrainer.setTrainingData(new PatternSet());
        this.partTrainer.setTestData(new PatternSet());
        this.partTrainer.getTrainingData().getInputs().addAll(this.trainingData.getInputs());
        this.partTrainer.getTrainingData().getTargets().addAll(this.trainingData.getInputs());
        this.partTrainer.getTestData().getInputs().addAll(this.testData.getInputs());
        this.partTrainer.getTestData().getTargets().addAll(this.testData.getInputs());
    }

    protected NeuronList getNeuronsOfLayer(int n) {
        NeuronList neuronList = new NeuronList();
        int n2 = this.net.getNeuronCount();
        for (int i = 0; i < n2; ++i) {
            if (this.net.getNeuron(i).getLayer() != n) continue;
            neuronList.add(this.net.getNeuron(i));
        }
        return neuronList;
    }

    protected void extractInputOutputLayer(NeuralNet neuralNet, int n) {
        Function.Identity identity = new Function.Identity();
        NeuronList neuronList = this.getNeuronsOfLayer(n);
        if (!neuronList.isEmpty()) {
            this.wasInputLayer = ((Neuron)neuronList.get(0)).isInputNeuron();
        }
        for (Neuron neuron : neuronList) {
            neuron.setInputNeuron(true);
            Neuron neuron2 = neuron.clone();
            neuron2.setInputNeuron(false);
            neuron2.setOutputNeuron(true);
            neuron2.setUsingBias(true);
            neuron2.setBias(Math.random() * 0.2 - 0.1);
            if (this.wasInputLayer) {
                neuron2.setActivationFn(((Neuron)this.getNeuronsOfLayer(n + 1).get(0)).getActivationFn());
            } else {
                neuron.setActivationFn(identity);
                neuron.setUsingBias(false);
            }
            int n2 = neuron.getLinks().size();
            for (int i = 0; i < n2; ++i) {
                if (neuron.getLink(i).getSource() != neuron) continue;
                Link link = neuron.getLink(i);
                Link link2 = link.clone();
                link2.setSource(link.getSink());
                link2.setSink(neuron2);
                link2.randomize(-0.1, 0.1);
                neuralNet.addLink(link);
                neuralNet.addLink(link2);
            }
            neuralNet.addNeuron(neuron);
            neuralNet.addNeuron(neuron2);
        }
    }

    protected void extractHiddenLayer(NeuralNet neuralNet, int n) {
        this.hiddenLayer = new NeuronList();
        NeuronList neuronList = this.getNeuronsOfLayer(n);
        if (!neuronList.isEmpty()) {
            this.wasOutputLayer = ((Neuron)neuronList.get(0)).isOutputNeuron();
        }
        for (Neuron neuron : neuronList) {
            neuron.setInputNeuron(false);
            neuron.setOutputNeuron(false);
            neuralNet.addNeuron(neuron);
            this.hiddenLayer.add(neuron);
        }
    }

    protected int numberOfLayers() {
        if (this.net.getTopology() == 1 && this.net.getOutputNeuronCount() != 0) {
            return this.net.getOutputNeuron(0).getLayer();
        }
        return -1;
    }

    protected PatternSet getNextSet(NeuralNet neuralNet, PatternSet patternSet) {
        PatternSet patternSet2 = new PatternSet();
        double[] dArray = new double[neuralNet.getInputNeuronCount()];
        if (this.print) {
            System.out.println("\n*** Testing auto-encoding of layer ...");
        }
        for (int i = 0; i < patternSet.getInputs().size(); ++i) {
            List<Double> list = patternSet.getInputs().get(i);
            if (this.print) {
                double d = this.trainingSignalGenerator.computeError(neuralNet, list, list);
                dArray = neuralNet.getOutput(dArray);
                System.out.println();
                System.out.println("Error " + i + " = " + d);
                System.out.print("     Input: ");
                System.out.println(patternSet.getInputs().get(i));
                System.out.print("     Output: ");
                System.out.println(Conversion.asList(dArray));
            }
            patternSet2.getInputs().add(this.getHiddenLayerOutput(new ArrayList<Double>()));
            patternSet2.getTargets().add(this.getHiddenLayerOutput(new ArrayList<Double>()));
        }
        return patternSet2;
    }

    protected List<Double> getHiddenLayerOutput(List<Double> list) {
        if (list == null || list.size() < this.hiddenLayer.size()) {
            list = new ArrayList<Double>();
        }
        for (int i = 0; i < this.hiddenLayer.size(); ++i) {
            list.add(i, ((Neuron)this.hiddenLayer.get(i)).getOutput());
        }
        return list;
    }

    public void setPrint(boolean bl) {
        this.print = bl;
    }
}

