poj3074, DLX解数独源代码

2013年6月26日 20:49

原题

 

#include <cstdio>
#include <iostream>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <vector>
#include <set>
#include <queue>
#include <map>
#include <iterator>
#include <algorithm>


using namespace std;


int const N = 3;
int PN = N * N, QN = PN * PN;

/***最大行***/
#define MAXROW 1001
/***最大列***/
#define MAXCOL 1001

struct DancingLinksNode {
    /***结点所在的行列位置***/
    int r, c;
    /***结点的上下左右结点指针***/
    DancingLinksNode *U, *D, *L, *R;
};

/****备用结点****/
DancingLinksNode node[MAXROW * 101];
/****行头****/
DancingLinksNode row[MAXROW];
/****列头****/
DancingLinksNode col[MAXCOL];
/****表头****/
DancingLinksNode head;
/****使用了多少结点****/
int cnt;
/****列含有多少个域****/
int size[MAXCOL];
/****表的行与列变量****/
int m, n;
/****选择的行****/
int choice[MAXROW];

/****初始化,r, c分别表示表的大小***/
void init(int r, int c) {
    /****将可以使用的结点设为第一个****/
    cnt = 0;
    /****head结点的r,c分别表示表的大小,以备查****/
    head.r = r;
    head.c = c;
    /****初始化head结点****/
    head.L = head.R = head.U = head.D = &head;

    /***初始化列头***/
    for(int i = 0; i < c; ++i) {
        col[i].r = r;
        col[i].c = i;
        col[i].L = head.L;
        col[i].R = &head;
        col[i].L->R = col[i].R->L = &col[i];
        col[i].U = col[i].D = &col[i];
        size[i] = 0;
    }


    /***初始化行头,在删除的时候,如果碰到row[i].c  == c的情形应当被跳过***/
    for(int i = r - 1; i > -1; --i) {
        row[i].r = i;
        row[i].c = c;
        row[i].U = head.U;
        row[i].D = &head;
        row[i].U->D = row[i].D->U = &row[i];
        row[i].L = row[i].R = &row[i];
    }
}

/****增加一个结点,在原表中的位置为r行,c列***/
inline void addNode(int r, int c) {
    /****找一个未曾使用的结点****/
    DancingLinksNode *ptr = &node[cnt++];

    /****设置结点的行列号****/
    ptr->r = r;
    ptr->c = c;

    /****将结点加入双向链表中****/
    ptr->L = row[r].L;
    ptr->R = &row[r];
    ptr->L->R = ptr->R->L = ptr;

    ptr->U = col[c].U;
    ptr->D = &col[c];
    ptr->U->D = ptr->D->U = ptr;

    /****将size域加1****/
    ++size[c];
}

/****删除ptr所指向的结点的左右方向****/
inline void delLR(DancingLinksNode * ptr) {
    ptr->L->R = ptr->R;
    ptr->R->L = ptr->L;
}

/****删除ptr所指向的结点的上下方向****/
inline void delUD(DancingLinksNode * ptr) {
    ptr->U->D = ptr->D;
    ptr->D->U = ptr->U;
}

/****重置ptr所指向的结点的左右方向****/
inline void resumeLR(DancingLinksNode * ptr) {
    ptr->L->R = ptr->R->L = ptr;
}

/****重置ptr所指向的结点的上下方向****/
inline void resumeUD(DancingLinksNode * ptr) {
    ptr->U->D = ptr->D->U = ptr;
}

/****覆盖第c例***/
inline void cover(int c) {
    /**** c == n 表示头****/
    if(c == n) {
        return;
    }

    /****删除表头****/
    delLR(&col[c]);

    DancingLinksNode *R, *C;
    for(C = col[c].D; C != (&col[c]); C = C->D) {
        if(C->c == n)
            continue;
        for(R = C->L; R != C; R = R->L){
            if(R->c == n)
                continue;
            --size[R->c];
            delUD(R);
        }
        delLR(C);
    }

}

/****重置第c列****/
inline void resume(int c) {
    if(c == n)
        return;
    DancingLinksNode *R, *C;
    for(C = col[c].U; C != (&col[c]); C = C->U) {
        if(C->c == n)
            continue;
        resumeLR(C);
        for(R = C->R; R != C; R = R->R) {
            if(R->c == n)
                continue;
            ++size[R->c];
            resumeUD(R);
        }
    }

    /****把列头接进表头中****/
    resumeLR(&col[c]);
}

/****搜索核心算法,k表示搜索层数****/
int search(int k = 0) {

    /***搜索成功,返回true***/
    if(head.L == (&head)) {
        return k;
    }
    /***c表示下一个列对象位置,找一个分支数目最小的进行覆盖***/
    int INF = (1<<30), c = -1;

    for(DancingLinksNode * ptr = head.L; ptr != (&head); ptr = ptr->L) {
        if(size[ptr->c] < INF) {
            INF = size[ptr->c];
            c = ptr->c;
        }
    }
    /***覆盖第c列***/
    cover(c);

    DancingLinksNode * ptr;

    for(ptr = col[c].D; ptr != (&col[c]); ptr = ptr->D) {
        DancingLinksNode *rc;
        ptr->R->L = ptr;
        choice[k] = ptr->r;
        for(rc = ptr->L; rc != ptr; rc = rc->L) {
            cover(rc->c);
        }
        ptr->R->L = ptr->L;
        int ans = search(k + 1);
        if(ans) {
            return ans;
        }
        ptr->L->R = ptr;
        for(rc = ptr->R; rc != ptr; rc = rc->R) {
            resume(rc->c);
        }
        ptr->L->R = ptr->R;
    }

    /***取消覆盖第c列***/
    resume(c);
    return 0;
}


void addPoss(int i, int j) {
    int x = i / PN, y = i % PN;
    addNode(i * PN + j, i);
    addNode(i * PN + j, QN * 1+ x * PN + j);
    addNode(i * PN + j, QN * 2 + y * PN + j);
    addNode(i * PN + j, QN * 3 + (x / N * N + y / N) * PN + j);
}

int main(int argc, char** argv) {
    char s[100];
    while(1) {
        scanf("%s", s);
        if(strcmp(s, "end") == 0) {
            break;
        }
        m = PN * QN, n = 4 * QN;
        init(m, n);
        for(int i = 0; i < QN; ++i) {
            if(s[i] == '.')  {
            	for(int j = 0; j < PN; ++j) addPoss(i, j);
            } else {
            	addPoss(i, s[i] - '1');
            }	
        }
        search();
        for(int i = 0; i < QN; i ++) {
            s[choice[i] / PN] = choice[i] % PN + '1';
        }
        puts(s);
    }
    return 0;
}

从“八数码”走进搜索

2009年9月25日 04:17

八数码是经典的搜索问题这已经是众所周知的 。

今天一天就做了这一个问题,感觉挺有收获的,于是想写点心得。

先来看下题目:

http://acm.pku.edu.cn/JudgeOnline/problem?id=1077

最先想到的是直接BFS,这也是最简单的想法。但一开始就遇到了一个公共的问题(为什么说是公共的问题呢。因为不管你用什么方式的搜索都会遇到的)——如何判重。最简单的就是直接将棋盘转化成一个9位的十进制数。但这用什么数据结构呢?数组,10^9!这个内存是不允许的;用set,这个时间是不允许的。后面听勇哥说,全排是有完美的hash函数的,于是在网上找到这个——http://bbs3.chinaunix.net/viewthread.php?tid=1283459,这上面解释得很清楚。于是按照上面的hash方法写了个单向的BFS,附:

 

#include <cstdio>
#include <cstring>

struct node {
    int pre; //上一个节点和如何变换的信息
	int p; // x即9的下标
    char ch[10];//用于记录棋盘信息
}st[400000];
//队列数组 ,9!= 362880

bool vis[400000];//判重数组

int perm[] = {1,1,2,6,24,120,720,5040,40320};//n!
int d[] = { -1 , -3, 1, 3};//四个方向的下标变换
bool move[][4] = {0,0,1,1, 1,0,1,1, 1,0,0,1, 0,1,1,1, 1,1,1,1, 1,1,0,1, 0,1,1,0, 1,1,1,0, 1,1,0,0};
//各个位置的可行变换

int hash(char x[])//用逆序数和变进制进行hash
{
    int h = 0;
    for(int i = 1;i<9;i++){
        int count = 0;
        for(int j=0;j<i;j++)
            if(x[j] > x[i])count ++;
        h += count * perm[i];
    }
    return h;
}

bool end (char x[])//判断是否到目标状态
{
    for(char i = 0;i<9;i++){
        if(x[i] != i + '1') return false;
    }
    return true;
}

int BFS(char *num,int p)
{
    int left ,right;
    left = right = 1;
    memset(vis,0,sizeof(vis));
    strcpy(st[right].ch,num);
    st[right].p = p;
    st[right ++].pre = 0;
    vis[hash(st[left].ch)] = 1;
    while(left < right){
        int p1 ,p2 ;
        p1 = st[left].p;
        for(int i = 0 ;i<4;i++)if(move[p1][i]){
            p2 = st[right].p = p1 + d[i];
            strcpy(st[right].ch,st[left].ch);
            char tp = st[right].ch[p2] ;
            st[right].ch[p2] = st[right].ch[p1] ;
            st[right].ch[p1] = tp ;
            st[right].pre = left << 2 | i;//将上一个节点的下标和变换规则压进一个数字
            if(end(st[right].ch)) return st[right].pre;
            int key = hash(st[right].ch);
            if( ! vis[key] ) { vis[key] = 1 ; right ++;}
        }
        left ++;
    }
    return -1;
}


int main()
{
    char num[10] ,der[]="lurd",step[100];
    int p ;
    for(int i =0 ;i<9;){
        num[i] = getchar();
        if(num[i] == 'x') { num[i] = '9' ; p = i; }//将x转化成'9',这样便于hash,和判结束
        if(num[i]<='9' && num[i]>='1') i++;
    }
    num[9] = '\0';
    int i = 0;
    if(!end(num)){
        int k = BFS(num,p);
        if( k == -1)  printf("unsolvable");
        else {
            while(k){
                step[ ++ i] = der[k&3];
                k = st[k>>2].pre;
            }
            for(;i>0;i--) putchar(step[i]);
            puts("");
        }
    }
    return 0;
}

 

在知道初状态(输入)和末状态(12345678x)的情况下还可以用双向BFS,无论从空间还是从时间的复杂度来说,双向BFS都是比单向BFS都优。单向的BFS的时间和空间都是k^n级的,k是一个状态分支的个数,这里理论是 k = 4,实际上2到3之间,n是迭代的层数,双向BFS的时间和空间的复杂度都是k^(n/2)级的。就是双向BFS的代码复杂度要高一些。在poj上单向BFS是4104K 266MS 2067B,双向BFS是1280K 16MS 2589B

附:双向BFS代码

 
 

#include <cstdio>
#include <cstring>
#define M 400000
struct node {
    int pre,p;
	int key; //这个节点状态对应的hash值
    char ch[10];
}st[2][M];

bool vis[2][M];

int perm[] = {1,1,2,6,24,120,720,5040,40320};
int d[] = { -1 , -3, 1, 3};
bool move[][4] = {0,0,1,1, 1,0,1,1, 1,0,0,1, 0,1,1,1, 1,1,1,1, 1,1,0,1, 0,1,1,0, 1,1,1,0, 1,1,0,0};

int hash(char x[])
{
    int h = 0;
    for(int i = 1;i<9;i++){
        int count = 0;
        for(int j=0;j<i;j++)
            if(x[j] > x[i])count ++;
        h += count * perm[i];
    }
    return h;
}

int BFS(char *num,int p)
{
    int l[2] ,r[2] ,key;
    l[0] = l[1] = r[0] = r[1] = 1;
    memset(vis,0,sizeof(vis));

    strcpy(st[0][r[0]].ch,num);
    st[0][r[0]].p = p;
    st[0][r[0]].pre = 0;
    key = hash(st[0][l[0]].ch);
    vis[0][key] = 1;
    st[0][r[0] ++].key = key;

    strcpy(st[1][r[1]].ch,"123456789");
    st[1][r[1]].p = 8;
    st[1][r[1]].pre = 0;
    key = hash(st[1][l[1]].ch);
    vis[1][key] = 1;
    st[1][r[1] ++].key = key;

    while(l[0] < r[0] && l[1]<r[1]){
        int p1 ,p2 ,now = (r[0] - l[0] > r[1] - l[1]);//将节点少的一边进行扩展
        p1 = st[now][l[now]].p;
        for(int i = 0 ;i<4;i++)if(move[p1][i]){
            p2 = st[now][r[now]].p = p1 + d[i];
            strcpy(st[now][r[now]].ch,st[now][l[now]].ch);
            char tp = st[now][r[now]].ch[p2] ;
            st[now][r[now]].ch[p2] = st[now][r[now]].ch[p1] ;
            st[now][r[now]].ch[p1] = tp ;
            st[now][r[now]].pre = l[now] << 2 | i;
            key = hash(st[now][r[now]].ch);
            st[now][r[now]].key = key;
            if( ! vis[now][key] ) { vis[now][key] = 1 ; r[now] ++;}
            if( vis[1-now][key]) return key;//检测当前扩展的状态在另一边是否已出现,
        }
        l[now] ++;
    }
    return -1;
}


int main()
{
    char num[10] ,der[]="lurd",step[100];
    int p ;
    for(int i =0 ;i<9;){
        num[i] = getchar();
        if(num[i] == 'x') { num[i] = '9' ; p = i; }
        if(num[i]<='9' && num[i]>='1') i++;
    }
    num[9] = '\0';
    int i = 0;
    int k = BFS(num,p);
    if( k == -1)  printf("unsolvable\n");
    else {
        int p = 1;
        while(st[0][p].key != k) p++;
        p = st[0][p].pre ;
        while(p){
            step[ ++ i] = der[p&3];
            p = st[0][p>>2].pre;
        }
        for(;i>0;i--) putchar(step[i]);
        p = 1;
        while(st[1][p].key != k) p++;
        p = st[1][p].pre;
        while(p){
            putchar(der[((p & 3) + 2)%4]);
            p = st[1][p>>2].pre;
        }
        puts("");
    }
    return 0;
}

 

这道题目还可以用A*来写,原来看过这方面的书籍,还没用自己动手写过A*,于是尝试着写了写。这里用搜索树的深度做g(s),用但前状态各个字符和目标状态这个字符的哈密尔顿距离之和做f(s),这里做了个比较危险的假设:我把路径压缩在一个long long 中,一个变换只有4种可能,用2个bit就能保存了,大家都知道long long 占64个bit,所以最多可以存32个变化,也就是说,一个long long 只能保存一天短于32的路径(不知道这样描述是否清楚)。我假设,对于任何状态,变到目标状态至多的步数不会超过32步(无解的状态除外),就像3阶魔方被证明20+步一定能被搞定一样,(如果哪个神牛能证明一下就好了……)。

同样这里也贴出代码(就不注释了,如果有没理解的发邮件给我吧,hustsxh@gmail.com):

 

#include <cstdio>
#include <cstring>
#include <cmath>
#include <queue>

struct node {
    long long path;
    char ch[10];
    int g,f,p;
    bool operator <(const node & a)const{
        return f>a.f || f==a.f &&  g > a.g;
    }
};

bool vis[400000];

int perm[] = {1,1,2,6,24,120,720,5040,40320};
int d[] = { -1 , -3, 1, 3};
bool move[][4] = {0,0,1,1, 1,0,1,1, 1,0,0,1, 0,1,1,1, 1,1,1,1, 1,1,0,1, 0,1,1,0, 1,1,1,0, 1,1,0,0};

int hash(char x[])
{
    int h = 0;
    for(int i = 1;i<9;i++){
        int count = 0;
        for(int j=0;j<i;j++)
            if(x[j] > x[i])count ++;
        h += count * perm[i];
    }
    return h;
}

bool end (char x[])
{
    for(char i = 0;i<9;i++){
        if(x[i] != i + '1') return false;
    }
    return true;
}

int h(char x[])
{
    int h = 0;
    for(int i=0;i<9;i++){
        h += abs( '1' + i -x[i]);
    }
    return h;
}

int Astar(char num[],int p,long long & path)
{
    memset(vis,0,sizeof(vis));
    std::priority_queue <node> c;
    node start ;
    strcpy(start.ch,num);
    start.path = 0;
    start.p = p;
    start.g = 0;
    start.f = start.g + h(start.ch);
    vis[hash(start.ch)] = 1;
    c.push(start);
    while(!c.empty()){
        node next, pre = c.top();
        c.pop();
        for(int i = 0;i<4;i++)if(move[pre.p][i]){
            next.p = pre.p + d[i];
            next.g = pre.g + 1;
            strcpy(next.ch,pre.ch);
            next.ch[pre.p] = next.ch[next.p];
            next.ch[next.p] = '9';
            next.path = pre.path<<2 | i;
            next.f = h(next.ch) + next.g ;
            if(end(next.ch)) { path = next.path ; return next.g;}
            int key = hash(next.ch);
            if(!vis[key]) { vis[key] = 1; c.push(next); }
        }
    }
    return -1;
}

int main()
{
    char num[10] ,der[]="lurd",step[100];
    int p ;
    for(int i =0 ;i<9;){
        num[i] = getchar();
        if(num[i] == 'x') { num[i] = '9' ; p = i; }
        if(num[i]<='9' && num[i]>='1') i++;
    }
    num[9] = '\0';
    int i = 0;
    if(!end(num)){
        long long path ;
        int k = Astar(num,p,path);
        if( k==-1) printf("unsolvable\n");
        else {
            for(int i = k-1;i>=0;--i) {
                step[i] = der[path&3];
                path = path >>2;
            }
            step[k] = '\0';
            puts(step);
        }
    }
    return 0;
}

 估计是我写得不够优,在poj的的各个参数是1012K 47MS 2459B

 从八数码中学到了一个精妙、完美的hash,第一次写双向BFS,第一次写A*,而且交了三份代码三份代码都是1Y的,很hanppy,在此留下一笔,作为成长的痕迹……