1. 程式人生 > 其它 >[模板] 插頭 DP

[模板] 插頭 DP

[模板] 插頭DP——從入門到入墳

陳丹琦——《基於連通性狀態壓縮的動態規劃問題》

傳送門

模板是插頭DP的入門題,詢問 帶障礙網格中的合法迴路個數

概念類

  • 棋盤模型問題:採用逐行,逐列,逐格的狀態轉移方式。

對於此題,逐格轉移是最快的。

插頭:整個 DP 中的核心。

定義

對於一個四聯通問題來說,一個格子通常有上下左右四個插頭,一個格子再一個方向上的插頭定義為 該格子可以通過這個方向與外界連通。

比如說,插頭 \(1\) 是上面黃底色格子的下插頭,是下面格子的上插頭。

如果逐行DP的話,第 \(i\) 行的所有下插頭會成為第 \(i+1\) 行的所有上插頭。

狀態表示法

通過狀態表示法,我們可以把輪廓線處的插頭狀態表示為一個 \(p\)

進位制數。

  • 最小表示法。

最小表示法,也就是用不同的數字來表示 輪廓線以上(已知部分) 不同的插頭所處的線路。

比如上圖輪廓線處狀壓起來就是 \((1,1,2,2)\),至於最小表示類似於 字串的最小表示,採用四進位制儲存以加快運算。

  • 括號序列表示法。

對於一對連通的插頭,用數對 \((1,2)\) 來表示,用 \(1\) 來代替左括號 \((\)\(2\) 來代替右括號 \()\)

類似於括號序列,所以把它叫做括號序列表示法。

比如上圖狀壓起來就是 \((1,2,1,2)\)

所有的狀態表示法,都基於對線路的唯一表示,防止衝突。只要線路與狀態表示一一對應,就可以是一種狀態表示法。

輪廓線狀壓轉移

通過狀態表示法,根據輪廓線插頭對當前格 \((i,j)\) 的影響 考慮幾類可能的轉移。

所謂輪廓線狀壓,也就是隻關注輪廓線處的一些插頭對逐格轉移過程中當前格的影響。

根據格子所處位置附近連通塊的連通情況,可以分為三類。

  1. 新建一個連通塊。

當且僅當輪廓線處不存在向右或者向下的插頭,\((i,j)\) 提供向下和向右的插頭。

  1. 連線兩個已有連通塊。

輪廓線處有向右和向下的插頭,\((i,j)\) 分別通過向左和向上的插頭把它們連線起來。

  1. 接上之前的連通塊。

也就是輪廓線處只有一個向右或向下的插頭,\((i,j)\) 分別提供向左或向上的插頭。

那麼設 \(f(i,j,S)\)

表示當前處理完了 \((i,j)\) 及之前的點,輪廓線附近括號序列為 \(S\) 的方案數。

程式碼實現

  • 採用滾動陣列,對於一個 \((i,j)\) 維護一維 \(f\)
  • 用雜湊表把 \(S\) 對映為序列數字,維護上一個點 \((i,j-1)\) 的所有狀態資訊,用掛鏈法處理雜湊衝突。
  • 採用刷表法,用雜湊表維護 DP 過程。

上面討論過了按照連通性分類,下面按照 \(up\)\(left\) 插頭的情況進行分類。

  1. 當前為障礙點,必須保證兩個插頭均為空才可加入決策集合。

  2. 當前不是障礙點。

    1. 兩個插頭都沒有:考慮新增一個連通塊,括號序列改變。

    2. 只有上面過來的插頭:選擇向右走或者向下走,注意:插頭序列均改變

    3. 只有左面過來的插頭:選擇向右走或者向下走,插頭序列均改變。

    4. 上面和左面過來的插頭都是 \(1\),也就是都是左括號:向右找到第一個能和當前左括號匹配上的右括號的位置,計算插頭貢獻,插頭序列改變。

    5. 上面和左面過來的插頭都是 \(2\),也就是都是右括號:向左找到第一個能和當前右括號匹配上的左括號的位置,計算插頭貢獻,插頭序列改變。

    6. 上面過來的插頭是 \(1\),左面過來的插頭是 \(2\),直接連起來即可。

    7. 上面過來的插頭是 \(2\),左面過來的插頭是 \(1\),說明形成了迴路,在最後遍歷到的非障礙點把答案加上,其它情況不管,這保證了最後一定是一種方案對應一條迴路。

狀壓插頭序列時的細節問題

對應於這一行:

for(int j=1;j<=tot[now];j++)a[now][j]<<=2;//詳細分析

這就涉及到了當前 \(i\) 這一行的狀壓問題。

可以發現,一條豎線( \((i,j)\) 左側的邊)將這一行分成了兩部分,其中這一豎線佔據第 \(j-1\) 位,這條豎線之前的第 \(t\) 列由於它都佔據了第 \(t-1\) 位 ,而豎線右側的第 \(t\) 列都佔據了第 \(t\) 位。這是 \((i,j)\) 被刷到之前,也就是轉移之前。

轉移之後,豎線變為佔據第 \(j\) 位,而之前的第 \(j\) 列佔據了第 \(j-1\) 列,也就是說每次取到上面插頭都是第 \(j\) 列,對於 上一行結束的狀態序列,第 \(j\) 列表示的還是第 \(j-1\),所以需要左移一位,這僅僅在每一行開始的時候做這個事情。

程式碼

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
inline T read(){
	T x=0;char ch=getchar();bool fl=false;
	while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();}
	while(isdigit(ch)){
		x=(x<<3)+(x<<1)+(ch^48);ch=getchar();
	}
	return fl?-x:x;
}
const int maxn = 20 , P = 590027;
const int maxm = 6e5 + 10;
int n,m,mp[maxn][maxn],stx,sty;
#define LL long long
#define read() read<int>()
LL ans,f[2][maxm];
int head[maxm],tot[2],now,last;
int nxt[maxm],a[2][maxm];
int bit[28];
void Hash(int sta,LL val){//輔助維護上一階段資訊的雜湊表
	int t=sta%P+1;
	for(int i=head[t];i;i=nxt[i]){
		if(a[now][i]==sta)return f[now][i]+=val,void();
	}
	nxt[++tot[now]]=head[t];head[t]=tot[now];
	a[now][tot[now]]=sta;f[now][tot[now]]=val;
}
void solve(){
	tot[now]=1;a[now][1]=0;f[now][1]=1;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=tot[now];j++)a[now][j]<<=2;//詳細分析
		for(int j=1;j<=m;j++){
			last=now;now^=1;tot[now]=0;
			memset(head,0,sizeof head);
			for(int k=1;k<=tot[last];k++){
				int sta=a[last][k],up=(sta>>(j*2))%4,left=(sta>>(j*2-2))%4;
				LL val=f[last][k];
				if(!mp[i][j]){if(!up && !left)Hash(sta,val);}//裡面判不判都一樣,保證合法即可
				else if(!up && !left){
					if(mp[i+1][j] && mp[i][j+1])Hash(sta+bit[j-1]+2*bit[j],val);//新開一個,1,0
				}
				else if(up && !left){
					if(mp[i+1][j])Hash(sta-bit[j]*up+bit[j-1]*up,val);//go down
					if(mp[i][j+1])Hash(sta-bit[j]*up+bit[j]*up,val);//go right
				}
				else if(!up && left){
					if(mp[i][j+1])Hash(sta-bit[j-1]*left+bit[j]*left,val);//go right
					if(mp[i+1][j])Hash(sta-bit[j-1]*left+bit[j-1]*left,val);//go down
				}
				else if(up==1 && left==1){//找到第一個匹配的右括號
					int sz=1;
					for(int t=j+1;t<=m;t++){
						if((sta>>(t*2))%4==1)sz++;
						if((sta>>(t*2))%4==2)sz--;
						if(!sz){
							Hash(sta-bit[j]-bit[j-1]-bit[t],val);//右括號->左括號
							break;
						}
					}
				}
				else if(up==2 && left==2){//找到第一個匹配的左括號
					int sz=1;
					for(int t=j-2;t>=0;t--){//t-1 --> t(真實)
						if((sta>>(t*2))%4==1)sz--;
						if((sta>>(t*2))%4==2)sz++;
						if(!sz){
							Hash(sta-2*bit[j]-2*bit[j-1]+bit[t],val);//左括號->右括號
							break;
						}
					}
				}
				else if(up==1 && left==2)Hash(sta-2*bit[j-1]-bit[j],val);
				else if(up==2 && left==1){
					if(i==stx && j==sty)ans+=val;
				}	
			}
		}
	}
}
char s[maxn];
int main(){
	n=read();m=read();
	for(int i=1;i<=n;i++){
		cin>>s+1;
		for(int j=1;j<=m;j++){
			if(s[j]=='.')mp[i][j]=1,stx=i,sty=j;
			else if(s[j]=='*')mp[i][j]=0;
		}
	}
	bit[0]=1;
	for(int i=1;i<=12;i++)bit[i]=bit[i-1]<<2;
	solve();
	printf("%lld\n",ans);
	return 0;
}

簡單例題

[SCOI2011]地板

題意

\(L\) 型地板鋪滿非障礙格子的方案數,\(L\) 型格子不能是條形的。

解題報告

類似於求迴路的普通插頭DP,有以下幾點不同:

  1. 不需要使用上面的任何狀態表示法,因為 不需要記錄每一條線的連通情況 了。

  2. 由於一條 \(L\) 型地板只能拐一次彎,在插頭處記錄能不能拐彎。(\(1\) 表示可以拐彎,\(2\) 表示不能拐彎)

  3. 可以 在一條拐過彎的地板的任何時刻中止它,這也是最容易忘的。

新的體會(2021/8/6)

  • 插頭的真正含義是 相鄰的兩個格子可以連通,比如說中止當前地板時,就沒有往外走的插頭了。

  • 三進位制數在改狀態改的少的情況下同樣可以跑的飛快(除法對時間的影響小)。

換行時對於所有狀態左移一位的操作,需要 分維計算

for(int j=0;j<bit[m+1];j++)f[now^1][j]=0;
for(int j=0;j<bit[m];j++)f[now^1][j*3]=f[now][j];
now^=1;

否則會產生同一維之間的影響和一些奇怪錯誤。

插頭DP真的是細節非常多,主要是分討容易丟情況。

除錯時的技巧

看當前結點可以由哪些狀態轉移過來,或者是與哪些插頭相鄰,手玩列舉這些所有可能的情況,看看丟沒丟解。

cerr<<"pos: "<<i<<" "<<j<<" "<<s<<endl;//
cerr<<up<<" "<<left<<" "<<val<<endl;//

這是我除錯時的圖(第二個樣例),明顯對於 \((2,3)\) 這個點少了一種情況是:上插頭為 \(1\),這就是由於當時沒有考慮停止當前地板的情況。

三進位制寫法:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
inline T read(){
    T x=0;char ch=getchar();bool fl=false;
    while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();}
    while(isdigit(ch)){
        x=(x<<3)+(x<<1)+(ch^48);ch=getchar();
    }
    return fl?-x:x;
}
const int P = 20110520;
inline void Plus(int &x,int y){
	x+=y;
	if(x>=P)x-=P;
}
const int maxm = 2e5 + 10;
const int maxn = 105;
int f[2][maxm],bit[100],mp[maxn][maxn];
char s[maxn];
int n,m,now,last;
void solve(){
	f[0][0]=1;
	for(int i=1;i<=n;i++){
		for(int j=0;j<bit[m+1];j++)f[now^1][j]=0;
		for(int j=0;j<bit[m];j++)f[now^1][j*3]=f[now][j];
		now^=1;
		for(int j=1;j<=m;j++){
			last=now;now^=1;
			for(int s=0;s<bit[m+1];s++)f[now][s]=0;
			for(int s=0;s<bit[m+1];s++){
				if(!f[last][s])continue;
				int up=(s/bit[j])%3,left=(s/bit[j-1])%3;
				int val=f[last][s];
				if(!mp[i][j]){
					if(!up && !left)Plus(f[now][s],val);continue;
				}
				if(!up && !left){
					if(mp[i+1][j] && mp[i][j+1])Plus(f[now][s+2*bit[j]+2*bit[j-1]],val);
					if(mp[i+1][j])Plus(f[now][s+bit[j-1]],val);
					if(mp[i][j+1])Plus(f[now][s+bit[j]],val);
				}
				if(up && !left){
					if(up==1){
						if(mp[i][j+1])Plus(f[now][s-bit[j]+2*bit[j]],val);
					}
					if(up==2){
						Plus(f[now][s-2*bit[j]],val);
					}
					if(mp[i+1][j])Plus(f[now][s-up*bit[j]+up*bit[j-1]],val);
				}
				if(!up && left){
					if(left==1){
						if(mp[i+1][j])Plus(f[now][s-bit[j-1]+2*bit[j-1]],val);
					}
					if(left==2){
						Plus(f[now][s-2*bit[j-1]],val);
					}
					if(mp[i][j+1])Plus(f[now][s-left*bit[j-1]+left*bit[j]],val);
				}
				if(up==1 && left==1){
					Plus(f[now][s-bit[j]-bit[j-1]],val);
				}
			}
		}
	}
}
#define read() read<int>()
int main(){
	n=read();m=read();
	bool fl=(n<m);
	for(int i=1;i<=n;i++){
		cin>>s+1;
		for(int j=1;j<=m;j++){
			if(s[j]=='_')fl?(mp[j][i]=1):(mp[i][j]=1);
		}
	}
	if(fl)swap(n,m);
	bit[0]=1;
	for(int i=1;i<=m+1;i++)bit[i]=bit[i-1]*3;
	solve();
	printf("%d\n",f[now][0]);
	return 0;
}