/*
 * Decompiled with CFR 0.152.
 */
package boone.training;

import boone.BrainPart;
import boone.Link;
import boone.NeuralNet;
import boone.Neuron;
import boone.PatternSet;
import boone.Trainer;
import boone.map.Topology;
import boone.util.Xml;
import java.util.List;
import org.jdom2.Element;

public class SOMTrainer
extends Trainer {
    protected Topology topology;
    protected double startRadius;
    protected double radius;
    protected double deltaRadius;

    public SOMTrainer(Topology topology) {
        this.topology = topology;
        this.learnRate = 100.0;
        this.startRadius = 3.0;
    }

    @Override
    public SOMTrainer clone() {
        return (SOMTrainer)super.clone();
    }

    @Override
    protected void reset() {
        this.radius = this.startRadius;
        this.deltaRadius = (this.startRadius - 1.0) / (double)this.epochs;
    }

    @Override
    public void setNetwork(NeuralNet neuralNet) {
        super.setNetwork(neuralNet);
        this.topology.setNet(neuralNet);
    }

    public Topology getTopology() {
        return this.topology;
    }

    public double getStartRadius() {
        return this.startRadius;
    }

    public void setStartRadius(double d) {
        this.startRadius = d;
    }

    @Override
    protected void train(List<Double> list, List<Double> list2) {
        Neuron neuron = this.getWinningNeuron(list);
        List<Neuron> list3 = this.topology.getNeighbors(neuron, (int)this.radius);
        for (Neuron neuron2 : list3) {
            this.train(neuron2, list, this.learnRate / (double)(this.currentEpoch + 1));
        }
    }

    @Override
    protected double calcAdaptation(BrainPart brainPart) {
        return 0.0;
    }

    @Override
    protected void endBatch() {
        this.radius -= this.deltaRadius;
    }

    @Override
    protected void endTrain() {
        this.labelNeurons(this.trainingData);
    }

    @Override
    public double test() {
        int n = 0;
        for (List<Double> list : this.testData.getInputs()) {
            int n2;
            String string;
            String string2 = this.assignLabel(list);
            if (string2.equals(string = this.testData.getTargetLabelOfPattern(n2 = this.testData.getInputs().indexOf(list)))) continue;
            ++n;
        }
        return (double)n / (double)this.testData.size();
    }

    @Override
    public Element toXML(Element element) {
        Element element2 = super.toXML(element);
        element2.setAttribute("StartRadius", String.valueOf(this.startRadius));
        this.topology.toXML(element2);
        return element2;
    }

    @Override
    public void fromXML(Element element) {
        super.fromXML(element);
        this.startRadius = Xml.getProperty(element, "StartRadius", 3.0);
        this.topology = (Topology)Xml.getStorable(element, "Topology");
    }

    protected void train(Neuron neuron, List<Double> list, double d) {
        BrainPart brainPart;
        int n;
        double d2 = 0.0;
        int n2 = this.net.getInputNeuronCount();
        for (n = 0; n < n2; ++n) {
            brainPart = this.net.getInputNeuron(n);
            Link link = ((Neuron)brainPart).getLinkTo(neuron);
            double d3 = link.getWeight() + d * list.get(n);
            link.setWeight(d3);
            d2 += d3 * d3;
        }
        d2 = Math.sqrt(d2);
        n2 = this.net.getInputNeuronCount();
        for (n = 0; n < n2; ++n) {
            brainPart = this.net.getInputNeuron(n).getLinkTo(neuron);
            brainPart.setWeight(brainPart.getWeight() / d2);
        }
    }

    private String assignLabel(List<Double> list) {
        Neuron neuron = this.net.getOutputNeuron(0);
        double d = -1.7976931348623157E308;
        this.net.setInput(list);
        this.net.innervate();
        for (int i = 0; i < this.net.getOutputNeuronCount(); ++i) {
            double d2;
            Neuron neuron2 = this.net.getOutputNeuron(i);
            if (neuron2.getName().isEmpty() || !((d2 = neuron2.getOutput()) > d)) continue;
            d = d2;
            neuron = neuron2;
        }
        return neuron.getName();
    }

    private void labelNeurons(PatternSet patternSet) {
        int n;
        int[][] nArray = new int[this.net.getOutputNeuronCount()][patternSet.getTargetLabelCount()];
        for (n = 0; n < patternSet.size(); ++n) {
            Neuron neuron = this.getWinningNeuron(patternSet.getInputs().get(n));
            double d = patternSet.getTargets().get(n).get(0);
            int[] nArray2 = nArray[(int)neuron.getID()];
            int n2 = (int)d;
            nArray2[n2] = nArray2[n2] + 1;
        }
        for (n = 0; n < nArray.length; ++n) {
            int n3 = 0;
            int n4 = nArray[n][0];
            for (int i = 1; i < nArray[n].length; ++i) {
                if (nArray[n][i] <= n4) continue;
                n4 = nArray[n][i];
                n3 = i;
            }
            String string = "";
            if (n4 > 0) {
                string = patternSet.getTargetLabel(n3);
            }
            this.net.getOutputNeuron(n).setName(string);
        }
    }
}

