xzit-xiaxu

NeuralNetwork.java

import java.awt.Color;
import java.util.ArrayList;
import acm.graphics.GImage;
import acm.program.ConsoleProgram;

public class NeuralNetwork extends ConsoleProgram{

	private static final int N_INPUTS = 1024;
	private static final int N_LAYER1 = 20;

	private ArrayList<Neuron> layer1 = null;
	private Neuron prediction = null;

	public void run() {
		loadNeuralNetwork();

		// make predicitons
		GImage birdImage = new GImage("bird6.png");
		GImage planeImage = new GImage("airplane4.png");
		
		makePrediction(birdImage);
		makePrediction(planeImage);
	}

	private void makePrediction(GImage img) {
		// turn the image into inputs
		ArrayList<Double> inputs = new ArrayList<Double>();
		int[][] pixelArray = img.getPixelArray();
		for(int r = 0; r < pixelArray.length; r++) {
			for(int c = 0; c < pixelArray[0].length; c++) {
				Color color = new Color(pixelArray[r][c]);
				double greyScale = getGrey(color);
				inputs.add(greyScale);
			}
		}

		// feed forward into layer 1
		ArrayList<Double> layer1Outputs = new ArrayList<Double>();
		for(int i = 0; i < N_LAYER1; i++) {
			double output = layer1.get(i).activate(inputs);
			layer1Outputs.add(output);
		}

		// feed forward into the prediction layer
		double output = prediction.activate(layer1Outputs);
		
		// slightly biased towards birds
		if(output > 0.4) {
			println("It\'s a bird!");
		} else {
			println("It\'s a plane!");
		}
	}

	private double getGrey(Color color) {
		float[] hsv = Color.RGBtoHSB(color.getRed(),
				color.getGreen(),
				color.getBlue(),
				null);
		return hsv[2];
	}

	private void loadNeuralNetwork() {
		layer1 = new ArrayList<Neuron>();
		for(int i = 0; i < N_LAYER1; i++) {
			Neuron hidden = new Neuron("weights/layer1_" + i + ".txt", N_INPUTS);
			layer1.add(hidden);
		}
		prediction = new Neuron("weights/prediction.txt", N_LAYER1);
	}

	public void init() {
		setFont("courier-24");
	}
}

Neuron.java

// TODO: comment this program

import java.awt.Color;
import java.awt.event.MouseEvent;
import java.io.*;
import java.util.ArrayList;
import acm.graphics.*;
import acm.program.*;
import acm.util.RandomGenerator;

public class Neuron extends GraphicsProgram {

	private ArrayList<Double> weights = null;

	public Neuron(String fileName, int n) {
		loadWeightsFromFile(fileName, n);
	}

	public double activate(ArrayList<Double> inputs) {
		double weightedSum = 0.0;
		for(int i = 0; i < inputs.size(); i++) {
			weightedSum += inputs.get(i) * weights.get(i);
		}
		return sigmoid(weightedSum);
	}

	private double sigmoid(double x) {
		return 1.0 / (1.0 + Math.exp(-x));
	}

	private void loadWeightsFromFile(String fileName, int n) {
		weights = new ArrayList<Double>();
		try {
			BufferedReader rd = new BufferedReader(new FileReader(fileName));
			while(true) {
				String line = rd.readLine();
				if(line == null) break;
				weights.add(Double.parseDouble(line));
			}
			rd.close();
		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}
}

NeuralGraphics.java

// TODO: comment this program

import java.awt.Color;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import acm.graphics.*;
import acm.program.*;

public class NeuronGraphics extends GraphicsProgram {
	
	private GLabel outputNum = new GLabel("0.0");
	private GLabel inputNum = new GLabel("0.0");
	
	private ArrayList<Double> weights = new ArrayList<Double>();

	public void run() {
		weights.add(2.5);
		weights.add(1.0);
		weights.add(7.0);
		weights.add(-9.0);
		drawNeuron();
		calculateOutput();
		addMouseListeners();
	}

	private void calculateOutput() {
		double inputSum = 0;
		for(int i = 0; i < weights.size(); i++) {
			double weight = weights.get(i);
			double input = getInputValue(i);
			double product = input * weight;
			inputSum = inputSum + product;
		}
		double output = sigmoidFunction(inputSum);
		
		// update the input str
		String inputStr = format(inputSum);
		inputNum.setLabel(inputStr);
		
		// update the output str
		String outputStr = format(output);
		outputNum.setLabel(outputStr);
		colorLabel(outputNum, output);
	}

	private void drawNeuron() {
		drawInput("input1", 100, 100, weights.get(0));
		drawInput("input2", 100, 200, weights.get(1));
		drawInput("input3", 100, 300, weights.get(2));
		drawInput("input4", 100, 400, weights.get(3));
		drawOutput();
		drawNode();
		drawTitle();
	}
	
	public void mouseClicked(MouseEvent e) {
		GObject obj = getElementAt(e.getX(), e.getY());
		if(obj != null) {
			if(obj.getColor() == Color.BLACK) {
				obj.setColor(Color.LIGHT_GRAY);
			} else {
				obj.setColor(Color.BLACK);
			}
		}
		calculateOutput();
	}
	
	private double sigmoidFunction(double x) {
		return 1.0 / (1.0 + Math.exp(-x));
	}
	
	//--------------- Lower Level Methods-------------//

	private void drawTitle() {
		GLabel title = new GLabel("Artificial Neuron");
		title.setFont("courier-30");
		title.setColor(Color.BLUE);
		add(title, (getWidth() - title.getWidth())/2,30);
	}

	private void drawNode() {
		int nodeSize = 100;
		GOval oval = new GOval(nodeSize, nodeSize);
		oval.setFilled(true);
		oval.setColor(new Color(102, 204, 255));
		add(oval, (getWidth() - nodeSize)/2, (getHeight() - nodeSize)/2);

		inputNum.setFont("courier-24");
		add(inputNum, 
				(getWidth() - inputNum.getWidth())/2, 
				getHeight()/2 + 8);
	}

	private void drawOutput() {
		GLine output = new GLine(getWidth()/2, getHeight()/2,
				getWidth() * .76, getHeight()/2);
		add(output);
		outputNum.setFont("courier-24");
		add(outputNum, getWidth() * .78, getHeight()/2 + 8);
	}

	private void drawInput(String name, int x, int y, double weight) {
		GRect input = new GRect(30, 30);
		input.setFilled(true);
		input.setColor(Color.LIGHT_GRAY);
		add(input, x, y);

		GLabel label = new GLabel(name);
		label.setFont("courier-18");
		add(label, x - 10, y - 10);

		GLine line = new GLine(x + 30, y + 15, getWidth()/2, getHeight()/2);
		add(line);
		
		double midX = (line.getStartPoint().getX()*1.5 + line.getEndPoint().getX()*.5) / 2;
		double midY = (line.getStartPoint().getY()*1.5 + line.getEndPoint().getY()*.5) / 2;
		
		GLabel weightLabel = new GLabel(weight + "");
		weightLabel.setFont("courier-18");
		add(weightLabel, midX, midY);
	}
	
	private String format(double output) {
		int hundred = (int) (output * 1000);
		double decimal = hundred / 1000.0;
		return decimal + "";
	}

	private double getInputValue(int i) {
		double x = 110;
		double y = (i + 1) * 100 + 10;
		GObject obj = getElementAt(x, y);
		if(obj.getColor() == Color.BLACK) {
			return 1.0;
		}
		return 0.0;
	}
	
	private void colorLabel(GLabel label, double value) {
		if(value > 0.7) {
			label.setColor(new Color(0, 150, 0));
		} else if(value < 0.2) {
			label.setColor(new Color(150, 0, 0));
		} else {
			label.setColor(Color.BLACK);
		}
	}
	
	public static final int APPLICATION_WIDTH = 700;
	public static final int APPLICATION_HEIGHT = 500;
}

分类:

技术点:

相关文章:

  • 2021-07-29
  • 2021-05-30
  • 2021-09-29
  • 2021-06-04
  • 2022-03-02
  • 2022-02-01
猜你喜欢
  • 2021-12-16
  • 2022-12-23
  • 2021-11-28
  • 2022-12-23
  • 2021-09-24
  • 2021-08-25
相关资源
相似解决方案