package samples.programs;

import boone.NeuralNet;
import boone.PatternSet;
import boone.Trainer;
import boone.io.IOFilter;
import boone.io.Proben1PatternFilter;
import boone.map.Function;
import boone.structure.NetFactory;
import boone.training.BackpropTrainer;
import boone.training.SAETrainer;
import boone.util.Conversion;
import boone.util.Patterns;

import java.io.File;
import java.io.IOException;

/**
 * Test of stacked auto-encoder on classification task
 *
 * Task: Proben1 Hearts Benchmark
 * Version: Training and Testdata
 *
 * @author Team stacked auto-encoer
 */

public class SAEClassifierProben1 {

    public static void main(String[] args) {

        // settings - print, training
        boolean print = false;
        boolean fineTune = false;


        // parameters
        int inputLayer = 35, layer1 = 13, encodingLayer = 7, classificationLayer = 2;

        Function sigmoid = new boone.map.Function.Sigmoid();
        Trainer trainer;


    // -- Read Data -------------------------------------------------------------------------------------------------
        System.out.println("Loading Proben1 file 'samples/data/heart1.dt'...");

        IOFilter filter = new Proben1PatternFilter();
        File file = new File("test/samples/data/heart1.dt");
        PatternSet proben1 = new PatternSet(filter);

        try {
            proben1.load(file);
        } catch (IOException e) {
            e.printStackTrace();
            return;
        }

        // training and test data
        PatternSet trainingData = Patterns.getSubSet(proben1, 0, 699);
        PatternSet testData = Patterns.getSubSet(proben1, 700, 900);


        // -- SAE -------------------------------------------------------------------------------------------------
        System.out.println("\n\n-- SAE ----------------------------------------------------------------------");
        System.out.println("*** Creating feed forward network with SAE trainer...");
        int steps = 10;

        NeuralNet net = NetFactory.createFeedForward(new int[]{inputLayer, layer1, encodingLayer}, false, new SAETrainer(new BackpropTrainer(), steps), null, null);

        // setup training
        int epochs = 10;

        trainer = net.getTrainer();
        trainer.setTrainingData(trainingData);
        trainer.setTestData(testData);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);
        ((SAETrainer) trainer).setPrint(print);

        // training in steps
        System.out.println("*** Training " + (steps * epochs) + " epochs...");

        trainer.train();


    // -- Classifier --------------------------------------------------------------------------------------
        System.out.println("\n\n-- Classifier ---------------------------------------------------------------");
        System.out.println("*** Creating feed forward classifier network...");

        NeuralNet cfnet = NetFactory.createFeedForward(new int[]{encodingLayer, classificationLayer}, false, new BackpropTrainer(), null, null);

        // encode input the classifier is trained on
        PatternSet cfTrainingPatterns = AEUtils.encodePatternSet(net, trainingData);
        PatternSet cfTestPatterns = AEUtils.encodePatternSet(net, testData);

        // setup training classifier
        epochs = 10;
        steps = 100;

        trainer = cfnet.getTrainer();
        trainer.setTrainingData(cfTrainingPatterns);
        trainer.setTestData(cfTestPatterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);

        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        AEUtils.printConditionally("Training Error: ", print);

        for (int i = 0; i < steps; i++) {
            trainer.train();

            AEUtils.printConditionally((i * epochs) + ". - " + cfnet.getTrainer().test(), print);
        }

        System.out.println("\n*** Testing the classifier...");
        for (int i = 0; i < cfTestPatterns.size(); i++)
            AEUtils.printConditionally("Testing Error " + i + " = " + cfnet.getTrainer().test(cfTestPatterns.getInputs().get(i), cfTestPatterns.getTargets().get(i)), true);


        // -- SAE + Classifier --------------------------------------------------------------------------------------
        System.out.println("\n\n-- SAE + Classifier ---------------------------------------------------------");
        System.out.println("*** Connecting network...");

        net.linkNet(cfnet);

        if (fineTune) {
            // setup fine-tuning (classifier trainer)
            epochs = 10;
            steps = 100;

            trainer = cfnet.getTrainer();
            trainer.setEpochs(epochs);
            trainer.setTrainingData(trainingData);
            trainer.setTestData(testData);
            net.setTrainer(trainer);
            trainer.setNetwork(net);

            System.out.println("*** Fine-tuning the netwerk...");
            AEUtils.printConditionally("Fine-tuning Error: ", print);
            for (int i = 0; i < steps; i++) {
                trainer.train();
                AEUtils.printConditionally((i * epochs) + ". - " + net.getTrainer().test(), print);
            }
        }


    // -- testing the SAE + classifier -------------------------------------------------------------------------
        System.out.println("\n*** Testing...");
        int wrong = 0;

        double[] netOutput = new double[classificationLayer];

        // compare expected result with net output
        for (int i = 0; i < testData.size(); i++) {
            net.setInput(testData.getInputs().get(i));

            net.innervate();

            netOutput = net.getOutput(netOutput);

            if (!Conversion.asList(AEUtils.getDiagnose(netOutput)).equals(testData.getTargets().get(i))) {
                AEUtils.printConditionally(testData.getTargets().get(i) + "," + Conversion.asList(netOutput), print);
                wrong++;
            }
        }

        System.out.println((double) wrong/testData.size()*100);
        System.out.println("Done.");
    }
}
