/*
 * Decompiled with CFR 0.152.
 */
package boone.training;

import boone.BrainPart;
import boone.training.BackpropTrainer;
import boone.util.Xml;
import org.jdom2.Element;

public class AdamTrainer
extends BackpropTrainer {
    private double beta1 = 0.9;
    private double beta2 = 0.999;
    private double beta1Pot = 1.0;
    private double beta2Pot = 1.0;
    private double epsilon = 1.0E-8;

    public AdamTrainer() {
        this.setLearnRate(0.001);
        this.setBatchSize(32);
    }

    @Override
    protected void reset() {
        super.reset();
        this.beta1Pot = 1.0;
        this.beta2Pot = 1.0;
    }

    @Override
    protected double calcAdaptation(BrainPart brainPart) {
        double d = brainPart.getGradient();
        double d2 = this.beta1 * brainPart.getLastGradient() + (1.0 - this.beta1) * d;
        brainPart.setLastGradient(d2);
        this.beta1Pot *= this.beta1;
        double d3 = this.beta2 * brainPart.getAccu() + (1.0 - this.beta2) * d * d;
        brainPart.setAccu(d3);
        this.beta2Pot *= this.beta2;
        return -this.learnRate * ((d2 /= 1.0 - this.beta1Pot) / (Math.sqrt(d3 /= 1.0 - this.beta2Pot) + this.epsilon));
    }

    public double getBeta1() {
        return this.beta1;
    }

    public void setBeta1(double d) {
        this.beta1 = d;
    }

    public double getBeta2() {
        return this.beta2;
    }

    public void setBeta2(double d) {
        this.beta2 = d;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    @Override
    public Element toXML(Element element) {
        Element element2 = super.toXML(element);
        element2.setAttribute("beta1", String.valueOf(this.beta1));
        element2.setAttribute("beta2", String.valueOf(this.beta2));
        element2.setAttribute("epsilon", String.valueOf(this.epsilon));
        return element2;
    }

    @Override
    public void fromXML(Element element) {
        super.fromXML(element);
        this.beta1 = Xml.getProperty(element, "beta1", this.beta1);
        this.beta2 = Xml.getProperty(element, "beta2", this.beta2);
        this.epsilon = Xml.getProperty(element, "epsilon", this.epsilon);
    }
}

