问题描述
我是一名高中生,正在为我的CS研究课设计一个项目(我很幸运有机会参加这样的课程)!该项目旨在通过多层感知器(MLP)使AI学习流行的游戏Snake,该感知器通过遗传算法(GA)学习。这个项目的灵感来自我在YouTube上看过的许多视频,这些视频实现了我刚刚描述的内容,如您所见here和here。我已经使用JavaFX和一个名为Neuroph的AI库编写了上述项目。
这个名字是无关紧要的,因为我有一个名词和形容词的列表,这些名词和形容词是我用来生成它们的(我认为这样会使它更有趣)。圆括号中的数字是这一代中最好的分数,因为一次只显示1条蛇。
繁殖时,我将x%的蛇设为父母(在这种情况下为20)。然后将孩子的数量平均分配给每对蛇父母。在这种情况下,“基因”是MLP的权重。由于我的库实际上并不支持偏见,因此我向输入层添加了一个偏向神经元,并将其连接到每一层中的所有其他神经元,以使其权重代替偏见(如线程here中所述)。每条蛇的孩子都有50、50的机会获得每个基因的父母一方的基因。将基因突变的概率设置为-1.0至1.0之间的随机数,也有5%的可能性。
每条蛇的MLP有3层:18个输入神经元,14个隐藏神经元和4个输出神经元(每个方向)。我输入的输入是头的x,头的y,食物的x,食物的y和左步。它还从四个方向看,并检查与食物,墙壁及其本身的距离(如果看不到,则将其设置为-1.0)。我还谈到了偏向神经元,将其加到18后就把数字增加到18。
我计算蛇得分的方法是通过健身函数,即(消耗的苹果×5 +存活秒数/ 2)
这是我的GAMLPAgent.java,所有MLP和GA内容都发生在这里。
package agents;
import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;
/**
*
* @author Preston Tang
*
* GAMLPAgent stands for genetic Algorithm multi-layer Perceptron Agent
*/
public class GAMLPAgent implements Comparable<GAMLPAgent> {
public Snake mask;
private final MultiLayerPerceptron mlp;
private final int width;
private final int height;
private final double size;
private final double mutationRate = 0.05;
public GAMLPAgent(Snake mask,int width,int height,double size) {
this.mask = mask;
this.width = width;
this.height = height;
this.size = size;
//Input: x of head,y of head,x of food,y of food,steps left
//Input: 4 directions,check for distance to food,wall,and self + 1 bias neuron (18 total)
//6 hidden perceptrons (2 hidden layer(s))
//Output: A direction,4 possibilities
mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID,18,14,4);
//Adding connections
List<Layer> layers = mlp.getLayers();
for (int r = 0; r < layers.size(); r++) {
for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
}
}
// System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getoutConnections());
mlp.randomizeWeights();
// System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
}
public void compute() {
if (mask.isAlive()) {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
double headX = head.getX();
double headY = head.getY();
double foodX = mask.getFood().getX();
double foodY = mask.getFood().getY();
int stepsLeft = mask.getSteps();
double foodL = -1.0,wallL,selfL = -1.0;
double foodR = -1.0,wallR,selfR = -1.0;
double foodU = -1.0,wallU,selfU = -1.0;
double foodD = -1.0,wallD,selfD = -1.0;
//The 4 directions
//Left Direction
if (head.getY() == food.getY() && head.getX() > food.getX()) {
foodL = head.getX() - food.getX();
}
wallL = head.getX() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() > part.getX()) {
selfL = head.getX() - part.getX();
break;
}
}
//Right Direction
if (head.getY() == food.getY() && head.getX() < food.getX()) {
foodR = food.getX() - head.getX();
}
wallR = size * width - head.getX();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() < part.getX()) {
selfR = part.getX() - head.getX();
break;
}
}
//Up Direction
if (head.getX() == food.getX() && head.getY() < food.getY()) {
foodU = food.getY() - head.getY();
}
wallU = size * height - head.getY();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() < part.getY()) {
selfU = part.getY() - head.getY();
break;
}
}
//Down Direction
if (head.getX() == food.getX() && head.getY() > food.getY()) {
foodD = head.getY() - food.getY();
}
wallD = head.getY() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() > part.getY()) {
selfD = head.getY() - food.getY();
break;
}
}
mlp.setInput(
headX,headY,foodX,foodY,stepsLeft,foodL,selfL,foodR,selfR,foodU,selfU,foodD,selfD,1);
mlp.calculate();
if (getIndexOfLargest(mlp.getoutput()) == 0) {
mask.setDirection(Direction.UP);
} else if (getIndexOfLargest(mlp.getoutput()) == 1) {
mask.setDirection(Direction.DOWN);
} else if (getIndexOfLargest(mlp.getoutput()) == 2) {
mask.setDirection(Direction.LEFT);
} else if (getIndexOfLargest(mlp.getoutput()) == 3) {
mask.setDirection(Direction.RIGHT);
}
}
}
public double[][] breed(GAMLPAgent agent,int num) {
//Converts Double[] to double[]
//https://stackoverflow.com/questions/1109988/how-do-i-convert-double-to-double
double[] parent1 = Stream.of(mlp.getWeights()).mapTodouble(Double::doubleValue).toArray();
double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapTodouble(Double::doubleValue).toArray();
double[][] childGenes = new double[num][parent1.length];
for (int r = 0; r < num; r++) {
for (int c = 0; c < childGenes[r].length; c++) {
if (new Random().nextInt(100) <= mutationRate * 100) {
childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0,1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
} else {
childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
}
}
}
return childGenes;
}
public MultiLayerPerceptron getMLP() {
return mlp;
}
public void setMask(Snake mask) {
this.mask = mask;
}
public Snake getMask() {
return mask;
}
public int getIndexOfLargest(double[] array) {
if (array == null || array.length == 0) {
return -1; // null or empty
}
int largest = 0;
for (int i = 1; i < array.length; i++) {
if (array[i] > array[largest]) {
largest = i;
}
}
return largest; // position of the first largest found
}
@Override
public int compareto(GAMLPAgent t) {
if (this.getMask().getscore() < t.getMask().getscore()) {
return -1;
} else if (t.getMask().getscore() < this.getMask().getscore()) {
return 1;
}
return 0;
}
public void debugLocation() {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getoutput()));
}
public void debuginput() {
String s = "";
for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
s += mlp.getInputNeurons().get(i).getoutput() + " ";
}
System.out.println(s);
}
public double[] getoutput() {
return mlp.getoutput();
}
}
这是我的代码的主要类geneticSnake2.java,它是游戏循环所在的位置,也是我为子蛇分配基因的位置(我知道这样做可以做得更干净)。
package main;
import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;
/**
*
* @author Preston Tang
*/
public class geneticSnake2 extends Application {
private final int width = 45;
private final int height = 40;
private final double displaySize = 120;
private final double size = 12;
private final Color pathColor = Color.rgb(120,120,120);
private final Color wallColor = Color.rgb(50,50,50);
private final int initSnakeLength = 2;
private final int populationSize = 1000;
private int generation = 0;
private int initSteps = 100;
private int stepsIncrease = 50;
private double parentPercentage = 0.2;
private final ArrayList<Color> snakeColors = new ArrayList() {
{
add(Color.GREEN);
add(Color.RED);
add(Color.YELLOW);
add(Color.BLUE);
add(Color.magenta);
add(Color.PINK);
add(Color.ORANGERED);
add(Color.BLACK);
add(Color.GOLDENROD);
add(Color.WHITE);
}
};
private final ArrayList<Snake> snakes = new ArrayList<>();
private final ArrayList<GAMLPAgent> agents = new ArrayList<>();
private long initTime = System.nanoTime();
@Override
public void start(Stage stage) {
Pane root = new Pane();
Pane graphics = new Pane();
graphics.setPrefheight(height * size);
graphics.setPrefWidth(width * size);
graphics.setTranslateX(0);
graphics.setTranslateY(displaySize);
Pane display = new Pane();
display.setStyle("-fx-background-color: BLACK");
display.setPrefheight(displaySize);
display.setPrefWidth(width * size);
display.setTranslateX(0);
display.setTranslateY(0);
root.getChildren().add(display);
SnakeGrid sg = new SnakeGrid(pathColor,wallColor,width,height,size);
//Parsing "adjectives.txt" and "nouns.txt" to form possible names
ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));
//Initializing the population
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0,1).toupperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0,1).toupperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar,adj + " " + noun,size,initSnakeLength,color,initSteps,stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.add(new GAMLPAgent(snake,size));
} else {
Snake snake = new Snake(adj + " " + noun,stepsIncrease);
snakes.add(snake);
agents.add(new GAMLPAgent(snake,size));
}
}
//Focused on original snake
display.getChildren().add(snakes.get(0).getInfoBar());
graphics.getChildren().addAll(sg);
graphics.getChildren().addAll(snakes.get(0));
root.getChildren().add(graphics);
//Add the speed controller (slider)
Slider slider = new Slider(1,10,10);
slider.setTranslateX(205);
slider.setTranslateY(75);
slider.setdisable(true);
root.getChildren().add(slider);
Scene scene = new Scene(root,width * size,height * size + displaySize);
stage.setScene(scene);
//Fixes the setResizable bug
//https://stackoverflow.com/questions/20732100/javafx-why-does-stage-setresizablefalse-cause-additional-margins
stage.setTitle("21-geneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
stage.setResizable(false);
stage.sizetoScene();
stage.show();
AnimationTimer timer = new AnimationTimer() {
private long lastUpdate = 0;
@Override
public void handle(long Now) {
if (Now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
lastUpdate = Now;
int alive = populationSize;
for (int i = 0; i < snakes.size(); i++) {
Snake snake = snakes.get(i); //Current snake
if (i == 0) {
Collections.sort(agents);
snake.getInfoBar().getscoreText().setText("score: " + snake.getscore() + " (" + agents.get(agents.size() - 1).getMask().getscore() + ")");
}
if (!snake.isAlive()) {
alive--;
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getStatusText().setText("Status: Dead");
snake.getInfoBar().getStatusText().setFill(Color.RED);
graphics.getChildren().remove(snake);
}
} else {
//If out of steps
if (snake.getSteps() <= 0) {
snake.setAlive(false);
}
//Bounds Detection (left right up down)
if (snake.getSnakeParts().get(0).getX() >= width * size
|| snake.getSnakeParts().get(0).getX() <= 0
|| snake.getSnakeParts().get(0).getY() >= height * size
|| snake.getSnakeParts().get(0).getY() <= 0) {
snake.setAlive(false);
}
//Self-Collision Detection
for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
&& snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
snakes.get(o).setAlive(false);
}
}
int rate = (int) slider.getValue();
int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);
agents.get(i).compute();
snake.manageMovement();
snake.setSecondsAlive(seconds);
// agents.get(0);
// System.out.println(Arrays.toString(agents.get(0).getoutput()));
//
// System.out.println("\n\n\n\n\n\n\n");
//Expression to calculate score
double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2,snake.getConsumed()) + Math.pow(snake.getConsumed(),2.1) * 500)
// - (Math.pow(snake.getConsumed(),1.2) * Math.pow(0.25 * snake.getSteps(),1.3));
snake.setscore(Math.round(exp * 100.0) / 100.0);
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
}
}
}
//Reset and breed
if (alive == 0) {
//Ascending order
initTime = System.nanoTime();
generation++;
graphics.getChildren().clear();
graphics.getChildren().addAll(sg);
snakes.clear();
//x% of snakes are parents
int parentNum = (int) (populationSize * parentPercentage);
//Faster odd number check
if ((parentNum & 1) != 0) {
//If odd make even
parentNum += 1;
}
for (int i = 0; i < parentNum; i += 2) {
//Get the 2 parents,sorted by score
GAMLPAgent p1 = agents.get(populationSize - (i + 2));
GAMLPAgent p2 = agents.get(populationSize - (i + 1));
//Produce the next generation
double[][] childGenes = p1.breed(p2,((populationSize - parentNum) / parentNum) * 2);
//Debugs Genes
// System.out.println(Arrays
// .stream(childGenes)
// .map(Arrays::toString)
// .collect(Collectors.joining(System.lineseparator())));
//Soft copy
ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);
for (int o = 0; o < childGenes.length; o++) {
temp.get(o).getMLP().setWeights(childGenes[o]);
}
//Add the genes of every pair of parents to the children
for (int o = 0; o < childGenes.length; o++) {
//Useful debug message
// System.out.println("ParentNum: " + parentNum
// + " ChildPerParent: " + (populationSize - parentNum) / parentNum
// + " Index: " + (o + (i / 2 * childGenes.length))
// + " ChildGenesNum: " + childGenes.length
// + " Var O: " + o);
//Adds the genes of the temp to the agents
agents.set((o + (i / 2 * childGenes.length)),temp.get(o));
}
// System.out.println("\n\n\n\n\n\n");
}
//Debugging the snakes' genes to a file
// String str = "";
// for (int i = 0; i < agents.size(); i++) {
// str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+ "\n\n\n";
// }
//
// printToFile(str,"gen" + generation);
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0,1).toupperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0,1).toupperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar,stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.get(i).setMask(snake);
} else {
Snake snake = new Snake(adj + " " + noun,stepsIncrease);
snakes.add(snake);
agents.get(i).setMask(snake);
}
}
graphics.getChildren().add(snakes.get(0));
display.getChildren().clear();
//Focused on original snake at first
display.getChildren().add(snakes.get(0).getInfoBar());
}
}
}
};
//Starts the infinite loop
timer.start();
}
public String readFile(File f) {
String content = "";
try {
content = new Scanner(f).useDelimiter("\\Z").next();
} catch (FileNotFoundException ex) {
System.err.println("Error: Unable to read " + f.getName());
}
return content;
}
public void printToFile(String str,String name) {
FileWriter fileWriter;
try {
fileWriter = new FileWriter(name + ".txt");
try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
bufferedWriter.write(str);
}
} catch (IOException ex) {
ex.printstacktrace();
}
}
public static void main(String[] args) {
launch(args);
}
}
主要的问题是,即使经过几千代,这些蛇仍只是简单地自杀而入墙。在我上面链接的视频中,这些蛇在躲避墙壁,并像第5代一样获得食物。我怀疑问题出在我为正在出生的蛇分配基因的主要班级中。
我实际上已经坚持了几个星期。以前,我怀疑的问题之一是缺乏投入,因为那时我的路还少。但是现在,我认为情况已不再如此。如果需要,我可以尝试在4个对角线方向上看一下,再向蛇的MLP添加12个输入。我也曾去过人工智能不和谐寻求帮助,但实际上并没有找到解决方案。
如果需要,我愿意发送我的全部代码,以便您可以自己运行仿真。
如果您已经读到这里,谢谢您抽出宝贵的时间来帮助我!非常感谢。
解决方法
我不奇怪你的蛇快死了。
让我们退后一步。 AI到底是什么?好吧,这是一个搜索问题。我们正在某个参数空间中搜索,以找到在给定游戏当前状态下可以解决蛇的参数集。您可以想象一个具有全局最小值的参数空间:可能的最佳蛇,犯错误最少的蛇。
所有学习算法都从此参数空间中的某个点开始,并尝试找到随时间变化的全局最大值。首先,让我们考虑一下MLP。 MLP通过尝试一组权重,计算损失函数,然后朝进一步降低损失(梯度下降)的方向迈出一步来学习。一个MLP可以找到一个最小值很明显,但是是否可以找到一个足够好的最小值是一个问题,并且存在很多可以提高这种可能性的训练技术。
另一方面,遗传算法的收敛特性非常差。首先,让我们停止调用这些遗传算法。叫这些 而是使用smorgasbord算法。 smorgasbord算法从两个父对象获取两组参数,将它们混杂在一起,然后生成一个新的smorgasbord。是什么让您认为这比两者中的任何一个都更好?您在这里最小化什么?您怎么知道它正在变得更好?如果您附加了损失函数,怎么知道自己处在实际上可以最小化的空间中?我要提出的观点是,遗传算法是不受原理限制的,与自然不同。大自然不仅将密码子放入混合器中以制造新的DNA链,而且这正是遗传算法所做的。有一些技术可以增加一些爬山时间,但是遗传算法仍然具有tons of problems。
要点是,不要被名字淹没。遗传算法就是smorgasbord算法。我认为您的方法行不通,因为GA不能保证在无限迭代后收敛,而MLP不能保证收敛到良好的全局最小值。
该怎么办?好吧,一种更好的方法是使用适合您问题的学习范例。更好的方法是使用强化学习。这个主题有很好的course on Udacity from Georgia Tech。