1. 程式人生 > 其它 >插頭DP詳解

插頭DP詳解

前言

圖片來源:我老師的PDF因為我不會畫圖

前排膜拜一波(

插頭dp雖然模版難度就是黑,但是我認為並不難。我認為dp的難度排序:

ddp(動態DP,P4719)>分治dp>插頭dp>決策單調性優化dp>網路流(類dp)>斜率優化dp>暴力dp

模版/推薦題目

P5056 【模版】插頭dp

P2289 郵遞員

P2337 喵星人的入侵

目的

給定一個點集$S$,求圖上的一條線,要求包含$S$裡包括的所有點,並且不包含圖上非$S$集的點,且不能重複經過一個點。

可能不是很好理解(**不會看題嗎),比如下圖兩條線都滿足條件(白色的格子表示$S$集)。

問能畫出來多少種線。上圖情況答案就是2,因為除了這兩種畫不出別的了。

實現

如果沒做過輪廓線DP,可以先做一下P4363,那是很基礎的一道輪廓線DP題。

勾勒一個輪廓線,意義與P4363相同,即為已經dp過的點的下輪廓,並給輪廓線的每條邊加上插頭。

上圖為正在計算點(3,3)的時候的輪廓線。

但是隻記錄是否空插頭還是不夠的,因為一條線既可以是區間左端,也可以是右端。

如何儲存:狀壓括號序列即可。0表示空插頭,1表示向上,2表示向下。

正確性證明:很顯然,如果括號序列是交叉的,形如$\color{red}(\color{green}(\color{red})\color{green})$,那麼它是一個不合法的序列(重複經過)。

例:

上圖插頭們的括號序列是$1120212$。

為了方便,將上圖挨著正在考慮的點$5,5$的兩個插頭稱作左插頭(插頭5)和上插頭(插頭6)。

接下來,對於點$(i,j)$,分類討論情況。

1. 障礙:別動,跳過就行。
2. 上插頭和左插頭都是空:因為任何一個點都要經過,所以只能向下向右都連邊。
3. 只有上插頭空:下和右隨便連邊,因為已經連了一個了。連下、連右、連下&右都合法。(當然都不連不合法)
4. 只有左插頭空:同3,只是要改方向。
5. 左、上都是起始(狀態都是1):因為不能重複,這時候只能合併插頭。但是合併後其他插頭可能會有改變,具體實現詳見程式碼。

6. 左、上都是終結(狀態都是2):同5,只是要改方向。
7. 迴路:如果已經走完了(即迴路形成點在$(n,m)$),那麼加進答案,否則丟棄。

到這裡,插頭dp就已經可以實現了。

優化

1. 一次轉移顯然只和上一次的答案有關,因此可以用滾動陣列優化空間。
2. 優化兩個狀態之間的轉移:雜湊表。

程式碼

另外,雜湊建議一個質數,可以減少衝突。1e6左右我找的是1e6+3。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e6 + 3;
int n, m, ex, ey, bits[15], state[mod + 5][2];
bool mp[15][15];
ll f[mod + 5][2], ans;
int head[mod + 5], to[mod + 5], nxt[mod + 5], sz[2], qwq, tot;
inline void link(int u, int v)
{
    to[tot] = v;
    nxt[tot] = head[u];
    head[u] = tot++;
}
inline void add(int x, ll k)
{
    int key = x % mod;
    for (int i = head[key]; ~i; i = nxt[i])
        if (state[to[i]][qwq] == x)
        {
            f[to[i]][qwq] += k;
            return;
        }
    state[++sz[qwq]][qwq] = x;
    f[sz[qwq]][qwq] = k;
    link(key, sz[qwq]);
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= m; j++)
        {
            char c;
            cin >> c;
            if (c == '.')
            {
                mp[i][j] = true;
                ex = i, ey = j;
            }
        }
    for (int i = 1; i <= 12; i++)
        bits[i] = i << 1;
    sz[qwq] = 1;
    f[1][qwq] = 1, state[1][qwq] = 0;
    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= sz[qwq]; j++)
            state[j][qwq] <<= 2;
        for (int j = 1; j <= m; j++)
        {
            tot = 0;
            memset(head, -1, sizeof(head));
            qwq ^= 1;
            sz[qwq] = 0;
            for (int k = 1; k <= sz[qwq ^ 1]; k++)
            {
                int stt = state[k][qwq ^ 1];
                int up = (stt >> bits[j]) % 4, left = (stt >> bits[j - 1]) % 4;
                ll val = f[k][qwq ^ 1];
                if (!mp[i][j])
                    add(stt, val);
                else if (!up && !left)
                {
                    if (mp[i + 1][j] && mp[i][j + 1])
                        add(stt | 1 << bits[j - 1] | 2 << bits[j], val);
                }
                else if (left && !up)
                {
                    if (mp[i + 1][j])
                        add(stt, val);
                    if (mp[i][j + 1])
                        add(stt - left * (1 << bits[j - 1]) + left * (1 << bits[j]), val);
                }
                else if (!left && up)
                {
                    if (mp[i][j + 1])
                        add(stt, val);
                    if (mp[i + 1][j])
                        add(stt - up * (1 << bits[j]) + up * (1 << bits[j - 1]), val);
                }
                else if (left == 1 && up == 1)
                {
                    int cnt = 1;
                    for (int p = j + 1; p <= m; p++)
                    {
                        if ((stt >> bits[p]) % 4 == 1) // left plug
                            cnt++;
                        else if ((stt >> bits[p]) % 4 == 2) // right plug
                            cnt--;
                        if (!cnt) // matched
                        {
                            add(stt - (1 << bits[p]) - (1 << bits[j]) - (1 << bits[j - 1]), val);
                            break;
                        }
                    }
                }
                else if (left == 2 && up == 2)
                {
                    int cnt = 1;
                    for (int p = j - 2; p >= 0; p--)
                    {
                        if ((stt >> bits[p]) % 4 == 1) // left plug
                            cnt--;
                        if ((stt >> bits[p]) % 4 == 2) // right plug
                            cnt++;
                        if (!cnt) // matched
                        {
                            add(stt - (2 << bits[j]) - (2 << bits[j - 1]) + (1 << bits[p]), val);
                            break;
                        }
                    }
                }
                else if (left == 2 && up == 1)
                    add(stt ^ 2 << bits[j - 1] ^ 1 << bits[j], val);
                else if (left == 1 && up == 2 && i == ex && j == ey)
                    ans += val;
            }
        }
    }
    cout << ans;
    return 0;
}