package samples.programs;

import boone.NeuralNet;
import boone.PatternSet;
import boone.util.Conversion;
import boone.util.Patterns;

import java.util.List;
import java.util.Random;

/**
 * Utility methods for SAE Training
 *
 * Task: Grade, Proben1
 *
 * @author Team stacked auto-encoder
 */

public class AEUtils {

	/**
	 * Generates training data for a specific AE
	 * @param amount of data
	 * @param size of the input layer
	 * @param rangeMin
	 * @param rangeMax
	 * @param print, if true the training data get printed, no otherwise
	 * @return the generated data as a two dimensional double array
	 */
	public static double[][] generateRandomTrainingsData(int amount, int size, double rangeMin, double rangeMax,
			boolean print) {

		double[][] trainingsData = new double[amount][size];
		Random r = new Random();
		for (int i = 0; i < amount; i++)
			for (int j = 0; j < size; j++)
				trainingsData[i][j] = rangeMin + (rangeMax - rangeMin) * r.nextDouble();

		if (print) {
			for (int i = 0; i < amount; i++) {
				for (int j = 0; j < size; j++) {
					System.out.print(trainingsData[i][j] + " ");
				}
				System.out.println();
			}
		}

		return trainingsData;
	}

	
	/**
	 * Method for testing a trained AE
	 * @param neuralNet 
	 * @param input
	 * @return the output of the AE
	 */
	public static double[] testAutoEncoder(NeuralNet neuralNet, List<Double> input) {

		if (input.size() != neuralNet.getInputNeuronCount()) {
			System.out.println("Wrong count of input values");
			return null;
		}

		double[] output = new double[neuralNet.getOutputNeuronCount()];

		neuralNet.setInput(input);

		neuralNet.innervate();

		neuralNet.getOutput(output);

		return output;
	}


// CLASSIFICATION -------------------------------------------------------------------------------------------------------------------------------------

	/**
	 * Generates test results - training data
	 * @param amount of data
	 * @param tests number of tests
	 * @param points max points per test
	 */
	public static double[][] generateClassificationTestResults(int amount, int tests, int points) {

		double [][] trainingsData = new double[amount][tests];

		for (int i = 0; i < amount; i++)
			trainingsData[i] = pseudoTestResultGenerator(tests, points);

		return trainingsData;
	}

	/**
	 * Generates grades based on test results = expected output
	 *
	 * @param results of all tests
	 * @param grades number of grade interval
	 */
	public static double[][] generateClassificationResult(double[][] results, int grades) {

		double [][] classifiedData = new double[results.length][grades];
		double result;

		for (int j = 0; j < results.length; j++) {
			result = 0;

			for (double val : results[j])
				result += val;

			for (int i = 0; i < grades; i++) {
				if (result < 0.5 + ((0.5 / (grades-1)) * i)) {
					classifiedData[j][grades - i - 1] = 1;
					result = 1;
				} else
					classifiedData[j][grades - i - 1] = 0;
			}

		}
		return classifiedData;
	}

	/**
	 * Generates test results for one student
	 *
	 * Gaussian distribution (aprox. 70%) and corrolation
	 *
	 * @param tests number of tests
	 * @param points max points per test
	 */
	private static double[] pseudoTestResultGenerator(int tests, int points) {

		double[] results = new double[tests];

		double mean = 0;

		Random r = new Random();
		double offsetToMean = r.nextGaussian() * 2;

		for (int i = 0; i < tests; i++) {
			mean = Math.abs(r.nextGaussian() % points + points * 0.7);
			results[i] = (mean + offsetToMean) / (points * tests);
		}

		return results;
	}

	/**
	 * Get grade calculated by the network
	 *
	 * @param result output of network
	 */
	public static int getGrade(double[] result) {
		int maxIdx = -1;
		double maxVal = 0;

		for (int i = 0; i < result.length; i++) {
			if (result[i] > maxVal) {
				maxIdx = i;
				maxVal = result[i];
			}
		}
		return maxIdx + 1;
	}

	/**
	 * Get Diagnose calculated by the network
	 *
	 * @param result output of network
	 */
	public static double[] getDiagnose(double[] result) {
		double[] diagnose = new double[2];
		if (result[0] < result[1]) {
			diagnose[0] = 0;
			diagnose[1] = 1;
		} else {
			diagnose[0] = 1;
			diagnose[1] = 0;
		}

		return diagnose;
	}

// HELPER-------------------------------------------------------------------------------------------------------------------------------------

	// encode array of patterns
	public static double[][] encodeData (NeuralNet neuralNet, double[][] input) {

		double [][] encoded = new double[input.length][neuralNet.getOutputNeuronCount()];

		for (int i = 0; i < input.length; i++) {
			encoded[i] = testAutoEncoder(neuralNet, Conversion.asList(input[i]));
		}

		return encoded;
	}

	// encode Pattern Set
	public static PatternSet encodePatternSet (NeuralNet neuralNet, PatternSet patterns) {

        PatternSet encoded = Patterns.getSubSet(patterns, 0, patterns.size());

		for (int i = 0; i < patterns.size(); i++) {
			encoded.getInputs().set(i, Conversion.asList(testAutoEncoder(neuralNet, patterns.getInputs().get(i))));
		}

		return encoded;
	}

	public static void printConditionally(String string, boolean print) {
		if (print)
			System.out.println(string);
	}
}