package samples.programs;

import boone.NeuralNet;
import boone.PatternSet;
import boone.Trainer;
import boone.structure.NetFactory;
import boone.training.RpropTrainer;
import boone.util.Conversion;

import java.io.File;
import java.io.IOException;
import java.util.List;

public class MNISTFFTest {

    public static void main(String[] args) {

        System.out.println("*** Creating feed forward network...");

        NeuralNet net = NetFactory.createFeedForward(new int[]{28 * 28, 200, 50, 10}, false, new RpropTrainer(), null, null);
        System.out.println("Links: " + net.getLinkCount());

        System.out.println("*** Reading MNIST batches" + "...");

        File trainBatch = new File("samples/data/mnistTrainBatch.xpat");       // small batches for testing
        File testBatch = new File("samples/data/mnistTestBatch.xpat");        // from BatchGenerator

        PatternSet trainPatterns = new PatternSet();
        PatternSet testPatterns = new PatternSet();

        //        IOFilter filter = new BooneFilter("set", false);

        try {
            trainPatterns.load(trainBatch);
            testPatterns.load(testBatch);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("Could not load batches!");
            System.exit(-1);
        }
        // original project
        //		int trainingSamples = 100;
        //		int testSamples = 20;

        int steps = 20;
        int epochs = 5;
        Trainer trainer = net.getTrainer();
        trainer.setTrainingData(trainPatterns);
        trainer.setTestData(trainPatterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);                                            // training in steps
        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        for (int i = 0; i < steps; i++) {
            trainer.train();
            System.out.println(((i + 1) * epochs) + ": " + net.getTrainer().test());
            printCurrentScore(net, testPatterns);   // test score for the current epoch
        }
    }

    private static void printCurrentScore(NeuralNet net, PatternSet patterns) {

        int cnt = 0;
        for (int i = 0; i < patterns.size(); i++) {
            net.setInput(patterns.getInputs().get(i));
            net.innervate();

            double[] outs = new double[10];
            List<Double> targets = patterns.getTargets().get(i);
            outs = net.getOutput(outs);
            int curr = findMax(Conversion.asList(outs));
            int targ = findMax(targets);

            if (targ == curr)
                cnt++;
        }
        System.out.println("Score: " + (cnt * 100.0 / patterns.getInputs().size()) + "% (" + cnt + "/" + patterns.getInputs().size() + ")");

    }

    private static int findMax(List<Double> a) {

        int max = 0;
        for (int i = 1; i < a.size(); i++)
            if (a.get(i) > a.get(max))
                max = i;
        return max;
    }

}
