插頭DP詳解
前言
圖片來源:我老師的PDF因為我不會畫圖
前排膜拜一波(
插頭dp雖然模版難度就是黑,但是我認為並不難。我認為dp的難度排序:
ddp(動態DP,P4719)>分治dp>插頭dp>決策單調性優化dp>網路流(類dp)>斜率優化dp>暴力dp
模版/推薦題目
P5056 【模版】插頭dp
P2289 郵遞員
P2337 喵星人的入侵
目的
給定一個點集$S$,求圖上的一條線,要求包含$S$裡包括的所有點,並且不包含圖上非$S$集的點,且不能重複經過一個點。
可能不是很好理解(**不會看題嗎),比如下圖兩條線都滿足條件(白色的格子表示$S$集)。
問能畫出來多少種線。上圖情況答案就是2,因為除了這兩種畫不出別的了。
實現
如果沒做過輪廓線DP,可以先做一下P4363,那是很基礎的一道輪廓線DP題。
勾勒一個輪廓線,意義與P4363相同,即為已經dp過的點的下輪廓,並給輪廓線的每條邊加上插頭。
上圖為正在計算點(3,3)的時候的輪廓線。
但是隻記錄是否空插頭還是不夠的,因為一條線既可以是區間左端,也可以是右端。
如何儲存:狀壓括號序列即可。0表示空插頭,1表示向上,2表示向下。
正確性證明:很顯然,如果括號序列是交叉的,形如$\color{red}(\color{green}(\color{red})\color{green})$,那麼它是一個不合法的序列(重複經過)。
例:
上圖插頭們的括號序列是$1120212$。
為了方便,將上圖挨著正在考慮的點$5,5$的兩個插頭稱作左插頭(插頭5)和上插頭(插頭6)。
接下來,對於點$(i,j)$,分類討論情況。
1. 障礙:別動,跳過就行。
2. 上插頭和左插頭都是空:因為任何一個點都要經過,所以只能向下向右都連邊。
3. 只有上插頭空:下和右隨便連邊,因為已經連了一個了。連下、連右、連下&右都合法。(當然都不連不合法)
4. 只有左插頭空:同3,只是要改方向。
5. 左、上都是起始(狀態都是1):因為不能重複,這時候只能合併插頭。但是合併後其他插頭可能會有改變,具體實現詳見程式碼。
6. 左、上都是終結(狀態都是2):同5,只是要改方向。
7. 迴路:如果已經走完了(即迴路形成點在$(n,m)$),那麼加進答案,否則丟棄。
到這裡,插頭dp就已經可以實現了。
優化
1. 一次轉移顯然只和上一次的答案有關,因此可以用滾動陣列優化空間。
2. 優化兩個狀態之間的轉移:雜湊表。
程式碼
另外,雜湊建議膜一個質數,可以減少衝突。1e6左右我找的是1e6+3。
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 1e6 + 3; int n, m, ex, ey, bits[15], state[mod + 5][2]; bool mp[15][15]; ll f[mod + 5][2], ans; int head[mod + 5], to[mod + 5], nxt[mod + 5], sz[2], qwq, tot; inline void link(int u, int v) { to[tot] = v; nxt[tot] = head[u]; head[u] = tot++; } inline void add(int x, ll k) { int key = x % mod; for (int i = head[key]; ~i; i = nxt[i]) if (state[to[i]][qwq] == x) { f[to[i]][qwq] += k; return; } state[++sz[qwq]][qwq] = x; f[sz[qwq]][qwq] = k; link(key, sz[qwq]); } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> n >> m; for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { char c; cin >> c; if (c == '.') { mp[i][j] = true; ex = i, ey = j; } } for (int i = 1; i <= 12; i++) bits[i] = i << 1; sz[qwq] = 1; f[1][qwq] = 1, state[1][qwq] = 0; for (int i = 1; i <= n; i++) { for (int j = 1; j <= sz[qwq]; j++) state[j][qwq] <<= 2; for (int j = 1; j <= m; j++) { tot = 0; memset(head, -1, sizeof(head)); qwq ^= 1; sz[qwq] = 0; for (int k = 1; k <= sz[qwq ^ 1]; k++) { int stt = state[k][qwq ^ 1]; int up = (stt >> bits[j]) % 4, left = (stt >> bits[j - 1]) % 4; ll val = f[k][qwq ^ 1]; if (!mp[i][j]) add(stt, val); else if (!up && !left) { if (mp[i + 1][j] && mp[i][j + 1]) add(stt | 1 << bits[j - 1] | 2 << bits[j], val); } else if (left && !up) { if (mp[i + 1][j]) add(stt, val); if (mp[i][j + 1]) add(stt - left * (1 << bits[j - 1]) + left * (1 << bits[j]), val); } else if (!left && up) { if (mp[i][j + 1]) add(stt, val); if (mp[i + 1][j]) add(stt - up * (1 << bits[j]) + up * (1 << bits[j - 1]), val); } else if (left == 1 && up == 1) { int cnt = 1; for (int p = j + 1; p <= m; p++) { if ((stt >> bits[p]) % 4 == 1) // left plug cnt++; else if ((stt >> bits[p]) % 4 == 2) // right plug cnt--; if (!cnt) // matched { add(stt - (1 << bits[p]) - (1 << bits[j]) - (1 << bits[j - 1]), val); break; } } } else if (left == 2 && up == 2) { int cnt = 1; for (int p = j - 2; p >= 0; p--) { if ((stt >> bits[p]) % 4 == 1) // left plug cnt--; if ((stt >> bits[p]) % 4 == 2) // right plug cnt++; if (!cnt) // matched { add(stt - (2 << bits[j]) - (2 << bits[j - 1]) + (1 << bits[p]), val); break; } } } else if (left == 2 && up == 1) add(stt ^ 2 << bits[j - 1] ^ 1 << bits[j], val); else if (left == 1 && up == 2 && i == ex && j == ey) ans += val; } } } cout << ans; return 0; }