package samples.programs;

import boone.NeuralNet;
import boone.PatternSet;
import boone.Trainer;
import boone.map.Function;
import boone.structure.NetFactory;
import boone.training.BackpropTrainer;
import boone.training.RpropTrainer;
import boone.training.SAETrainer;
import boone.util.Conversion;

/**
 * Test of stacked auto-encoder on classification task
 *
 * Task: grading students based on test-results
 * Version: same patterns
 *
 * @author Team stacked auto-encoer
 */

public class SAEClassifierGrading{

    public static void main(String[] args) {

        // settings - print, training
        boolean print = false;
        boolean fineTune = false;


        // parameters
        int gradeScale = 5;
        int pointsPerTest = 20;
        int tests = 6;
        int students = 1000;

        int encodingLength = 4;

        Function sigmoid = new boone.map.Function.Sigmoid();
        Trainer trainer;


    // -- SAE --------------------------------------------------------------------------------------
        System.out.println("\n\n-- SAE ----------------------------------------------------------------------");
        System.out.println("*** Creating feed forward network with SAE trainer...");

        NeuralNet net = NetFactory.createFeedForward(new int[]{tests, 7, 5, encodingLength}, false, new SAETrainer(new BackpropTrainer(), 100), null, null);

        // generate unrealistic data to test (different distribution, no correlation)
        double[][] randomPatterns = AEUtils.generateRandomTrainingsData(1000, 6, 0, 1, false);

        // generate realistic training data
        double[][] inPatterns = AEUtils.generateClassificationTestResults(students, tests, pointsPerTest);
        double[][] outPatterns = AEUtils.generateClassificationResult(inPatterns, gradeScale);

        PatternSet patterns = new PatternSet();
        for (int i = 0; i < inPatterns.length; i++) {
            patterns.getInputs().add(Conversion.asList(inPatterns[i]));
            patterns.getTargets().add(Conversion.asList(outPatterns[i]));
        }

        // setup training SAE
        int steps = 100;
        int epochs = 10;

        trainer = net.getTrainer();
        trainer.setTrainingData(patterns);
        trainer.setTestData(patterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);
        ((SAETrainer)trainer).setPrint(print);

        // training in steps
        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        trainer.train();


    // -- Classifier --------------------------------------------------------------------------------------
        System.out.println("\n\n-- Classifier ---------------------------------------------------------------");
        System.out.println("*** Creating feed forward classifier network...");

        NeuralNet cfnet = NetFactory.createFeedForward(new int[]{encodingLength, gradeScale}, false, new RpropTrainer(), null, null);

        // encode input the classifier is trained on
        PatternSet cfPatterns = AEUtils.encodePatternSet(net, patterns);


        // setup training classifier
        epochs = 10;

        trainer = cfnet.getTrainer();
        trainer.setTrainingData(cfPatterns);
        trainer.setTestData(cfPatterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);

        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        AEUtils.printConditionally("Training Error: ", print);

        for (int i = 0; i < steps; i++) {
            trainer.train();

            AEUtils.printConditionally((i * epochs) + ". - " + cfnet.getTrainer().test(), print);
        }

        System.out.println("\n*** Testing the classifier...");
        for (int i = 0; i < cfPatterns.size(); i++)
            AEUtils.printConditionally("Testing Error " + i + " = " + cfnet.getTrainer().test(cfPatterns.getInputs().get(i), cfPatterns.getTargets().get(i)), print);


    // -- SAE + Classifier --------------------------------------------------------------------------------------
        System.out.println("\n\n-- SAE + Classifier ---------------------------------------------------------");
        System.out.println("*** Connecting network...");

        net.linkNet(cfnet);

        if (fineTune) {
            PatternSet tuning = new PatternSet();

            // setup fine-tuning (classifier trainer)
            epochs = 10;

            trainer = cfnet.getTrainer();
            trainer.setEpochs(epochs);
            trainer.setTrainingData(patterns);
            trainer.setTestData(patterns);
            net.setTrainer(trainer);
            trainer.setNetwork(net);

            System.out.println("*** Fine-tuning the network...");
            AEUtils.printConditionally("Fine-tuning Error: ", print);
            for (int i = 0; i < steps; i++) {
                trainer.train();
                AEUtils.printConditionally((i * epochs) + ". - " + cfnet.getTrainer().test(), print);
            }
        }

    // -- testing the SAE + classifier -------------------------------------------------------------------------
        System.out.println("\n*** Testing...");
        int testruns = 300;
        int wrong = 0;

        double[][] SAEInput = AEUtils.generateClassificationTestResults(testruns, tests, pointsPerTest);
        double[][] cfOut = AEUtils.generateClassificationResult(SAEInput, gradeScale);

        double[] netOutput = new double[gradeScale];

        // compare expected result with net output
        for (int i = 0; i < SAEInput.length; i++) {
            net.setInput(SAEInput[i]);
            net.innervate();

            netOutput = net.getOutput(netOutput);

            if (AEUtils.getGrade(cfOut[i]) != AEUtils.getGrade(netOutput)) {
                System.out.println(AEUtils.getGrade(cfOut[i]) + "," + AEUtils.getGrade(netOutput));
                wrong++;
            }
        }

        System.out.println("Error rate: " + (double)wrong / testruns);
        System.out.println("Done.");
    }
}