a*

6 minute read

A* 알고리즘

java와 c++로 작성하였다.

c++ 코드는 포인터 사용에 익숙치 않아 java로 구현하는것보다 오래 걸렸다.

openList는 priority queue이며 closeList는 ArrayList, vector로 구현하였다.

openList는 아직 방문하지 않는 노드 중 가장 짧은 f값을 가진 노드를 구하기 위한 자료구조이다.

closeList는 방문을 완료한 노드를 저장하는 자료구조이다.

시작지점에서 방문하지 않은 이웃노드를 찾아 f값을 계산하여 openList에 넣는다.

openList에 이미 포함된 노드는 f값을 계산하여 기존의 값보다 낮다면 갱신한다.

java에서는 priority queue의 contains로 openList에 있는지 확인하였으나 c++의 priority queue는 같은 방법이 불가능하여 openList에 입력되었는지 확인하는 set을 만들어 방문 유무를 확인하였다.

JAVA 코드

import java.util.*;
import java.io.*;
import java.io.IOException;
import java.io.InputStreamReader;

public class Main {
    public static void main(String args[]) throws Exception {
        InputStreamReader reader = new InputStreamReader(System.in);
        BufferedReader bufReader = new BufferedReader(reader);
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

        String input = bufReader.readLine();
        String[] inputs = input.split(" ");

        int x = Integer.parseInt(inputs[1]);
        int y = Integer.parseInt(inputs[0]);

        int[][] map = new int[y][x];

        // S는 출발지: 2
        // F는 도착지: 3
        // X는 장애물: 1
        // O는 빈칸: 0
        int startX = 0, startY = 0;
        int endX = 0, endY = 0;

        for(int i = 0; i < y; i++){
            input = bufReader.readLine();
            for(int p = 0; p < x; p++){
                switch (input.charAt(p)){
                    case 'X':
                        map[i][p] = 1;
                        break;
                    case 'S':
                        map[i][p] = 2;
                        startX = p;
                        startY = i;
                        break;
                    case 'F':
                        map[i][p] = 3;
                        endX = p;
                        endY = i;
                        break;
                }
            }
        }

        Graph graph = new Graph(x, y);
        graph.inputData(map);
        graph.aStar(startX, startY, endX, endY);
    }
}

// S: 출발 2
// O: 빈공간 0
// F: 도착 3
// X: 장애물 1

class Graph{
    private final int moveCost = 10;
    private int maxX, maxY;
    Node[][] nodeMap;

    public Graph(int sizeX, int sizeY){
        maxX = sizeX;
        maxY = sizeY;
        nodeMap = new Node[maxY][maxX];
    }

    public void inputData(int[][] data){
        for(int y = 0; y < data.length; y++){
            for(int x = 0; x < data[0].length; x++){
                if(data[y][x] != 1){
                    nodeMap[y][x] = new Node(x, y);
                }
            }
        }
    }

    public ArrayList<Node> getNeighbor(Node node){
        int x = node.x;
        int y = node.y;
        ArrayList<Node> result = new ArrayList<>();

        if(x + 1 < maxX){
            if(nodeMap[y][x + 1] != null){
                result.add(nodeMap[y][x + 1]);
            }
        }
        if(x > 0){
            if(nodeMap[y][x - 1] != null){
                result.add(nodeMap[y][x - 1]);
            }
        }
        if(y + 1 < maxY){
            if(nodeMap[y + 1][x] != null){
                result.add(nodeMap[y + 1][x]);
            }
        }
        if(y > 0){
            if(nodeMap[y - 1][x] != null){
                result.add(nodeMap[y - 1][x]);
            }
        }
        return result;
    }

    public void aStar(int startX, int startY, int endX, int endY){
        PriorityQueue<Node> openList = new PriorityQueue<>();
        ArrayList<Node> closeList = new ArrayList<>();

        Node startNode = nodeMap[startY][startX];
        Node endNode = nodeMap[endY][endX];

        openList.offer(startNode);

        while(!openList.isEmpty()){
            Node current = openList.poll();
            if(current.equals(endNode)){
                System.out.println("find!!");
//                printPath(current);
                printDist(current);
                return;
            }
            closeList.add(current);

            ArrayList<Node> neighbors = getNeighbor(current);

            for(Node item : neighbors){
                if(!closeList.contains(item)){
                    double g = current.g + moveCost;
                    double f = g + calDistance(item, endNode);
                    if(!openList.contains(item)){
                        item.parent = current;
                        item.g = g;
                        item.f = f;
                        openList.offer(item);
                    }
                    else{
                        if(f < item.f) {
                            item.parent = current;
                            item.g = g;
                            item.f = f;
                        }
                    }
                }
            }
        }
        System.out.println("Not Found");
    }

    private double calDistance(Node n1, Node n2){
        double x = n1.x - n2.x;
        double y = n1.y - n2.y;
        return Math.sqrt(x * x + y * y);
    }

    private void printPath(Node node){
        System.out.println("x: " + node.x + "  y: " + node.y);

        while(node.parent != null){
            node = node.parent;
            System.out.println("x: " + node.x + "  y: " + node.y);
        }
    }

    private void printDist(Node node){
        int cost = 0;
        while(node.parent != null){
            ++cost;
            node = node.parent;
        }
        System.out.println("dist: " + cost);
    }
}

class Node implements Comparable<Node>{
    int x;
    int y;
    double f, g;
    Node parent;

    public Node(int x, int y){
        this.x = x;
        this.y = y;
    }

    @Override
    public int compareTo(Node target) {
        if(this.f > target.f)
            return 1;
        else if(this.f < target.f)
            return -1;
        return 0;
    }
}

C++ 코드

#include <iostream>
#include <vector>
#include <queue>
#include <set>
#include <algorithm>
#include <stdio.h>

using namespace std;

// S: 출발 2
// O: 빈공간 0
// F: 도착 3
// X: 장애물 1

class Node {
public:
	int x, y;
	int type;
	double f, g;
	Node *parent;

	Node() {
		parent = NULL;
	}
	Node(int x, int y) : x(x), y(y) {
		parent = NULL;
	}
	bool operator==(const Node& node)const;
	bool operator<(const Node& node)const;
	void setNode(int x, int y, int type);
};

bool Node::operator==(const Node& node)const {
	if (this->x == node.x && this->y == node.y)
		return true;
	return false;
}

bool Node::operator<(const Node& node)const {
	if (y == node.y) {
		return x < node.x;
	}
	return y < node.y;
}

void Node::setNode(int x, int y, int type) {
	this->x = x;
	this->y = y;
	this->type = type;
}

class Graph {
private:
	int moveCost;
	int maxX;
	int maxY;
	Node **nodeMap;

public:
	Graph(int sizeX, int sizeY) {
		maxX = sizeX;
		maxY = sizeY;
		nodeMap = new Node*[maxY];
		for (int i = 0; i < maxY; ++i)
			nodeMap[i] = new Node[maxX];
	}
	~Graph();

	void inputData(int** data, int ySize, int xSize);
	vector<Node*> getNeighbor(const Node& node) const;
	double calDistance(Node& n1, Node& n2);
	void printDist(const Node& node) const;
	void printPath(const Node& node) const;
	void aStar(int startX, int startY, int endX, int endY);
};

Graph::~Graph() {
	for (int i = 0; i < maxY; ++i) {
		delete[] nodeMap[i];
	}
	delete[] nodeMap;
}

void Graph::inputData(int** data, int ySize, int xSize) {
	for (int y = 0; y < ySize; y++) {
		for (int x = 0; x < xSize; x++) {
			if (data[y][x] != 1) {
				nodeMap[y][x].setNode(x, y, data[y][x]);
			}
		}
	}
}

vector<Node*> Graph::getNeighbor(const Node& node) const {
	int x = node.x;
	int y = node.y;
	vector<Node*> result;

	if (x + 1 < maxX) {
		if (nodeMap[y][x + 1].type != 1) {
			result.push_back(&nodeMap[y][x + 1]);
		}
	}
	if (x > 0) {
		if (nodeMap[y][x - 1].type != 1) {
			result.push_back(&nodeMap[y][x - 1]);
		}
	}
	if (y + 1 < maxY) {
		if (nodeMap[y + 1][x].type != 1) {
			result.push_back(&nodeMap[y + 1][x]);
		}
	}
	if (y > 0) {
		if (nodeMap[y - 1][x].type != 1) {
			result.push_back(&nodeMap[y - 1][x]);
		}
	}
	return result;
}

double Graph::calDistance(Node& n1, Node& n2) {
	double x = n1.x - n2.x;
	double y = n1.y - n2.y;
	return sqrt(x * x + y * y);
}

void Graph::printDist(const Node& node) const {
	int cost = 0;
	const Node* temp = &node;
	while (temp->parent != NULL) {
		++cost;
		temp = temp->parent;
	}
	printf("dist: %d", cost);
}

void Graph::printPath(const Node& node) const {
	const Node* temp = &node;

	printf("x: %d  y: %d\n", temp->x, temp->y);

	while (temp->parent != NULL) {
		temp = temp->parent;
		printf("x: %d  y: %d\n", temp->x, temp->y);
	}
}

void Graph::aStar(int startX, int startY, int endX, int endY) {
	priority_queue<Node*> openList;
	vector<Node*> closeList;
	set<Node*> visitedList;

	Node *startNode = &nodeMap[startY][startX];
	Node *endNode = &nodeMap[endY][endX];

	openList.push(startNode);

	while (!openList.empty()) {
		Node *current = openList.top();
		openList.pop();
		if (current == endNode) {
			printf("find!!\n");
			printPath(*current);
			printDist(*current);
			return;
		}
		closeList.push_back(current);

		vector<Node*> neighbors = getNeighbor(*current);

		for (Node *item : neighbors) {
			if (find(closeList.begin(), closeList.end(), item) == closeList.end()) { // 이웃 노드 중 closeList에 없는 노드
				double g = current->g + moveCost;
				double f = g + calDistance(*item, *endNode);
				auto iter = visitedList.find(item);

				if (iter == visitedList.end()) { // 처음 방문하는 노드일 경우
					item->parent = current;
					item->g = g;
					item->f = f;
					openList.push(item);
					visitedList.insert(item);
				}
				else {
					if (f < item->f) {
						item->parent = current;
						item->g = g;
						item->f = f;
					}
				}
			}
		}
	}
	printf("Not Found\n");
}

int main() {
	Graph *graph;
	int x, y;
	int startX, startY;
	int endX = 0, endY = 0;
	int **data;
	char* input = NULL;

	scanf_s("%d %d", &y, &x);

	data = new int*[y];
	input = new char[x + 1];
	for (int i = 0; i < y; ++i) {
		data[i] = new int[x];
	}

	for (int i = 0; i < y; ++i) {
		scanf_s("%s", input, x + 1);
		for (int p = 0; p < x; ++p) {
			switch (input[p]) {
			case 'O':
				data[i][p] = 0;
				break;
			case 'X':
				data[i][p] = 1;
				break;
			case 'S':
				data[i][p] = 2;
				startX = p;
				startY = i;
				break;
			case 'F':
				data[i][p] = 3;
				endX = p;
				endY = i;
				break;
			}
		}
	}

	graph = new Graph(x, y);
	graph->inputData(data, y, x);
	graph->aStar(startX, startY, endX, endY);
}