前段时间人工智能的课介绍到A*算法,于是便去了解了一下,然后试着用这个算法去解决经典的八数码问题,一开始写用了挺久时间的,后来试着把算法的框架抽离出来,编写成一个通用的算法模板,这样子如果以后需要用到A*算法的话就可以利用这个模板进行快速开发了(对于刷OJ的题当然不适合,不过可以适用于平时写一些小游戏之类的东西)。
A*算法的原理就不过多介绍了,网上能找到一大堆,核心就是估价函数 g() 的定义,这个会直接影响搜索的速度,我在代码里使用 C++/Java 的多态性来编写业务无关的算法模板,用一个抽象类来表示搜索树中的状态,A*算法主类直接操纵这个抽象类,然后编写自己业务相关的类去继承这个抽象类并实现其中的所有抽象方法(C++里是纯虚函数),之后调用A*算法主类的 run 函数就能得到一条可行并且是最短的的搜索路径,下面具体看代码:(文末附所有代码的 github 地址)
先看 c++ 部分,毕竟一开始就是用 c++ 来写的
首先是表示状态的抽象基类CState,头文件 state.h:
#ifndef __state_h #define __state_h #include <cstddef> #include <vector> using std::vector; class CState { public: CState(); virtual bool operator < (const CState &) const=0; virtual void checkSomeFields(const CState &) const; virtual vector<CState*> getNextState() const=0; vector<CState*> __getNextState() const; // call the function getNextState and deal with iSteps and pparent virtual long astar_f() const; virtual long astar_g() const=0; // g函数的值越小,优先级就越高,f()和h()函数类似 virtual long astar_h() const; virtual ~CState(); int iSteps; const CState *pparent; // 必须指向实实际际存在的值!注意不要指向一个局部变量等! }; #endif
源文件 state.cpp:
#include "state.h" #include <algorithm> using std::for_each; CState::CState(): iSteps(0), pparent(NULL) {} void CState::checkSomeFields(const CState &) const {} vector<CState*> CState::__getNextState() const { vector<CState*> nextState = getNextState(); for_each(nextState.begin(), nextState.end(), [this](CState *pstate) { pstate->iSteps = this->iSteps + 1; pstate->pparent = this; }); return nextState; } long CState::astar_f() const { return iSteps; } long CState::astar_h() const { return astar_f() + astar_g(); } CState::~CState() {}
子类只需实现小于运算符,getNextState(),astar_g() 这三个纯虚函数就可以了,另外几个虚函数可以不重写,直接用父类的即可。
然后是A*算法主类 CAstar,头文件 astar.h:
#ifndef __ASTAR_H #define __ASTAR_H #include "state.h" #include <set> using std::set; class CAstar { public: CAstar(const CState &_start, const CState &_end); static set<const CState*> getStateByStartAndSteps(const CState &start, int steps); void run(); ~CAstar(); const CState &m_rStart, &m_rEnd; bool bCanSolve; int iSteps; vector<const CState*> vecSolve; long lRunTime; int iTotalStates; private: set<const CState*> pointerWaitToDelete; }; #endif
源文件 astar.cpp:
#include "astar.h" #include "timeval.h" #include "exception.h" #include <set> #include <queue> #include <algorithm> #include <cstdlib> #include <functional> using std::set; using std::queue; using std::priority_queue; using std::swap; using std::max; using std::sort; using std::function; #define For(i,s,t) for(auto i = (s); i != (t); ++i) CAstar::CAstar(const CState &_start, const CState &_end): m_rStart(_start), m_rEnd(_end), bCanSolve(false), iSteps(0), vecSolve{}, iTotalStates(0), lRunTime(0), pointerWaitToDelete{} { m_rStart.checkSomeFields(m_rEnd); } template <typename T> struct CPointerComp { bool operator () (const T &pl, const T &pr) const { return *pl < *pr; } }; set<const CState*> CAstar::getStateByStartAndSteps(const CState &start, int steps) { set<const CState*> retSet; set<const CState*, CPointerComp<const CState*> > inSet; inSet.insert(&start); queue<const CState*> queState; queState.push(&start); while(!queState.empty()) { const CState* const pCurState = queState.front(); queState.pop(); if(pCurState->iSteps > steps) { continue; } if(pCurState->iSteps == steps) { retSet.insert(pCurState); continue; } auto nextState = pCurState->__getNextState(); int len = nextState.size(); For(i, 0, len) { if(inSet.find(nextState[i]) == inSet.end()) { queState.push(nextState[i]); inSet.insert(nextState[i]); } else { delete nextState[i]; } } } inSet.erase(&start); For(ret_it, retSet.begin(), retSet.end()) { inSet.erase(*ret_it); } For(ins_it, inSet.begin(), inSet.end()) { delete *ins_it; } return retSet; } struct priority_state { bool operator () (const CState* const lhs, const CState* const rhs) const { return lhs->astar_h() > rhs->astar_h(); } }; void CAstar::run() { CTimeVal _time; set<const CState*, CPointerComp<const CState*>> setState; setState.insert(&m_rStart); priority_queue<const CState*, vector<const CState*>, priority_state> queState; queState.push(&m_rStart); while(!queState.empty()) { // auto pHeadState = *(setState.find(queState.top())); auto pHeadState = queState.top(); queState.pop(); if(!(*pHeadState < m_rEnd) && !(m_rEnd < *pHeadState)) { bCanSolve = true; iSteps = pHeadState->iSteps; vecSolve.push_back(pHeadState); const CState *lastState = pHeadState->pparent; while(lastState != NULL) { vecSolve.push_back(lastState); lastState = lastState->pparent; } break; } vector<CState*> nextState = pHeadState->__getNextState(); int len = nextState.size(); for(int i = 0; i < len; ++i) { auto state_it = setState.find(nextState[i]); if(state_it == setState.end()) { queState.push(nextState[i]); setState.insert(nextState[i]); } else { if((*state_it)->astar_f() > nextState[i]->astar_f()) { pointerWaitToDelete.insert(*state_it); // 这一句要放在setState.erase前面,防止迭代器失效 setState.erase(state_it); setState.insert(nextState[i]); queState.push(nextState[i]); } else { delete nextState[i]; } } } if(setState.size() > 6000 * 10000) { break ; } } iTotalStates = setState.size(); lRunTime = _time.costTime(); setState.erase(&m_rStart); For(vec_it, vecSolve.begin(), vecSolve.end()) { setState.erase(*vec_it); } For(s_it, setState.begin(), setState.end()) { delete *s_it; } } CAstar::~CAstar() { For(vec_it, vecSolve.begin(), vecSolve.end()) { if(*vec_it != &m_rStart && *vec_it != &m_rEnd) { delete *vec_it; } } for(const auto &pState: pointerWaitToDelete) { delete pState; } }
主搜索函数里是以 广度优先搜索 + 优先队列 来实现A*算法的,因为是用多态来实现,用到了指针,所以有些细节可能写得不是很好看,但是经运行测试过没有明显的bug,cpu和内存的使用均在正常的范围内。
以上两个类就是A*算法的主体框架了,但里面用到了自定义的异常类 CException 和计时类 CTimeVal 等一些工具类,具体代码可以在后面的 github 地址里看到。
然后是业务相关的类,这里首先是八数码问题的类 CChess,头文件 chess.h:
#ifndef __CCHESS_H #define __CCHESS_H #include "state.h" #include <iostream> #include <string> #include <vector> using std::string; using std::vector; using std::ostream; class CChess: public CState { friend ostream& operator << (ostream &, const CChess &); static int iLimitNum; public: CChess(const string &state, int row, int col, const string &standard=""); virtual bool operator < (const CState &) const; virtual void checkSomeFields(const CState &) const; const string& getStrState() const; void setStrStandard(const string &); virtual vector<CState*> getNextState() const; // virtual long astar_f() const; virtual long astar_g() const; // virtual long astar_h() const; private: void check_row_col() const; void check_value() const; void check_standard() const; inline int countNotMatch() const; inline int countLocalNotMatch(int, int) const; private: string strState; int iRow, iCol; int iZeroIdx; string strStandard; int iNotMatch; public: int iMoveFromLast; static const string directs[5]; enum DIRECT { UP, DOWN, LEFT, RIGHT, UNKOWN }; void output(ostream &out, const string &colSpace=" ", const string &rowSpace="\n") const; }; #endif