【HDU-5785】Interesting(迴文串的性質+迴文自動機+map空間優化)
阿新 • • 發佈:2020-12-05
題目連結:https://vjudge.net/problem/HDU-5785
題目大意
給定一個字串,求有多少對三元組 \((i, j, k)\) 滿足 \(1≤i≤j<k≤|S|\),要求 \(S[i,...j]\) 和 \(S[j+1, .. k]\) 都為迴文串,對 \(1e9+7\) 取模。
思路
是對迴文的端點進行計數,可上回文自動機。
由迴文樹的性質可知,假設字串的第 \(i\) 個點在迴文樹上的編號為 \(p_{i}\),那麼其在 \(fail\) 樹上的祖先為以第 \(i\) 個點為右端點的所有迴文串。
假設迴文樹的右端為 \(i\),其長度為 \(len_{p_{i}}\)
那麼其左端點之和就為 \((i-len_{p_{i}}+1)+(i-len_{fail[p_{i}]}+1)+(i-len_{fail[fail[p_{i}]]}+1)+...\)
每次在迴文樹上增加節點時,維護 \(len_{p_{i}}\) 之和,記為 \(sum_{p_{i}}\)。
那麼則將式子轉化成 \((i+i+..+i) + (1+1+...+1) - sum_{p_{i}}\),其中 \((i+i+...+i) = num_{p_{i}}\) 即以這個點為右端點時迴文串個數,要求這個可以做洛谷模板題。
從右向左同理,在轉換時需要注意小細節。
但是由於空間特別卡,需要用 \(map\) 來優化空間。
AC程式碼
#include <bits/stdc++.h> typedef long long ll; using namespace std; const int MAXN = 1e6 + 5; const int MAXC = 26; const int mod = 1000000007; int nm[MAXN][2]; char str[MAXN]; class PAM { public: struct node { map<int, int> ch;// int ch[MAXC]; int fail, len, num; ll sum; } T[MAXN]; int las, tot; inline int get_fail(int x, int pos) { while (str[pos - T[x].len - 1] != str[pos]) { x = T[x].fail; } return x; } void init() { T[0].ch.clear(), T[1].ch.clear(); // memset(T[0].ch, 0, sizeof(T[0].ch)), memset(T[1].ch, 0, sizeof(T[1].ch)); T[0].fail = 1, T[1].fail = 0; T[0].len = 0, T[1].len = -1; T[0].num = T[1].num = T[0].sum = T[1].sum = 0; las = 0, tot = 1; } void insert1(char s[], int len) { s[0] = -1; for (int i = 1; i <= len; i++) { int p = get_fail(las, i); if (!T[p].ch[s[i]-'a']) { T[++tot].len = T[p].len + 2; T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch)); int u = get_fail(T[p].fail, i); T[tot].fail = T[u].ch[s[i]-'a']; T[tot].num = T[T[tot].fail].num + 1; T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod; T[p].ch[s[i]-'a'] = tot; } las = T[p].ch[s[i]-'a']; nm[i][0] = ((ll) T[las].num * (i + 1) % mod - T[las].sum + mod) % mod; } } void insert2(char s[], int len) { s[0] = 0; for (int i = 1; i <= len; i++) { int p = get_fail(las, i); if (!T[p].ch[s[i]-'a']) { T[++tot].len = T[p].len + 2; T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch)); int u = get_fail(T[p].fail, i); T[tot].fail = T[u].ch[s[i]-'a']; T[tot].num = T[T[tot].fail].num + 1; T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod; T[p].ch[s[i]-'a'] = tot; } las = T[p].ch[s[i]-'a']; int pos = len - i + 1; nm[pos][1] = ((ll) T[las].num * (pos - 1) % mod + T[las].sum) % mod; } } } tree; int main() { while (~scanf("%s", str + 1)) { int n = strlen(str + 1); tree.init(); tree.insert1(str, n); int n2 = n >> 1; for (int i = 1; i <= n2; i++) swap(str[i], str[n-i+1]); tree.init(); tree.insert2(str, n); ll res = 0; for (int i = 1; i < n; i++) { res = (res + (ll) nm[i][0] * nm[i + 1][1] % mod) % mod; } printf("%lld\n", res); } }