package samples.programs;

import boone.NeuralNet;
import boone.PatternSet;
import boone.Trainer;
import boone.structure.*;
import boone.training.AdamTrainer;
import boone.util.Common;
import boone.util.Conversion;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class CNNTest {

    public static void main(String[] args) {

        System.out.println("*** Creating convolutional network...");

        ArrayList<Layer> layers = new ArrayList<>();
        layers.add(new FeedForwardLayer(5, 5, 3));          // width, height, depth
        layers.add(new ConvolutionLayer(3, 1, 1, 1));       // filter, pad, stride, channels
        layers.add(new PoolingLayer(2, 0, 2));              // filter, pad, stride
        layers.add(new FeedForwardLayer(5, 1, 1));          // width, height, depth
        layers.add(new FeedForwardLayer(1));                // output layer

        NeuralNet net = NetFactory.createFeedForward(layers, new AdamTrainer(), true);
        //        System.out.println("FF = " + net.isFeedForward());
        //                System.out.println(net);
        System.out.println(layers.get(1));
        System.out.println(layers.get(2));
        System.out.println("Neurons: " + net.getNeuronCount());
        System.out.println("Links: " + net.getLinkCount());

        /* Create random binary images classified into all-black or some-white. */
        PatternSet patterns = new PatternSet();
        int trainingSamples = 5000;
        Layer l = layers.get(0);                                            // input layer
        double[] image = new double[l.getWidth() * l.getHeight()];          // the image
        double[] c = new double[1];                                         // the class

        for (int i = 0; i < trainingSamples; i++) {
            Arrays.fill(image, 0.0);                                        // all-black
            c[0] = 0.0;                                                     // class all-black

            while (Math.random() > 0.9) {
                int pix = (int) Math.floor(Math.random() * image.length);    // random pixel
                image[pix] = 1;                                              // paint it white
                c[0] = 1.0;                                                  // class some-white
            }
            List<Double> input = new ArrayList<>();                         // input pattern
            for (int j = 0; j < l.getDepth(); j++)                          // same image for all channels
                Conversion.copyTo(image, input);

            patterns.getInputs().add(input);
            patterns.getTargets().add(Conversion.asList(c));
        }

        /* Set up training. */
        int steps = 10;
        int epochs = 20;
        Trainer trainer = net.getTrainer();
        System.out.println(trainer);
        //        trainer.setLearnRate(0.0005);
        trainer.setTrainingData(patterns);
        trainer.setTestData(patterns);
        trainer.setEpochs(epochs);
        trainer.setStepMode(true);                                            // training image steps
        System.out.println("*** Training " + (steps * epochs) + " epochs...");
        System.out.println("Error: ");
        double start = Common.getTimeStamp();
        for (int i = 0; i < steps; i++) {
            trainer.train();
            System.out.println(((i + 1) * epochs) + ". - " + net.getTrainer().test());
            //            System.out.println(net);
        }
        System.out.println("Training time[s] = " + Common.getDuration(start));

        try {
            net.save(new File("check.xnet"));
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println("\n*** Testing the network...");
        double hits = 0;
        for (int i = 0; i < patterns.size(); i++) {
            net.setInput(patterns.getInputs().get(i));
            net.innervate();
            c[0] = 0.0;
            if (net.getOutputNeuron(0).getOutput() > 0.5)
                c[0] = 1.0;
            if (patterns.getTargets().get(i).get(0) == c[0])
                ++hits;
        }
        System.out.println("Accuracy: " + (hits * 100 / patterns.getInputs().size()) + "%");
    }

}
