package samples.programs;

import boone.NeuralNet;
import boone.Neuron;
import boone.PatternSet;
import boone.Trainer;
import boone.io.BooneFilter;
import boone.io.CSVPatternFilter;
import boone.io.IOFilter;
import boone.training.AdamTrainer;
import boone.util.Common;
import boone.util.Nets;

import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.List;

public class NetBatch {

    public static void main(String[] args) {

        System.out.println("*** Reading data sets...");

                File trainSet = new File("samples/data/mnistTrainBatch.xpat");    // small batches for testing
                File testSet = new File("samples/data/mnistTestBatch.xpat");      // from BatchGenerator

                PatternSet trainPatterns = new PatternSet();
                PatternSet testPatterns = new PatternSet();

//        File trainSet = new File("samples/data/mnistTrainUnit.csv");
//        File testSet = new File("samples/data/mnistTestUnit.csv");
//        IOFilter trainFilter = new CSVPatternFilter(',', 784, 10, false, 128);        // CSV without header, streaming
//        IOFilter testFilter = new CSVPatternFilter(',', 784, 10, false, 1000);       // CSV without header, streaming
//        trainFilter.setCompressed(true);
//        testFilter.setCompressed(true);

//        PatternSet trainPatterns = new PatternSet(trainFilter);
//        PatternSet testPatterns = new PatternSet(testFilter);

        try {
            trainPatterns.load(trainSet);
            testPatterns.load(testSet);
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("Could not load patterns!");
            System.exit(-1);
        }
        //        System.out.println(testPatterns);

        System.out.println("*** Evaluating net batch...");

        File currDir = new File("./");
        File[] netFiles = currDir.listFiles(new FilenameFilter() {
            public boolean accept(File dir, String name) {

                return name.endsWith(".xnet");
            }
        });
        if (netFiles == null) {
            System.out.println("No xnet files in current directory!");
            return;
        }
        NeuralNet net;
        IOFilter filter = new BooneFilter("net", true);

        for (File netFile : netFiles) {
            try {
                net = NeuralNet.load(netFile, filter);
            } catch (IOException e) {
                e.printStackTrace();
                return;
            }
            System.out.println(netFile.getPath());
            System.out.println("Neurons: " + net.getNeuronCount() + ", Links: " + net.getLinkCount() + ", Layers: " + (net.getOutputNeuron(0).getLayer() + 1));
            Nets.purify(net);
            System.out.println("Neurons: " + net.getNeuronCount() + ", Links: " + net.getLinkCount() + ", Layers: " + (net.getOutputNeuron(0).getLayer() + 1));

            printCurrentAccuracy(net, testPatterns);
            trainNet(net, trainPatterns, null, 10, 10);
            printCurrentAccuracy(net, testPatterns);
        }
    }


    private static void printCurrentAccuracy(NeuralNet net, PatternSet patterns) {

        Trainer trainer = net.getTrainer();
        int patternCount = 0;
        int score = 0;

        while (patterns.hasMoreElements()) {                                            // streaming
            int index = patterns.getNextElement();
            ++patternCount;

            Neuron winner = trainer.getWinningNeuron(patterns.getInputs().get(index));
            int win = net.getOutputNeuronIndex(winner);
            List<Double> targets = patterns.getTargets().get(index);
            int corr = findMax(targets);
            if (win == corr)
                ++score;
        }
        System.out.println("Acc: " + Common.fixedPoint.format((double) score / patternCount) + " (" + score + "/" + patternCount + ")\n");
    }


    private static void trainNet(NeuralNet net, PatternSet trainSet, PatternSet testSet, int steps, int epochs) {

        net.setTrainer(new AdamTrainer());
        Trainer trainer = net.getTrainer();

//        trainer.setTrainingSignalGenerator(new CrossEntropy());
//        trainer.setLearnRate(0.25);
//        trainer.setBatchSize(1);
        trainer.setTrainingData(trainSet);
        trainer.setTestData(trainSet);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);
        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        //        double timeStamp = Common.getTimeStamp();

        for (int i = 0; i < steps; i++) {
            trainer.train();
            System.out.println(((i + 1) * epochs) + ". - " + Common.fixedPoint.format(net.getTrainer().test()));  // streaming
            if (testSet != null)
                printCurrentAccuracy(net, testSet);
        }
        //        System.out.println("Training time[s]: " + Common.getDuration(timeStamp));
    }


    private static int findMax(List<Double> a) {

        int max = 0;
        for (int i = 1; i < a.size(); i++)
            if (a.get(i) > a.get(max))
                max = i;
        return max;
    }

}
