POJ 2778 DNA Sequence【AC自動機+矩陣快速冪】
題意:給m個病毒字串,問長度為n的DNA片段有多少種沒有包含病毒串的。
首先解決這個問題:給定一個有向圖,問從A點恰好走k步(允許重複經過邊)到達B點的方案數mod p的值
把給定的圖轉為鄰接矩陣,即A(i,j)=1當且僅當存在一條邊i->j。令C=A*A,那麼C(i,j)=ΣA(i,k)*A(k,j),實際上就等於從點i到點j恰好經過2條邊的路徑數(列舉k為中轉點)。類似地,C*A的第i行第j列就表示從i到j經過3條邊的路徑數。同理,如果要求經過k步的路徑數,我們只需要二分求出A^k即可。
首先考慮長度為n的DNA可以由長度為n-1的加上一個字母得來,可以把長度為n的不含這些片段的DNA分成幾類,並建立遞推關係。
例如不含ATC,AAA,GGC,CT的DNA可以分成以下幾類:
字尾是?A的、字尾是AT的、字尾是AA的、字尾是?G的、字尾是GG的、字尾是?C的,以及字尾是??的(?表示其他任意字母,這些類都是不重疊的,也就是說?A可以是AC,AT,AG,而不能是AA)之所以分成這些類是為了建立它們之間的遞推關係,比如?A轉化到AT有1種方法(後面加個T),轉化到AA有1種方法(後面加個A),轉化到?G有1種方法(後面加個G),轉化到?C有1種方法(後面加個C),由於?A無論加哪個字母都能轉化到一種情況,所以?A不能轉化到??,也就是說我們可以求出每種情況與其他各種情況的轉化關係,剩下的方法數就是轉化到??的。最後我們要求的是從??轉化到所有情況的方法數。
這些轉化關係可以用AC自動機求出。由於AC自動機的fail指標可以指向與當前串字尾相同的最長字串尾節結點,那麼當前結點就可以轉化到這個結點的後繼結點。
上述做法的思想就是:從一個合法的字串字尾,轉移到另外一個合法的字串字尾。如果存在這種轉移,那麼就在這兩個狀態之間連線一條邊(建圖)。然後問題就轉化成了第一個提出的問題。
程式碼說明:
val值表示當前節點是否為病毒串(即字尾是否為病毒串)
建AC自動機的時候,將所有的失配節點的fail指標省略了,直接連線了一條邊(因此將失配邊和其他邊等同看待了),最終的Trie樹將變成一個有向圖。
對於自動機程式碼的細節可以參考劉汝佳的《訓練指南》。
注意自動機的兩個註釋。因為當前結點可能不是一個病毒串的終點,但是其失配邊指向的串是個病毒,說明當前串的字尾也是個病毒串。
#include <cstdio> #include <cstring> #include <algorithm> #include <vector> #include <iostream> #include <queue> using namespace std; typedef long long ll; struct AC_Automata { #define Nn 102 #define M 4 int ch[Nn][M], val[Nn], f[Nn], last[Nn], sz; void clear() { sz = 1; memset(ch[0], 0, sizeof(ch[0])); } int idx(char c) { if (c == 'A') return 0; if (c == 'C') return 1; if (c == 'T') return 2; return 3; } void insert(char s[], int v) { int u = 0; for (int i=0; s[i]; i++) { int c = idx(s[i]); if (!ch[u][c]) { memset(ch[sz], 0, sizeof(ch[sz])); val[sz] = 0; ch[u][c] = sz++; } u = ch[u][c]; } val[u] = 1; ///標記當前節點是病毒串 } void build() { queue<int> q; f[0] = 0; for (int c=0; c<M; c++) { int u = ch[0][c]; if (u) { f[u] = last[u] = 0; q.push(u); } } while (!q.empty()) { int r = q.front(); q.pop(); for (int c=0; c<M; c++) { int u = ch[r][c]; if (!u) { ch[r][c] = ch[f[r]][c]; val[r] = val[r] || val[f[r]]; ///如果失配邊指向的結點是病毒,那麼當前串的字尾也是病毒串 continue; } q.push(u); f[u] = ch[f[r]][c]; last[u] = val[f[u]] ? f[u] : last[f[u]]; } } } } ac; #define Mod 100000ll #define N 200 ll a[N][N]; int n; //c = a*b void Multi(ll a[][N], ll b[][N], ll c[][N]) { for (int i=0; i<n; i++) for (int j=0; j<n; j++) { c[i][j] = 0; for (int k=0; k<n; k++) c[i][j] = (c[i][j] + a[i][k]*b[k][j]) % Mod; } } //d = s void copy(ll d[][N], ll s[][N]) { for (int i=0; i<n; i++) for (int j=0; j<n; j++) d[i][j] = s[i][j]; } //a = a^b % Mod void PowerMod(ll a[][N], ll b) { ll t[N][N], ret[N][N]; for (int i=0; i<n; i++) ret[i][i] = 1; while (b) { if (b & 1) { Multi(ret, a, t); copy(ret, t); } Multi(a, a, t); copy(a, t); b >>= 1; } copy(a, ret); } void init() { n = ac.sz; int u; memset(a, 0, sizeof(a)); for (int i=0; i<n; i++) if (!ac.val[i]) { for (int j=0; j<4; j++) { u = ac.ch[i][j]; if (!ac.val[u]) a[i][u]++; } } } int main() { char s[12]; int m; ll b; while (scanf("%d %lld", &m, &b) == 2) { ac.clear(); for (int i=1; i<=m; i++) { scanf(" %s", s); ac.insert(s, i); } ac.build(); init(); PowerMod(a, b); ll sum = 0; for (int i=0; i<n; i++) sum = (sum + a[0][i]) % Mod; cout << sum << endl; } return 0; }