1. 程式人生 > 其它 >The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange Permutations 題解

The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange Permutations 題解


title: >-
The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange
Permutations
date: 2021-12-13 15:27:09
tags: [inclusion-exclusion, combinatorics, math, FFT, team training, merge]

題意

給一個全排列 \(P\) ,計算構造全排列 \(Q\) 使得 \(\forall i \in \{1, 2, \cdots, n - 1\}, Q_{i+1} \neq P_{Q_i}\) 的方案數

思路

抽象題意:取編號 \(1\) ~ \(n\) 的點出來,每個點上有一個值,表示不能連出的邊,計算所有經過且僅經過 \(1\) 次每個頂點的有向路徑(哈密頓路徑)的方案數

(圖中橙色邊表示不可連,藍色邊表示可連)

考慮容斥,列舉破壞 \(i\) 個條件(有 \(i\) 橙色邊)

由於是全排列,所以必然有若干個圈(含自環),每個 \(k\) 元環可貢獻 \(0\) ~ \(k-1\) 個橙色邊(因為每個點只經過一次,不可能形成迴路),則貢獻的生成函式為

\[\begin{aligned} &1 + C_k^1 \cdot x + C_k^2 \cdot x^2 + \cdots + C_k^{k-1} \cdot x^{k-1} \\ =\ & (1 + x) ^k - x^k \end{aligned} \]

然後找出所有環,啟發式合併這些多項式即可

複雜度 \(O(n\ log^2\ n)\)

程式碼

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define i64 long long
#define poly std::vector<int>
// dont visit a[m] when a.size() <= m
// (a = fastpow(c,n-m+1,m+1)).resize(m+1);
// i64 res = a[m] - b[m];
// (b = fastpow(d,n-m+1,m+1)).resize(m+1);

constexpr int MOD = 998244353;

namespace Poly { // remember to resize
    const int N = (1 << 21), g = 3;
    inline int power(int x, int p) {
        int res = 1;
        for (; p; p >>= 1, x = (ll)x * x % MOD) 
            if (p & 1)
                res = (ll)res * x % MOD;
        return res;
    }
    inline int fix(const int x) { return x >= MOD ? x - MOD : x; }
    void dft(poly& A, int n) {
        static ull W[N << 1], *H[30], *las = W, mx = 0;
        for (; mx < n; mx++) {
            H[mx] = las;
            ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1));
            for(int i=0;i<1<<n;++i) *las++ = w, w = w * wn % MOD;
        }
        if (A.size() != (1 << n))
            A.resize(1 << n);
        static ull a[N];
        for (int i = 0, j = 0; i < (1 << n); ++i) {
            a[i] = A[j];
            for (int k = 1 << (n - 1); (j ^= k) < k; k >>= 1);
        }
        for (int k = 0, d = 1; k < n; k++, d <<= 1)
            for (int i = 0; i < (1 << n); i += (d << 1)) {
                ull *l = a + i, *r = a + i + d, *w = H[k], t;
                for (int j = 0; j < d; j++, l ++, r++) {
                    t = (*r) * (*w++) % MOD;
                    *r = *l + MOD - t, *l += t;
                }
            }
        for(int i=0;i<1<<n;++i) A[i] = a[i] % MOD;
    }
 
    void idft(poly &a, int n) {
        a.resize(1 << n), reverse(a.begin() + 1, a.end());
        dft(a, n);
        int inv = power(1 << n, MOD - 2);
        for(int i=0;i<1<<n;++i) a[i] = (ll)a[i] * inv % MOD;
    }
 
    poly FIX(poly a) {
        while (!a.empty() && !a.back()) a.pop_back();
        return a;
    }

    // remember to resize
    poly mul(poly a, poly b, int t = 1) {
        if (t == 1 && a.size() + b.size() <= 24) {
            poly c(a.size() + b.size(), 0);
            for(int i=0;i<a.size();++i) for(int j=0;j<b.size();++j) c[i + j] = (c[i + j] + (ll)a[i] * b[j]) % MOD;
            return FIX(c);
        }
        int n = 1, aim = a.size() * t + b.size();
        while ((1<<n) <= aim) n++;
        dft(a, n); dft(b, n);
        if (t == 1)
            for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * b[i] % MOD;
        else
            for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * a[i] % MOD * b[i] % MOD;
        idft(a, n); a.resize(aim);
        return FIX(a);
    }
 
    int Merge(std::vector<poly>&a) { // return index
        std::priority_queue<std::pair<int,int> > H; // <-size, index>
        int n = a.size();
        for(int i=0;i<n;++i) {
            H.emplace(-a[i].size(), i);
        }

        while(H.size()>=2) {
            int o1 = H.top().second; H.pop();
            int o2 = H.top().second; H.pop();
            poly res = mul(a[o1], a[o2]);
            a[o1].clear(); a[o2].clear();
            for(int i=0;i<res.size();++i) a[o1].push_back(res[i]);
            H.emplace(-a[o1].size(), o1);
        }

        return H.top().second; // index
    }
};

void norm(int&x) {
    if(x>=MOD) x -= MOD;
    if(x<0) x += MOD;
}

int mul(int a,int b) {
    return 1ll * a * b % MOD;
}

int main(int argc, char const *argv[])
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr); std::cout.tie(nullptr);

    int n;
    std::cin >> n;
    std::vector<int> p(n);
    for(int i=0;i<n;++i) {
        std::cin >> p[i];
        --p[i];
    }

    std::vector<int> vis(n, false); // bool
    int circles = 0;
    std::vector<int> cnt;
    for(int i=0;i<n;++i) {
        if(!vis[i]) {
            cnt.push_back(0);
            for(int j=i;!vis[j];j=p[j]) {
                vis[j] = true;
                ++cnt[circles];
            }
            ++circles;
        }
    }
    std::vector<poly> ps(circles, poly());

    std::vector<int> fac(n+1),ifac(n+1),inv(n+1);
    fac[0] = fac[1] = ifac[0] = ifac[1] = inv[0] = inv[1] = 1;
    for(int i=2;i<=n;++i) {
        fac[i] = mul(i, fac[i - 1]);
        inv[i] = mul(inv[MOD % i], MOD - MOD/i);
        ifac[i] = mul(inv[i], ifac[i - 1]);
    }

    auto C = [&](int n, int m) {
        return mul( fac[n], mul(ifac[m], ifac[n - m]) );
    };

    for(int i=0;i<circles;++i) {
        poly &thiz = ps[i];
        thiz.resize(cnt[i]);
        for(int j=0;j<cnt[i];++j) {
            thiz[j] = C(cnt[i],j);
        }
    }

    int thiz = Poly::Merge(ps);
    poly &ans = ps[thiz];
    ans.resize(n+1);

    int res = 0;

    for(int i=0;i<=n;++i) {
        int thiz = mul(ans[i], fac[n - i]);
        norm(
            res += ((i & 1) ? MOD - thiz : thiz)
        );
    }
    
    std::cout << res;

    return 0;
}
Living with bustle, hearing of isolation.