import java.io.*;

public class MNIST1 extends CSV
{
    public MNIST1(){
    }
    
    public void init(){
        int[] eingabeSpalten = new int[784];
        for (int i = 0; i < 784; i++){
            eingabeSpalten[i] = i + 1;
        }
        super.init("mnist_train.csv", ',', 60000, eingabeSpalten, new int[]{0}, new double[]{10});
        erzeugeNetz(new int[]{20, 20, 10}, new NeuronalesNetz.Sigmoid());
    }

    /**
     * Trainiere das Netz
     * 
     * @return der Fehler
     */
    public double trainiere(){
        return trainiere(0.001, 1000, 100);
    }

    /**
     * Teste das Netz mit den MNIST-Trainingsdaten<br>
     * Die Ausgabe erfolgt über die Konsole.
     */
    public void teste(){
        int korrekt = 0;
        int[] abweichungen = new int[10];

        try {
            FileReader filereader = new FileReader("mnist_test.csv");
            BufferedReader reader = new BufferedReader(filereader);
            String line = reader.readLine();
            while (line != null){
                String[] eintraege = line.split(",");
                int zahl = Integer.parseInt(eintraege[0]);
                double daten[] = new double[784];
                for (int i = 1; i < 785; i++){
                    daten[i - 1] = Double.parseDouble(eintraege[i]);
                }
                double ausgabe[] = berechne(daten);

                int berechnet = (int)ausgabe[0];
                System.out.println(zahl + " - " + berechnet);
                if (zahl == berechnet) korrekt++; else abweichungen[zahl]++;
                line = reader.readLine();
            }
            reader.close();
        } catch (Exception ex){
            System.err.println("Fehler beim Einlesen der Trainingsdaten!");
        }

        System.out.println("Korrekt: " + korrekt);
        System.out.println("Anzahl der Fehler bei ...");
        for (int i = 0; i < 10; i++){
            System.out.println("   ... " + i + ": " + abweichungen[i]);
        }
        System.out.println("insgesamt: " + (10000 - korrekt));
    }
}