package samples.programs;

import boone.PatternSet;
import boone.io.CSVPatternFilter;
import boone.io.IOFilter;
import boone.util.Patterns;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * A workbench for manipulating pattern sets and files.
 */
public class BatchGenerator {

    public static void main(String[] args) {

                System.out.println("*** Reading data (this may take a while)...");

                File trainFile = new File("samples/data/mnistTrainUnit.csv");
                File testFile = new File("samples/data/mnistTestUnit.csv");

                IOFilter filter = new CSVPatternFilter(',', 784, 10, false, 0);       // CSV without header, no streaming
                filter.setCompressed(true);

        IOFilter readFilter = new CSVPatternFilter(',', 784, 10, false, 1024);
        IOFilter writeFilter = new CSVPatternFilter(',', 784, 10, false, 1024);
        readFilter.setFile(new File("samples/data/mnist_testA.csv"));
        writeFilter.setFile(testFile);
        writeFilter.setCompressed(true);
        streamScale(readFilter, writeFilter);

        //        PatternSet trainPatterns = new PatternSet(filter);
        //        PatternSet testPatterns = new PatternSet(filter);
        //

//        System.out.println("*** Creating data (this may take a while)...");
//
//        PatternSet trainPatterns = new PatternSet();
//        PatternSet testPatterns = new PatternSet();
//
//        for (int i = 0; i < 10000; i++)
//            createPattern(64, 10, trainPatterns);
//        for (int i = 0; i < 1000; i++)
//            createPattern(64, 10, testPatterns);
//
        //        try {
        //            trainPatterns.load(trainFile);
        //            testPatterns.load(testFile);
        //        } catch (IOException e) {
        //            e.printStackTrace();
        //            System.exit(-1);
        //        }

        //        System.out.println("*** Creating random MNIST mini-batches...");
        //
        //        int trainBatch = 1000;
        //        int testBatch = 100;
        //
        //        Patterns.shuffle(trainPatterns, trainBatch);
        //        trainPatterns = Patterns.getSubSet(trainPatterns, 0, trainBatch);
        //        Patterns.shuffle(testPatterns, testBatch);
        //        testPatterns = Patterns.getSubSet(testPatterns, 0, testBatch);
        //        filter = new BooneFilter("set", true);
        //
        //        trainPatterns.setFilter(filter);
        //        testPatterns.setFilter(filter);

//        System.out.println("*** Saving batches...");
//
//        try {
//            trainPatterns.save(new File("samples/data/greyLevelsTrain.xpat"));
//            testPatterns.save(new File("samples/data/greyLevelsTest.xpat"));
//        } catch (IOException e) {
//            System.err.println("Could not save batches!");
//            e.printStackTrace();
//        }
//        //                double start = Common.getTimeStamp();
        //                System.out.println("Time[s] = " + Common.getDuration(start));

    }


    /**
     * Adds a simple input image with just one grey level in the unit interval to the given patterns. The corresponding class
     * is the first decimal of the grey level, e.g, a level of 0.34 is class 3.
     *
     * @param size     the number of pixels
     * @param catCount the number of categories
     * @param patterns the pattern set
     */
    private static void createPattern(int size, int catCount, PatternSet patterns) {

        double greyLevel = Math.random();
        int category = ((int) (greyLevel * catCount));
        List<Double> inputs = new ArrayList<>(size);
        List<Double> targets = new ArrayList<>(catCount);

        for (int i = 0; i < size; i++)
            inputs.add(greyLevel);
        patterns.getInputs().add(inputs);

        Patterns.encode(category, catCount, targets);
        patterns.getTargets().add(targets);
    }


    /**
     * Exchanges input and target patterns to correct the order of data. Also, encodes the single target class value
     * into a 1-of-10 encoding.
     *
     * @param set the incorrect pattern set
     */
    private static void adapt(PatternSet set) {

        List<List<Double>> inputs = set.getInputs();
        List<List<Double>> targets = set.getTargets();

        set.setInputs(targets);                                // reorder
        set.setTargets(inputs);

        for (List<Double> target : set.getTargets()) {
            double label = target.remove(0);
            Patterns.encode((int) label, 10, target);        // modify targets 1-of-10 encoding
        }

    }

    /**
     * Scales a large streamed pattern set. The pattern set is streamed in, scaled, and streamed out until
     * all the patterns are scaled. The scaling has to be set manually here. Note that not all filters support streaming.
     *
     * @param readFilter  the filter for reading data from a pattern file
     * @param writeFilter the filter for writing data to a pattern file
     */
    private static void streamScale(IOFilter readFilter, IOFilter writeFilter) {

        PatternSet patterns = new PatternSet(readFilter);

        try {
            patterns.load(readFilter);
            while (patterns.hasMoreElements()) {                                            // streaming
                Patterns.mapInputsToInterval(patterns, 0, 1, 0, 255);
                patterns.save(writeFilter);                                                // streaming
                System.out.println("Stream saved.");
            }
            writeFilter.reset();
        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("Could not stream patterns!");
        }
    }

}
