1. 程式人生 > 實用技巧 >【HDU-5785】Interesting(迴文串的性質+迴文自動機+map空間優化)

【HDU-5785】Interesting(迴文串的性質+迴文自動機+map空間優化)

題目連結:https://vjudge.net/problem/HDU-5785

題目大意

給定一個字串,求有多少對三元組 \((i, j, k)\) 滿足 \(1≤i≤j<k≤|S|\),要求 \(S[i,...j]\)\(S[j+1, .. k]\) 都為迴文串,對 \(1e9+7\) 取模。

思路

是對迴文的端點進行計數,可上回文自動機。

由迴文樹的性質可知,假設字串的第 \(i\) 個點在迴文樹上的編號為 \(p_{i}\),那麼其在 \(fail\) 樹上的祖先為以第 \(i\) 個點為右端點的所有迴文串。

假設迴文樹的右端為 \(i\),其長度為 \(len_{p_{i}}\)

,那麼左端點就為 \(i-len_{p_{i}}+1\)

那麼其左端點之和就為 \((i-len_{p_{i}}+1)+(i-len_{fail[p_{i}]}+1)+(i-len_{fail[fail[p_{i}]]}+1)+...\)

每次在迴文樹上增加節點時,維護 \(len_{p_{i}}\) 之和,記為 \(sum_{p_{i}}\)

那麼則將式子轉化成 \((i+i+..+i) + (1+1+...+1) - sum_{p_{i}}\),其中 \((i+i+...+i) = num_{p_{i}}\) 即以這個點為右端點時迴文串個數,要求這個可以做洛谷模板題

從右向左同理,在轉換時需要注意小細節。

但是由於空間特別卡,需要用 \(map\) 來優化空間。

AC程式碼

#include <bits/stdc++.h>

typedef long long ll;
using namespace std;
const int MAXN = 1e6 + 5;
const int MAXC = 26;
const int mod = 1000000007;

int nm[MAXN][2];

char str[MAXN];

class PAM {
public:
    struct node {
        map<int, int> ch;// int ch[MAXC];
        int fail, len, num;
        ll sum;
    } T[MAXN];
    int las, tot;

    inline int get_fail(int x, int pos) {
        while (str[pos - T[x].len - 1] != str[pos]) {
            x = T[x].fail;
        }
        return x;
    }

    void init() {
        T[0].ch.clear(), T[1].ch.clear();
        // memset(T[0].ch, 0, sizeof(T[0].ch)), memset(T[1].ch, 0, sizeof(T[1].ch));
        T[0].fail = 1, T[1].fail = 0;
        T[0].len = 0, T[1].len = -1;
        T[0].num = T[1].num = T[0].sum = T[1].sum = 0;
        las = 0, tot = 1;
    }

    void insert1(char s[], int len) {
        s[0] = -1;
        for (int i = 1; i <= len; i++) {
            int p = get_fail(las, i);
            if (!T[p].ch[s[i]-'a']) {
                T[++tot].len = T[p].len + 2;
                T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
                int u = get_fail(T[p].fail, i);
                T[tot].fail = T[u].ch[s[i]-'a'];
                T[tot].num = T[T[tot].fail].num + 1;
                T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
                T[p].ch[s[i]-'a'] = tot;
            }
            las = T[p].ch[s[i]-'a'];
            nm[i][0] = ((ll) T[las].num * (i + 1) % mod - T[las].sum + mod) % mod;
        }
    }

    void insert2(char s[], int len) {
        s[0] = 0;
        for (int i = 1; i <= len; i++) {
            int p = get_fail(las, i);
            if (!T[p].ch[s[i]-'a']) {
                T[++tot].len = T[p].len + 2;
                T[tot].ch.clear();// memset(T[tot].ch, 0, sizeof(T[tot].ch));
                int u = get_fail(T[p].fail, i);
                T[tot].fail = T[u].ch[s[i]-'a'];
                T[tot].num = T[T[tot].fail].num + 1;
                T[tot].sum = (T[T[tot].fail].sum + T[tot].len) % mod;
                T[p].ch[s[i]-'a'] = tot;
            }
            las = T[p].ch[s[i]-'a'];
            int pos = len - i + 1;
            nm[pos][1] = ((ll) T[las].num * (pos - 1) % mod + T[las].sum) % mod;
        }
    }

} tree;


int main() {
    while (~scanf("%s", str + 1)) {
        int n = strlen(str + 1);
        tree.init();
        tree.insert1(str, n);

        int n2 = n >> 1;
        for (int i = 1; i <= n2; i++) swap(str[i], str[n-i+1]);

        tree.init();
        tree.insert2(str, n);

        ll res = 0;
        for (int i = 1; i < n; i++) {
            res = (res + (ll) nm[i][0] * nm[i + 1][1] % mod) % mod;
        }
        printf("%lld\n", res);
    }
}