[模板] 插頭 DP
[模板] 插頭DP——從入門到入墳
模板是插頭DP的入門題,詢問 帶障礙網格中的合法迴路個數。
概念類
- 棋盤模型問題:採用逐行,逐列,逐格的狀態轉移方式。
對於此題,逐格轉移是最快的。
插頭:整個 DP 中的核心。
定義
對於一個四聯通問題來說,一個格子通常有上下左右四個插頭,一個格子再一個方向上的插頭定義為 該格子可以通過這個方向與外界連通。
比如說,插頭 \(1\) 是上面黃底色格子的下插頭,是下面格子的上插頭。
如果逐行DP的話,第 \(i\) 行的所有下插頭會成為第 \(i+1\) 行的所有上插頭。
狀態表示法
通過狀態表示法,我們可以把輪廓線處的插頭狀態表示為一個 \(p\)
- 最小表示法。
最小表示法,也就是用不同的數字來表示 輪廓線以上(已知部分) 不同的插頭所處的線路。
比如上圖輪廓線處狀壓起來就是 \((1,1,2,2)\),至於最小表示類似於 字串的最小表示,採用四進位制儲存以加快運算。
- 括號序列表示法。
對於一對連通的插頭,用數對 \((1,2)\) 來表示,用 \(1\) 來代替左括號 \((\),\(2\) 來代替右括號 \()\) 。
類似於括號序列,所以把它叫做括號序列表示法。
比如上圖狀壓起來就是 \((1,2,1,2)\) 。
所有的狀態表示法,都基於對線路的唯一表示,防止衝突。只要線路與狀態表示一一對應,就可以是一種狀態表示法。
輪廓線狀壓轉移
通過狀態表示法,根據輪廓線插頭對當前格 \((i,j)\) 的影響 考慮幾類可能的轉移。
所謂輪廓線狀壓,也就是隻關注輪廓線處的一些插頭對逐格轉移過程中當前格的影響。
根據格子所處位置附近連通塊的連通情況,可以分為三類。
- 新建一個連通塊。
當且僅當輪廓線處不存在向右或者向下的插頭,\((i,j)\) 提供向下和向右的插頭。
- 連線兩個已有連通塊。
輪廓線處有向右和向下的插頭,\((i,j)\) 分別通過向左和向上的插頭把它們連線起來。
- 接上之前的連通塊。
也就是輪廓線處只有一個向右或向下的插頭,\((i,j)\) 分別提供向左或向上的插頭。
那麼設 \(f(i,j,S)\)
程式碼實現
- 採用滾動陣列,對於一個 \((i,j)\) 維護一維 \(f\)。
- 用雜湊表把 \(S\) 對映為序列數字,維護上一個點 \((i,j-1)\) 的所有狀態資訊,用掛鏈法處理雜湊衝突。
- 採用刷表法,用雜湊表維護 DP 過程。
上面討論過了按照連通性分類,下面按照 \(up\) 和 \(left\) 插頭的情況進行分類。
-
當前為障礙點,必須保證兩個插頭均為空才可加入決策集合。
-
當前不是障礙點。
-
兩個插頭都沒有:考慮新增一個連通塊,括號序列改變。
-
只有上面過來的插頭:選擇向右走或者向下走,注意:插頭序列均改變。
-
只有左面過來的插頭:選擇向右走或者向下走,插頭序列均改變。
-
上面和左面過來的插頭都是 \(1\),也就是都是左括號:向右找到第一個能和當前左括號匹配上的右括號的位置,計算插頭貢獻,插頭序列改變。
-
上面和左面過來的插頭都是 \(2\),也就是都是右括號:向左找到第一個能和當前右括號匹配上的左括號的位置,計算插頭貢獻,插頭序列改變。
-
上面過來的插頭是 \(1\),左面過來的插頭是 \(2\),直接連起來即可。
-
上面過來的插頭是 \(2\),左面過來的插頭是 \(1\),說明形成了迴路,在最後遍歷到的非障礙點把答案加上,其它情況不管,這保證了最後一定是一種方案對應一條迴路。
-
狀壓插頭序列時的細節問題
對應於這一行:
for(int j=1;j<=tot[now];j++)a[now][j]<<=2;//詳細分析
這就涉及到了當前 \(i\) 這一行的狀壓問題。
可以發現,一條豎線( \((i,j)\) 左側的邊)將這一行分成了兩部分,其中這一豎線佔據第 \(j-1\) 位,這條豎線之前的第 \(t\) 列由於它都佔據了第 \(t-1\) 位 ,而豎線右側的第 \(t\) 列都佔據了第 \(t\) 位。這是 \((i,j)\) 被刷到之前,也就是轉移之前。
轉移之後,豎線變為佔據第 \(j\) 位,而之前的第 \(j\) 列佔據了第 \(j-1\) 列,也就是說每次取到上面插頭都是第 \(j\) 列,對於 上一行結束的狀態序列,第 \(j\) 列表示的還是第 \(j-1\) 列,所以需要左移一位,這僅僅在每一行開始的時候做這個事情。
程式碼
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
inline T read(){
T x=0;char ch=getchar();bool fl=false;
while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);ch=getchar();
}
return fl?-x:x;
}
const int maxn = 20 , P = 590027;
const int maxm = 6e5 + 10;
int n,m,mp[maxn][maxn],stx,sty;
#define LL long long
#define read() read<int>()
LL ans,f[2][maxm];
int head[maxm],tot[2],now,last;
int nxt[maxm],a[2][maxm];
int bit[28];
void Hash(int sta,LL val){//輔助維護上一階段資訊的雜湊表
int t=sta%P+1;
for(int i=head[t];i;i=nxt[i]){
if(a[now][i]==sta)return f[now][i]+=val,void();
}
nxt[++tot[now]]=head[t];head[t]=tot[now];
a[now][tot[now]]=sta;f[now][tot[now]]=val;
}
void solve(){
tot[now]=1;a[now][1]=0;f[now][1]=1;
for(int i=1;i<=n;i++){
for(int j=1;j<=tot[now];j++)a[now][j]<<=2;//詳細分析
for(int j=1;j<=m;j++){
last=now;now^=1;tot[now]=0;
memset(head,0,sizeof head);
for(int k=1;k<=tot[last];k++){
int sta=a[last][k],up=(sta>>(j*2))%4,left=(sta>>(j*2-2))%4;
LL val=f[last][k];
if(!mp[i][j]){if(!up && !left)Hash(sta,val);}//裡面判不判都一樣,保證合法即可
else if(!up && !left){
if(mp[i+1][j] && mp[i][j+1])Hash(sta+bit[j-1]+2*bit[j],val);//新開一個,1,0
}
else if(up && !left){
if(mp[i+1][j])Hash(sta-bit[j]*up+bit[j-1]*up,val);//go down
if(mp[i][j+1])Hash(sta-bit[j]*up+bit[j]*up,val);//go right
}
else if(!up && left){
if(mp[i][j+1])Hash(sta-bit[j-1]*left+bit[j]*left,val);//go right
if(mp[i+1][j])Hash(sta-bit[j-1]*left+bit[j-1]*left,val);//go down
}
else if(up==1 && left==1){//找到第一個匹配的右括號
int sz=1;
for(int t=j+1;t<=m;t++){
if((sta>>(t*2))%4==1)sz++;
if((sta>>(t*2))%4==2)sz--;
if(!sz){
Hash(sta-bit[j]-bit[j-1]-bit[t],val);//右括號->左括號
break;
}
}
}
else if(up==2 && left==2){//找到第一個匹配的左括號
int sz=1;
for(int t=j-2;t>=0;t--){//t-1 --> t(真實)
if((sta>>(t*2))%4==1)sz--;
if((sta>>(t*2))%4==2)sz++;
if(!sz){
Hash(sta-2*bit[j]-2*bit[j-1]+bit[t],val);//左括號->右括號
break;
}
}
}
else if(up==1 && left==2)Hash(sta-2*bit[j-1]-bit[j],val);
else if(up==2 && left==1){
if(i==stx && j==sty)ans+=val;
}
}
}
}
}
char s[maxn];
int main(){
n=read();m=read();
for(int i=1;i<=n;i++){
cin>>s+1;
for(int j=1;j<=m;j++){
if(s[j]=='.')mp[i][j]=1,stx=i,sty=j;
else if(s[j]=='*')mp[i][j]=0;
}
}
bit[0]=1;
for(int i=1;i<=12;i++)bit[i]=bit[i-1]<<2;
solve();
printf("%lld\n",ans);
return 0;
}
簡單例題
題意
用 \(L\) 型地板鋪滿非障礙格子的方案數,\(L\) 型格子不能是條形的。
解題報告
類似於求迴路的普通插頭DP,有以下幾點不同:
-
不需要使用上面的任何狀態表示法,因為 不需要記錄每一條線的連通情況 了。
-
由於一條 \(L\) 型地板只能拐一次彎,在插頭處記錄能不能拐彎。(\(1\) 表示可以拐彎,\(2\) 表示不能拐彎)
-
可以 在一條拐過彎的地板的任何時刻中止它,這也是最容易忘的。
新的體會(2021/8/6)
-
插頭的真正含義是 相鄰的兩個格子可以連通,比如說中止當前地板時,就沒有往外走的插頭了。
-
三進位制數在改狀態改的少的情況下同樣可以跑的飛快(除法對時間的影響小)。
換行時對於所有狀態左移一位的操作,需要 分維計算。
for(int j=0;j<bit[m+1];j++)f[now^1][j]=0;
for(int j=0;j<bit[m];j++)f[now^1][j*3]=f[now][j];
now^=1;
否則會產生同一維之間的影響和一些奇怪錯誤。
插頭DP真的是細節非常多,主要是分討容易丟情況。
除錯時的技巧
看當前結點可以由哪些狀態轉移過來,或者是與哪些插頭相鄰,手玩列舉這些所有可能的情況,看看丟沒丟解。
cerr<<"pos: "<<i<<" "<<j<<" "<<s<<endl;//
cerr<<up<<" "<<left<<" "<<val<<endl;//
這是我除錯時的圖(第二個樣例),明顯對於 \((2,3)\) 這個點少了一種情況是:上插頭為 \(1\),這就是由於當時沒有考慮停止當前地板的情況。
三進位制寫法:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
inline T read(){
T x=0;char ch=getchar();bool fl=false;
while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);ch=getchar();
}
return fl?-x:x;
}
const int P = 20110520;
inline void Plus(int &x,int y){
x+=y;
if(x>=P)x-=P;
}
const int maxm = 2e5 + 10;
const int maxn = 105;
int f[2][maxm],bit[100],mp[maxn][maxn];
char s[maxn];
int n,m,now,last;
void solve(){
f[0][0]=1;
for(int i=1;i<=n;i++){
for(int j=0;j<bit[m+1];j++)f[now^1][j]=0;
for(int j=0;j<bit[m];j++)f[now^1][j*3]=f[now][j];
now^=1;
for(int j=1;j<=m;j++){
last=now;now^=1;
for(int s=0;s<bit[m+1];s++)f[now][s]=0;
for(int s=0;s<bit[m+1];s++){
if(!f[last][s])continue;
int up=(s/bit[j])%3,left=(s/bit[j-1])%3;
int val=f[last][s];
if(!mp[i][j]){
if(!up && !left)Plus(f[now][s],val);continue;
}
if(!up && !left){
if(mp[i+1][j] && mp[i][j+1])Plus(f[now][s+2*bit[j]+2*bit[j-1]],val);
if(mp[i+1][j])Plus(f[now][s+bit[j-1]],val);
if(mp[i][j+1])Plus(f[now][s+bit[j]],val);
}
if(up && !left){
if(up==1){
if(mp[i][j+1])Plus(f[now][s-bit[j]+2*bit[j]],val);
}
if(up==2){
Plus(f[now][s-2*bit[j]],val);
}
if(mp[i+1][j])Plus(f[now][s-up*bit[j]+up*bit[j-1]],val);
}
if(!up && left){
if(left==1){
if(mp[i+1][j])Plus(f[now][s-bit[j-1]+2*bit[j-1]],val);
}
if(left==2){
Plus(f[now][s-2*bit[j-1]],val);
}
if(mp[i][j+1])Plus(f[now][s-left*bit[j-1]+left*bit[j]],val);
}
if(up==1 && left==1){
Plus(f[now][s-bit[j]-bit[j-1]],val);
}
}
}
}
}
#define read() read<int>()
int main(){
n=read();m=read();
bool fl=(n<m);
for(int i=1;i<=n;i++){
cin>>s+1;
for(int j=1;j<=m;j++){
if(s[j]=='_')fl?(mp[j][i]=1):(mp[i][j]=1);
}
}
if(fl)swap(n,m);
bit[0]=1;
for(int i=1;i<=m+1;i++)bit[i]=bit[i-1]*3;
solve();
printf("%d\n",f[now][0]);
return 0;
}