package samples.programs;

import boone.NeuralNet;
import boone.Neuron;
import boone.PatternSet;
import boone.Trainer;
import boone.io.BooneFilter;
import boone.map.Function;
import boone.structure.*;
import boone.training.AdamTrainer;
import boone.training.BackpropTrainer;
import boone.training.CrossEntropy;
import boone.training.SquareError;
import boone.util.Common;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class MNISTCNNTest {

    public static void main(String[] args) {

        System.out.println("*** Creating convolutional network...");

        Neuron reluNeuron = new Neuron(new Function.ReLU());
        Neuron idNeuron = new Neuron(new Function.Identity());
        Neuron tanhNeuron = new Neuron(new Function.TanH());

        ArrayList<Layer> layers = new ArrayList<>();
        layers.add(new FeedForwardLayer(28, 28, 1, null, null));        // width, height, depth
        //        layers.add(new ConvolutionLayer(5, 0, 1, 6, tanhNeuron, null));       // filter, pad, stride, channels
        //        layers.add(new PoolingLayer(2, 0, 2));              // filter, pad, stride
        //        layers.add(new ConvolutionLayer(5, 0, 1, 16, tanhNeuron, null));       // filter, pad, stride, channels
        //        layers.add(new PoolingLayer(2, 0, 2));              // filter, pad, stride
        //        layers.add(new ConvolutionLayer(4, 0, 1, 120, tanhNeuron, null));       // filter, pad, stride, channels
        //        layers.add(new FeedForwardLayer(84, tanhNeuron, null));   // width, ReLU
        //        layers.add(new FeedForwardLayer(10, tanhNeuron, null));     // output layer, 10 classes

        layers.add(new ConvolutionLayer(5, 0, 1, 6));       // filter, pad, stride, channels
        layers.add(new PoolingLayer(2, 0, 2));              // filter, pad, stride
        layers.add(new ConvolutionLayer(5, 0, 1, 16));       // filter, pad, stride, channels
        layers.add(new PoolingLayer(2, 0, 2));              // filter, pad, stride
        layers.add(new ConvolutionLayer(4, 0, 1, 120));       // filter, pad, stride, channels
        layers.add(new FeedForwardLayer(84, reluNeuron, null));   // width, ReLU
        layers.add(new FeedForwardLayer(10, idNeuron, null));     // output layer, 10 classes

        NeuralNet net = NetFactory.createFeedForward(layers, new AdamTrainer(), true);
        //        System.out.println("FF = " + net.isFeedForward());
        //        System.out.println(net);
        for (Layer layer : layers)
            System.out.println(layer);
        System.out.println("Layer " + layers.get(1).getLevel() + " links: " + layers.get(1).getLinkCount());
        System.out.println("Layer " + layers.get(3).getLevel() + " links: " + layers.get(3).getLinkCount());
        System.out.println("Layer " + layers.get(5).getLevel() + " links: " + layers.get(5).getLinkCount());
        System.out.println("Output layer level " + net.getOutputNeuron(9).getLayer());
        System.out.println("Links: " + net.getLinkCount());


        System.out.println("*** Reading MNIST data" + "...");

        File trainSet = new File("samples/data/mnistTrainBatch.xpat");    // small batches for testing
        File testSet = new File("samples/data/mnistTestBatch.xpat");      // from BatchGenerator

        PatternSet trainPatterns = new PatternSet();
        PatternSet testPatterns = new PatternSet();
        //
        //        File trainSet = new File("samples/data/mnistTrainUnit.csv");
        //        File testSet = new File("samples/data/mnistTestUnit.csv");
        //        IOFilter trainFilter = new CSVPatternFilter(',', 784, 10, false, 128);        // CSV without header, streaming
        //        IOFilter testFilter = new CSVPatternFilter(',', 784, 10, false, 1000);       // CSV without header, streaming
        //        trainFilter.setCompressed(true);
        //        testFilter.setCompressed(true);
        //
        //        PatternSet trainPatterns = new PatternSet(trainFilter);
        //        PatternSet testPatterns = new PatternSet(testFilter);


        try {
            trainPatterns.load(trainSet);
            testPatterns.load(testSet);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("Could not load patterns!");
            System.exit(-1);
        }
        //        System.out.println(testPatterns);

        int steps = 10;
        int epochs = 1;
        Trainer trainer = net.getTrainer();
        trainer.setTrainingSignalGenerator(new CrossEntropy());
//        trainer.setLearnRate(0.25);
        trainer.setTrainingData(trainPatterns);
        trainer.setTestData(testPatterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);
//        trainer.setBatchSize(1);
        //        trainer.setShuffle(false);
        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        double timeStamp = Common.getTimeStamp();

        for (int i = 0; i < steps; i++) {
            trainer.train();
            System.out.println(((i + 1) * epochs) + ". - " + net.getTrainer().test());  // streaming
            printCurrentScore(net, testPatterns);                                       // test score for the current epoch
        }
        System.out.println("Training time[s]: " + Common.getDuration(timeStamp));
        //        System.out.println(net);
        try {
            System.out.println("*** Saving CNN" + "...");
            net.save(new File("check.xnet"));
            System.out.println("*** Loading CNN" + "...");
            net = NeuralNet.load(new File("check.xnet"), new BooneFilter("net", true));
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println("Links: " + net.getLinkCount());
        printCurrentScore(net, testPatterns);                                       // test score for the current epoch
    }


    private static void printCurrentScore(NeuralNet net, PatternSet patterns) {

        Trainer trainer = net.getTrainer();
        int patternCount = 0;
        int score = 0;

        while (patterns.hasMoreElements()) {                                            // streaming
            int index = patterns.getNextElement();
            ++patternCount;

            Neuron winner = trainer.getWinningNeuron(patterns.getInputs().get(index));
            int win = net.getOutputNeuronIndex(winner);
            List<Double> targets = patterns.getTargets().get(index);
            int corr = findMax(targets);
            if (win == corr)
                ++score;
        }
        System.out.println("Score: " + (score * 100.0 / patternCount) + "% (" + score + "/" + patternCount + ")");
    }


    private static int findMax(List<Double> a) {

        int max = 0;
        for (int i = 1; i < a.size(); i++)
            if (a.get(i) > a.get(max))
                max = i;
        return max;
    }

}
