org.usfirst.frc.team2084.neuralnetwork.HeadingNeuralNetworkTrainer.java Source code

Java tutorial

Introduction

Here is the source code for org.usfirst.frc.team2084.neuralnetwork.HeadingNeuralNetworkTrainer.java

Source

/* 
 * Copyright (c) 2015 RobotsByTheC. All rights reserved.
 *
 * Open Source Software - may be modified and shared by FRC teams. The code must
 * be accompanied by the BSD license file in the root directory of the project.
 */
package org.usfirst.frc.team2084.neuralnetwork;

import java.awt.BorderLayout;
import java.awt.Container;
import java.awt.GridLayout;
import java.beans.PropertyChangeEvent;
import java.beans.PropertyChangeListener;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Paths;
import java.text.NumberFormat;
import java.util.ArrayList;

import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFormattedTextField;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.ProgressMonitor;
import javax.swing.SwingUtilities;
import javax.swing.SwingWorker;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.Range;
import org.jfree.data.xy.XYDataItem;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
 * @author Ben Wolsieffer
 */
public class HeadingNeuralNetworkTrainer {

    private final JFrame frame = new JFrame("Heading Neural Network Trainer");
    private final JFileChooser dataDirectoryChooser = new JFileChooser(getWorkingDirectory());
    private final JFileChooser saveFileChooser = new JFileChooser(getWorkingDirectory());

    private final JPanel graphPanel = new JPanel();
    private final JPanel controlPanel = new JPanel();

    private ChartPanel outputGraphPanel;
    private final JFreeChart outputGraph;
    private final XYSeries outputGraphNetworkSeries;
    private final XYSeries outputGraphDataSeries;

    private final NumberFormat integerFormat = NumberFormat.getIntegerInstance();
    private final NumberFormat numberFormat = NumberFormat.getNumberInstance();

    private final JFormattedTextField etaField = new JFormattedTextField(numberFormat);
    private final JFormattedTextField momentumField = new JFormattedTextField(numberFormat);
    private final JFormattedTextField iterationsField = new JFormattedTextField(integerFormat);

    private final JButton chooseDataButton = new JButton("Choose Data");
    private final JButton trainButton = new JButton("Train");
    private final JButton saveButton = new JButton("Save");

    private File dataDirectory;
    private final Network network;

    private boolean training;
    private volatile int iterations = 10000;
    private volatile double eta = 0.3;
    private volatile double momentum = 0.1;

    public HeadingNeuralNetworkTrainer() {
        outputGraphNetworkSeries = new XYSeries("Network Prediction");
        outputGraphDataSeries = new XYSeries("Error");
        final XYSeriesCollection data = new XYSeriesCollection();
        data.addSeries(outputGraphDataSeries);
        data.addSeries(outputGraphNetworkSeries);
        outputGraph = ChartFactory.createXYLineChart("Error vs. Output", "Time", "Error", data,
                PlotOrientation.VERTICAL, true, true, false);

        NumberAxis xAxis = new NumberAxis();
        xAxis.setRange(new Range(-Math.PI, Math.PI));
        xAxis.setAutoRange(false);
        NumberAxis yAxis = new NumberAxis();
        yAxis.setRange(new Range(-1, 1));

        XYPlot plot = (XYPlot) outputGraph.getPlot();
        plot.setDomainAxis(xAxis);
        plot.setRangeAxis(yAxis);

        network = new Network(new int[] { 1, 5, 1 }, eta, momentum, new TransferFunction.HyperbolicTangent());

        try {
            SwingUtilities.invokeAndWait(() -> {
                frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
                final Container content = frame.getContentPane();
                content.setLayout(new BorderLayout());

                outputGraphPanel = new ChartPanel(outputGraph);
                outputGraphPanel.setDomainZoomable(false);
                outputGraphPanel.setRangeZoomable(false);
                graphPanel.add(outputGraphPanel);

                graphPanel.setLayout(new GridLayout(1, 1));
                content.add(graphPanel, BorderLayout.CENTER);
                {

                    JLabel etaLabel = new JLabel("Eta:");
                    etaLabel.setLabelFor(etaField);
                    etaField.setText(Double.toString(network.getEta()));
                    etaField.setColumns(5);
                    etaField.addPropertyChangeListener("value",
                            (evt) -> eta = ((Number) evt.getNewValue()).doubleValue());
                    controlPanel.add(etaLabel);
                    controlPanel.add(etaField);

                    JLabel momentumLabel = new JLabel("Momentum:");
                    momentumLabel.setLabelFor(etaField);
                    momentumField.setText(Double.toString(network.getMomentum()));
                    momentumField.setColumns(5);
                    momentumField.addPropertyChangeListener("value",
                            (evt) -> momentum = ((Number) evt.getNewValue()).doubleValue());
                    controlPanel.add(momentumLabel);
                    controlPanel.add(momentumField);

                    JLabel iterationsLabel = new JLabel("Iterations:");
                    iterationsLabel.setLabelFor(iterationsField);
                    iterationsField.setText(Integer.toString(iterations));
                    iterationsField.setColumns(10);
                    iterationsField.addPropertyChangeListener("value",
                            (evt) -> iterations = ((Number) evt.getNewValue()).intValue());
                    controlPanel.add(iterationsLabel);
                    controlPanel.add(iterationsField);
                }

                chooseDataButton.addActionListener((e) -> {
                    if (dataDirectoryChooser.showOpenDialog(frame) == JFileChooser.APPROVE_OPTION) {
                        dataDirectory = dataDirectoryChooser.getSelectedFile();
                        displayData();
                    }
                });
                controlPanel.add(chooseDataButton);

                trainButton.addActionListener((e) -> trainNetwork());
                controlPanel.add(trainButton);

                saveButton.addActionListener((e) -> {
                    saveNetwork();
                });
                controlPanel.add(saveButton);

                content.add(controlPanel, BorderLayout.SOUTH);

                dataDirectoryChooser.setFileSelectionMode(JFileChooser.DIRECTORIES_ONLY);

                frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
                frame.pack();
                frame.setVisible(true);
            });
        } catch (InvocationTargetException | InterruptedException e) {
            e.printStackTrace();
        }
    }

    private void displayData() {
        if (dataDirectory != null) {
            outputGraphDataSeries.clear();

            // Create a copy of the data directory for the worker thread
            final File dataDirectory = this.dataDirectory;
            new Thread(() -> {
                if (dataDirectory.exists() && dataDirectory.isDirectory()) {
                    for (File d : dataDirectory.listFiles()) {
                        if (d.exists() && d.isFile() && d.getName().endsWith(".csv")) {
                            try (BufferedReader dataReader = new BufferedReader(new FileReader(d))) {
                                String line;
                                while ((line = dataReader.readLine()) != null) {
                                    String[] data = line.split(",");
                                    if (data.length >= 2) {
                                        double error = Double.parseDouble(data[0]);
                                        double output = Double.parseDouble(data[1]);

                                        // Add to the data series on the UI
                                        // thread
                                        SwingUtilities.invokeLater(() -> {
                                            outputGraphDataSeries.add(error, output);
                                        });
                                    } else {
                                        System.err.println("Too few values in data line: " + line);
                                    }
                                }
                            } catch (FileNotFoundException e) {
                                System.err.println("Could not open data file: " + e.getMessage());
                            } catch (IOException e) {
                                System.err.println("Could not read data file: " + e.getMessage());
                            } catch (NumberFormatException e) {
                                System.err.println("Data is not a valid number: " + e.getMessage());
                            }
                        }
                    }
                }
            }).start();
        }
    }

    private void displayNetwork() {
        outputGraphNetworkSeries.clear();

        for (double angle = -Math.PI; angle <= Math.PI; angle += 0.05) {
            double scaledAngle = convertAngleToInput(angle);
            synchronized (this) {
                network.feedForward(scaledAngle);
                double fAngle = angle;
                double output = network.getOutputLayer()[0].getOutputValue();

                SwingUtilities.invokeLater(() -> {
                    outputGraphNetworkSeries.add(fAngle, output);
                });
            }
        }
    }

    private void trainNetwork() {
        if (!training) {

            // Set training flag and disable button
            training = true;
            trainButton.setEnabled(false);

            final ProgressMonitor progressMonitor = new ProgressMonitor(frame, "Training Network...", "", 0, 100);
            progressMonitor.setMillisToDecideToPopup(100);
            progressMonitor.setMillisToPopup(400);

            @SuppressWarnings("unchecked")
            final ArrayList<XYDataItem> data = new ArrayList<>(outputGraphDataSeries.getItems());

            final int maxProgress = iterations * data.size();

            final SwingWorker<Void, Void> trainingWorker = new SwingWorker<Void, Void>() {

                @Override
                protected Void doInBackground() throws Exception {
                    // Reset the neural network to default values
                    synchronized (this) {
                        network.reset();
                        network.setEta(eta);
                        network.setMomentum(momentum);
                    }

                    outer: for (int j = 0; j < iterations; j++) {
                        for (int i = 0; i < data.size(); i++) {
                            if (!isCancelled()) {
                                XYDataItem d = data.get(i);
                                double error = convertAngleToInput(d.getXValue());
                                double output = d.getYValue();
                                synchronized (this) {
                                    network.feedForward(error);
                                    network.backPropagation(output);
                                }
                                int jl = j;
                                int il = i;
                                int progress = (int) (((float) (data.size() * jl + il + 1) / maxProgress) * 100);
                                setProgress(progress);
                            } else {
                                break outer;
                            }
                        }
                    }

                    displayNetwork();
                    return null;
                }

                @Override
                protected void done() {
                    training = false;
                    trainButton.setEnabled(true);
                    progressMonitor.close();
                }
            };
            trainingWorker.addPropertyChangeListener(new PropertyChangeListener() {

                @Override
                public void propertyChange(PropertyChangeEvent evt) {
                    if (evt.getPropertyName().equals("progress")) {
                        progressMonitor.setProgress((int) evt.getNewValue());
                    }
                    if (progressMonitor.isCanceled()) {
                        trainingWorker.cancel(true);
                    }
                }
            });
            trainingWorker.execute();
        }
    }

    private void saveNetwork() {
        if (saveFileChooser.showSaveDialog(frame) == JFileChooser.APPROVE_OPTION) {
            File saveFile = saveFileChooser.getSelectedFile();

            new SwingWorker<Void, Void>() {

                @Override
                protected Void doInBackground() throws Exception {
                    synchronized (network) {
                        new Data(network).save(saveFile);
                    }

                    return null;
                }

                @Override
                protected void done() {
                }
            }.execute();
        }
    }

    private static File getWorkingDirectory() {
        return new File(Paths.get("").toAbsolutePath().toString());
    }

    private static double convertAngleToInput(double angle) {
        return angle / Math.PI;
    }
}