使用MLP和GA进行学习的Snake AI即使经过几代人的学习,也不会表现出智能行为

问题描述

我是一名高中生,正在为我的CS研究课设计一个项目(我很幸运有机会参加这样的课程)!该项目旨在通过多层感知器(MLP)使AI学习流行的游戏Snake,该感知器通过遗传算法(GA)学习。这个项目的灵感来自我在YouTube上看过的许多视频,这些视频实现了我刚刚描述的内容,如您所见herehere。我已经使用JavaFX和一个名为Neuroph的AI库编写了上述项目。

这是我的程序当前的样子:

GeneticSnake Project

这个名字是无关紧要的,因为我有一个名词和形容词的列表,这些名词和形容词是我用来生成它们的(我认为这样会使它更有趣)。圆括号中的数字是这一代中最好的分数,因为一次只显示1条蛇。

繁殖时,我将x%的蛇设为父母(在这种情况下为20)。然后将孩子的数量平均分配给每对蛇父母。在这种情况下,“基因”是MLP的权重。由于我的库实际上并不支持偏见,因此我向输入层添加一个偏向神经元,并将其连接到每一层中的所有其他神经元,以使其权重代替偏见(如线程here中所述)。每条蛇的孩子都有50、50的机会获得每个基因的父母一方的基因。将基因突变的概率设置为-1.0至1.0之间的随机数,也有5%的可能性。

每条蛇的MLP有3层:18个输入神经元,14个隐藏神经元和4个输出神经元(每个方向)。我输入的输入是头的x,头的y,食物的x,食物的y和左步。它还从四个方向看,并检查与食物,墙壁及其本身的距离(如果看不到,则将其设置为-1.0)。我还谈到了偏向神经元,将其加到18后就把数字增加到18。

我计算蛇得分的方法是通过健身函数,即(消耗的苹果×5 +存活秒数/ 2)

这是我的GAMLPAgent.java,所有MLP和GA内容都发生在这里

package agents;

import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;

/**
 *
 * @author Preston Tang
 *
 * GAMLPAgent stands for genetic Algorithm multi-layer Perceptron Agent
 */
public class GAMLPAgent implements Comparable<GAMLPAgent> {

    public Snake mask;

    private final MultiLayerPerceptron mlp;

    private final int width;
    private final int height;
    private final double size;

    private final double mutationRate = 0.05;

    public GAMLPAgent(Snake mask,int width,int height,double size) {
        this.mask = mask;
        this.width = width;
        this.height = height;
        this.size = size;

        //Input: x of head,y of head,x of food,y of food,steps left
        //Input: 4 directions,check for distance to food,wall,and self  + 1 bias neuron (18 total)
        //6 hidden perceptrons (2 hidden layer(s))
        //Output: A direction,4 possibilities
        mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID,18,14,4);
        //Adding connections
        List<Layer> layers = mlp.getLayers();

        for (int r = 0; r < layers.size(); r++) {
            for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
                mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
            }
        }

//        System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getoutConnections());
        mlp.randomizeWeights();

//        System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
    }

    public void compute() {
        if (mask.isAlive()) {
            Rectangle head = mask.getSnakeParts().get(0);
            Rectangle food = mask.getFood();

            double headX = head.getX();
            double headY = head.getY();
            double foodX = mask.getFood().getX();
            double foodY = mask.getFood().getY();
            int stepsLeft = mask.getSteps();

            double foodL = -1.0,wallL,selfL = -1.0;
            double foodR = -1.0,wallR,selfR = -1.0;
            double foodU = -1.0,wallU,selfU = -1.0;
            double foodD = -1.0,wallD,selfD = -1.0;

            //The 4 directions
            //Left Direction
            if (head.getY() == food.getY() && head.getX() > food.getX()) {
                foodL = head.getX() - food.getX();
            }

            wallL = head.getX() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() > part.getX()) {
                    selfL = head.getX() - part.getX();
                    break;
                }
            }

            //Right Direction
            if (head.getY() == food.getY() && head.getX() < food.getX()) {
                foodR = food.getX() - head.getX();
            }

            wallR = size * width - head.getX();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getY() == part.getY() && head.getX() < part.getX()) {
                    selfR = part.getX() - head.getX();
                    break;
                }
            }

            //Up Direction
            if (head.getX() == food.getX() && head.getY() < food.getY()) {
                foodU = food.getY() - head.getY();
            }

            wallU = size * height - head.getY();

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() < part.getY()) {
                    selfU = part.getY() - head.getY();
                    break;
                }
            }

            //Down Direction
            if (head.getX() == food.getX() && head.getY() > food.getY()) {
                foodD = head.getY() - food.getY();
            }

            wallD = head.getY() - size;

            for (Rectangle part : mask.getSnakeParts()) {
                if (head.getX() == part.getX() && head.getY() > part.getY()) {
                    selfD = head.getY() - food.getY();
                    break;
                }
            }

            mlp.setInput(
                    headX,headY,foodX,foodY,stepsLeft,foodL,selfL,foodR,selfR,foodU,selfU,foodD,selfD,1);

            mlp.calculate();

            if (getIndexOfLargest(mlp.getoutput()) == 0) {
                mask.setDirection(Direction.UP);
            } else if (getIndexOfLargest(mlp.getoutput()) == 1) {
                mask.setDirection(Direction.DOWN);
            } else if (getIndexOfLargest(mlp.getoutput()) == 2) {
                mask.setDirection(Direction.LEFT);
            } else if (getIndexOfLargest(mlp.getoutput()) == 3) {
                mask.setDirection(Direction.RIGHT);
            }
        }
    }

    public double[][] breed(GAMLPAgent agent,int num) {
        //Converts Double[] to double[]
        //https://stackoverflow.com/questions/1109988/how-do-i-convert-double-to-double
        double[] parent1 = Stream.of(mlp.getWeights()).mapTodouble(Double::doubleValue).toArray();
        double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapTodouble(Double::doubleValue).toArray();

        double[][] childGenes = new double[num][parent1.length];

        for (int r = 0; r < num; r++) {
            for (int c = 0; c < childGenes[r].length; c++) {
                if (new Random().nextInt(100) <= mutationRate * 100) {
                    childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0,1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
                } else {
                    childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
                }
            }
        }

        return childGenes;
    }

    public MultiLayerPerceptron getMLP() {
        return mlp;
    }

    public void setMask(Snake mask) {
        this.mask = mask;
    }

    public Snake getMask() {
        return mask;
    }

    public int getIndexOfLargest(double[] array) {
        if (array == null || array.length == 0) {
            return -1; // null or empty
        }
        int largest = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[largest]) {
                largest = i;
            }
        }
        return largest; // position of the first largest found
    }

    @Override
    public int compareto(GAMLPAgent t) {
        if (this.getMask().getscore() < t.getMask().getscore()) {
            return -1;
        } else if (t.getMask().getscore() < this.getMask().getscore()) {
            return 1;
        }
        return 0;
    }

    public void debugLocation() {
        Rectangle head = mask.getSnakeParts().get(0);
        Rectangle food = mask.getFood();
        System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
        System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getoutput()));
    }

    public void debuginput() {
        String s = "";
        for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
            s += mlp.getInputNeurons().get(i).getoutput() + " ";
        }
        System.out.println(s);
    }

    public double[] getoutput() {
        return mlp.getoutput();
    }
}

这是我的代码的主要类geneticSnake2.java,它是游戏循环所在的位置,也是我为子蛇分配基因的位置(我知道这样做可以做得更干净)。

package main;

import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;

/**
 *
 * @author Preston Tang
 */
public class geneticSnake2 extends Application {

    private final int width = 45;
    private final int height = 40;

    private final double displaySize = 120;
    private final double size = 12;

    private final Color pathColor = Color.rgb(120,120,120);
    private final Color wallColor = Color.rgb(50,50,50);

    private final int initSnakeLength = 2;

    private final int populationSize = 1000;

    private int generation = 0;

    private int initSteps = 100;
    private int stepsIncrease = 50;

    private double parentPercentage = 0.2;

    private final ArrayList<Color> snakeColors = new ArrayList() {
        {
            add(Color.GREEN);
            add(Color.RED);
            add(Color.YELLOW);
            add(Color.BLUE);
            add(Color.magenta);
            add(Color.PINK);
            add(Color.ORANGERED);
            add(Color.BLACK);
            add(Color.GOLDENROD);
            add(Color.WHITE);
        }
    };

    private final ArrayList<Snake> snakes = new ArrayList<>();

    private final ArrayList<GAMLPAgent> agents = new ArrayList<>();

    private long initTime = System.nanoTime();

    @Override
    public void start(Stage stage) {
        Pane root = new Pane();
        Pane graphics = new Pane();
        graphics.setPrefheight(height * size);
        graphics.setPrefWidth(width * size);
        graphics.setTranslateX(0);
        graphics.setTranslateY(displaySize);

        Pane display = new Pane();
        display.setStyle("-fx-background-color: BLACK");
        display.setPrefheight(displaySize);
        display.setPrefWidth(width * size);
        display.setTranslateX(0);
        display.setTranslateY(0);

        root.getChildren().add(display);

        SnakeGrid sg = new SnakeGrid(pathColor,wallColor,width,height,size);

        //Parsing "adjectives.txt" and "nouns.txt" to form possible names
        ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
        ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));

        //Initializing the population
        for (int i = 0; i < populationSize; i++) {
            //Get random String from lists and capitalize first letter
            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
            adj = adj.substring(0,1).toupperCase() + adj.substring(1);

            String noun = nouns.get(new Random().nextInt(nouns.size()));
            noun = noun.substring(0,1).toupperCase() + noun.substring(1);

            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

            //We want to see the first snake
            if (i == 0) {
                InfoBar bar = new InfoBar();
                bar.getStatusText().setText("Status: Alive");
                bar.getStatusText().setFill(Color.GREENYELLOW);
                bar.getSizeText().setText("Population Size: " + populationSize);

                Snake snake = new Snake(bar,adj + " " + noun,size,initSnakeLength,color,initSteps,stepsIncrease);
                bar.getNameText().setText("Name: " + snake.getName());

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake,size));

            } else {
                Snake snake = new Snake(adj + " " + noun,stepsIncrease);

                snakes.add(snake);
                agents.add(new GAMLPAgent(snake,size));
            }
        }

        //Focused on original snake
        display.getChildren().add(snakes.get(0).getInfoBar());

        graphics.getChildren().addAll(sg);

        graphics.getChildren().addAll(snakes.get(0));

        root.getChildren().add(graphics);

        //Add the speed controller (slider)
        Slider slider = new Slider(1,10,10);
        slider.setTranslateX(205);
        slider.setTranslateY(75);
        slider.setdisable(true);

        root.getChildren().add(slider);

        Scene scene = new Scene(root,width * size,height * size + displaySize);
        stage.setScene(scene);

        //Fixes the setResizable bug
        //https://stackoverflow.com/questions/20732100/javafx-why-does-stage-setresizablefalse-cause-additional-margins
        stage.setTitle("21-geneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
        stage.setResizable(false);
        stage.sizetoScene();
        stage.show();

        AnimationTimer timer = new AnimationTimer() {
            private long lastUpdate = 0;

            @Override
            public void handle(long Now) {
                if (Now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
                    lastUpdate = Now;

                    int alive = populationSize;
                    for (int i = 0; i < snakes.size(); i++) {
                        Snake snake = snakes.get(i); //Current snake

                        if (i == 0) {
                            Collections.sort(agents);
                            snake.getInfoBar().getscoreText().setText("score: " + snake.getscore() + " (" + agents.get(agents.size() - 1).getMask().getscore() + ")");
                        }

                        if (!snake.isAlive()) {
                            alive--;

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getStatusText().setText("Status: Dead");
                                snake.getInfoBar().getStatusText().setFill(Color.RED);
                                graphics.getChildren().remove(snake);
                            }

                        } else {
                            //If out of steps
                            if (snake.getSteps() <= 0) {
                                snake.setAlive(false);
                            }

                            //Bounds Detection (left right up down)
                            if (snake.getSnakeParts().get(0).getX() >= width * size
                                    || snake.getSnakeParts().get(0).getX() <= 0
                                    || snake.getSnakeParts().get(0).getY() >= height * size
                                    || snake.getSnakeParts().get(0).getY() <= 0) {
                                snake.setAlive(false);
                            }

                            //Self-Collision Detection
                            for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
                                if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
                                        && snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
                                    snakes.get(o).setAlive(false);
                                }
                            }

                            int rate = (int) slider.getValue();
                            int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);

                            agents.get(i).compute();
                            snake.manageMovement();
                            snake.setSecondsAlive(seconds);

//                            agents.get(0);
//                            System.out.println(Arrays.toString(agents.get(0).getoutput()));
//                            
//                            System.out.println("\n\n\n\n\n\n\n");
                            //Expression to calculate score
                            double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2,snake.getConsumed()) + Math.pow(snake.getConsumed(),2.1) * 500)
//        - (Math.pow(snake.getConsumed(),1.2) * Math.pow(0.25 * snake.getSteps(),1.3));

                            snake.setscore(Math.round(exp * 100.0) / 100.0);

                            //Update graphics for main snake
                            if (i == 0) {
                                snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
                                snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
                                snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
                                snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
                            }
                        }
                    }

                    //Reset and breed
                    if (alive == 0) {
                        //Ascending order
                        initTime = System.nanoTime();
                        generation++;
                        graphics.getChildren().clear();
                        graphics.getChildren().addAll(sg);
                        snakes.clear();

                        //x% of snakes are parents
                        int parentNum = (int) (populationSize * parentPercentage);

                        //Faster odd number check
                        if ((parentNum & 1) != 0) {
                            //If odd make even
                            parentNum += 1;
                        }

                        for (int i = 0; i < parentNum; i += 2) {
                            //Get the 2 parents,sorted by score
                            GAMLPAgent p1 = agents.get(populationSize - (i + 2));
                            GAMLPAgent p2 = agents.get(populationSize - (i + 1));

                            //Produce the next generation
                            double[][] childGenes = p1.breed(p2,((populationSize - parentNum) / parentNum) * 2);

                            //Debugs Genes
//                            System.out.println(Arrays
//                                    .stream(childGenes)
//                                    .map(Arrays::toString)
//                                    .collect(Collectors.joining(System.lineseparator())));
                            //Soft copy
                            ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);

                            for (int o = 0; o < childGenes.length; o++) {
                                temp.get(o).getMLP().setWeights(childGenes[o]);
                            }

                            //Add the genes of every pair of parents to the children
                            for (int o = 0; o < childGenes.length; o++) {
                                //Useful debug message
//                                System.out.println("ParentNum: " + parentNum
//                                        + " ChildPerParent: " + (populationSize - parentNum) / parentNum
//                                        + " Index: " + (o + (i / 2 * childGenes.length))
//                                        + " ChildGenesNum: " + childGenes.length
//                                        + " Var O: " + o);

                                //Adds the genes of the temp to the agents
                                agents.set((o + (i / 2 * childGenes.length)),temp.get(o));
                            }
//                            System.out.println("\n\n\n\n\n\n");
                        }

                        //Debugging the snakes' genes to a file
//                        String str = "";
//                        for (int i = 0; i < agents.size(); i++) {
//                            str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+  "\n\n\n";
//                        }
//
//                        printToFile(str,"gen" + generation);

                        for (int i = 0; i < populationSize; i++) {
                            //Get random String from lists and capitalize first letter
                            String adj = adjectives.get(new Random().nextInt(adjectives.size()));
                            adj = adj.substring(0,1).toupperCase() + adj.substring(1);

                            String noun = nouns.get(new Random().nextInt(nouns.size()));
                            noun = noun.substring(0,1).toupperCase() + noun.substring(1);

                            Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));

                            //We want to see the first snake
                            if (i == 0) {
                                InfoBar bar = new InfoBar();
                                bar.getStatusText().setText("Status: Alive");
                                bar.getStatusText().setFill(Color.GREENYELLOW);
                                bar.getSizeText().setText("Population Size: " + populationSize);

                                Snake snake = new Snake(bar,stepsIncrease);
                                bar.getNameText().setText("Name: " + snake.getName());
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            } else {
                                Snake snake = new Snake(adj + " " + noun,stepsIncrease);
                                snakes.add(snake);
                                agents.get(i).setMask(snake);
                            }
                        }

                        graphics.getChildren().add(snakes.get(0));
                        display.getChildren().clear();

                        //Focused on original snake at first
                        display.getChildren().add(snakes.get(0).getInfoBar());
                    }
                }
            }
        };
        //Starts the infinite loop
        timer.start();
    }

    public String readFile(File f) {
        String content = "";
        try {
            content = new Scanner(f).useDelimiter("\\Z").next();
        } catch (FileNotFoundException ex) {
            System.err.println("Error: Unable to read " + f.getName());
        }
        return content;
    }

    public void printToFile(String str,String name) {
        FileWriter fileWriter;
        try {
            fileWriter = new FileWriter(name + ".txt");
            try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
                bufferedWriter.write(str);
            }

        } catch (IOException ex) {
            ex.printstacktrace();
        }
    }

    public static void main(String[] args) {
        launch(args);
    }
}

主要的问题是,即使经过几千代,这些蛇仍只是简单地自杀而入墙。在我上面链接的视频中,这些蛇在躲避墙壁,并像第5代一样获得食物。我怀疑问题出在我为正在出生的蛇分配基因的主要班级中。

我实际上已经坚持了几个星期。以前,我怀疑的问题之一是缺乏投入,因为那时我的路还少。但是现在,我认为情况已不再如此。如果需要,我可以尝试在4个对角线方向上看一下,再向蛇的MLP添加12个输入。我也曾去过人工智能不和谐寻求帮助,但实际上并没有找到解决方案。

如果需要,我愿意发送我的全部代码,以便您可以自己运行仿真。

如果您已经读到这里,谢谢您抽出宝贵的时间来帮助我!非常感谢。

解决方法

我不奇怪你的蛇快死了。

让我们退后一步。 AI到底是什么?好吧,这是一个搜索问题。我们正在某个参数空间中搜索,以找到在给定游戏当前状态下可以解决蛇的参数集。您可以想象一个具有全局最小值的参数空间:可能的最佳蛇,犯错误最少的蛇。

所有学习算法都从此参数空间中的某个点开始,并尝试找到随时间变化的全局最大值。首先,让我们考虑一下MLP。 MLP通过尝试一组权重,计算损失函数,然后朝进一步降低损失(梯度下降)的方向迈出一步来学习。一个MLP可以找到一个最小值很明显,但是是否可以找到一个足够好的最小值是一个问题,并且存在很多可以提高这种可能性的训练技术。

另一方面,遗传算法的收敛特性非常差。首先,让我们停止调用这些遗传算法。叫这些 而是使用smorgasbord算法。 smorgasbord算法从两个父对象获取两组参数,将它们混杂在一起,然后生成一个新的smorgasbord。是什么让您认为这比两者中的任何一个都更好?您在这里最小化什么?您怎么知道它正在变得更好?如果您附加了损失函数,怎么知道自己处在实际上可以最小化的空间中?

我要提出的观点是,遗传算法是不受原理限制的,与自然不同。大自然不仅将密码子放入混合器中以制造新的DNA链,而且这正是遗传算法所做的。有一些技术可以增加一些爬山时间,但是遗传算法仍然具有tons of problems

要点是,不要被名字淹没。遗传算法就是smorgasbord算法。我认为您的方法行不通,因为GA不能保证在无限迭代后收敛,而MLP不能保证收敛到良好的全局最小值。

该怎么办?好吧,一种更好的方法是使用适合您问题的学习范例。更好的方法是使用强化学习。这个主题有很好的course on Udacity from Georgia Tech