1. 程式人生 > >Luogu P3181 [HAOI2016]找相同字符 廣義$SAM$

Luogu P3181 [HAOI2016]找相同字符 廣義$SAM$

nts can include bit clas spa 聲音 nod 相同字符

題目鏈接 \(Click\) \(Here\)

設一個串\(s\)\(A\)中出現\(cnt[s][1]\)次,在\(B\)中出現\(cnt[s][2]\)次,我們要求的就是:

\[\sum cnt[s][1]*cnt[s][2]\]

\(SAM\)這種把多個串用一個點表示的東西裏,答案就變成了這個

\[\sum cnt[s][1] * cnt[s][2] * (len[fa[s]]-len[s])\]

其中的\(cnt\)求法,聽說好像可以兩個串隔開求?但是我不太會。學了一下用廣義\(SAM\)的寫法,似乎是第一個串建完之後把\(las\)指針指回根節點,再建第二個就好。因為網上對於這種寫法各種聲音都有,所以我打算這周末認真學習\(SA\)
\(SAM\)後再詳細進行解釋說明或者算法更正。

\(p.s\)這種寫法下似乎不能以\(len\)桶排序求\(cnt\),因為\(len\)會有相等情況。所以我們要用\(Parent\) \(Tree\)\(DP\)來寫。

最後提醒:別忘\(long\) \(long\)

#include <bits/stdc++.h>
using namespace std;

const int N = 800010;
typedef long long ll;

ll tot[2][N];
int node = 1, las = 1;
int fa[N], len[N], ch[N][26];



void extend (int c, int id) {
    int p = las, q = ++node;
    len[q] = len[p] + 1, tot[id][q] = 1, las = q;
    while (p != 0 && ch[p][c] == 0) {
        ch[p][c] = q;
        p = fa[p];
    }
    if (p == 0) {
        fa[q] = 1;
    } else {
        int x = ch[p][c];
        if (len[x] == len[p] + 1) {
            fa[q] = x;
        } else {
            int y = ++node;
            len[y] = len[p] + 1;
            fa[y] = fa[x];
            fa[x] = fa[q] = y;
            memcpy (ch[y], ch[x], sizeof (ch[x]));
            while (p != 0 && ch[p][c] == x) {
                ch[p][c] = y;
                p = fa[p];
            }
        }
    }
}
    
int cnt, head[N];

struct edge {
    int nxt, to;
}e[N];

void add_edge (int from, int to) {
    e[++cnt].nxt = head[from];
    e[cnt].to = to;
    head[from] = cnt;
}

void dfs (int u) {
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        dfs (v);
        tot[0][u] += tot[0][v];
        tot[1][u] += tot[1][v];
    }
}

ll get_ans () {
    ll ans = 0;
    for (int i = 1; i <= node; ++i) add_edge (fa[i], i); dfs (1);
    for (int i = 1; i <= node; ++i) ans += 1LL * (len[i] - len[fa[i]]) * tot[0][i] * tot[1][i];
    return ans;
} 

int n1, n2;
char s1[N], s2[N];

int main () {
    scanf ("%s %s", s1, s2);
    n1 = strlen (s1), n2 = strlen (s2);
    for (int i = 0; i < n1; ++i) extend (s1[i] - 'a', 0);
    las = 1;
    for (int i = 0; i < n2; ++i) extend (s2[i] - 'a', 1);
    cout << get_ans () << endl;
} 

Luogu P3181 [HAOI2016]找相同字符 廣義$SAM$