1. 程式人生 > >UOJ #348 州區劃分 —— 狀壓DP+子集卷積

UOJ #348 州區劃分 —— 狀壓DP+子集卷積

題目:http://uoj.ac/problem/348

一開始可以 3^n 子集DP,列舉一種狀態的最後一個集合是什麼來轉移;

設 \( f[s] \) 表示 \( s \) 集合內的點都劃分好了,\( g[s] = \sum\limits_{i \in s} w[i] \)

那麼 \( f[s] = \sum\limits_{d \subseteq s} \frac{f[s-d] * g[d]}{g[s]} \)

注意判斷一個集合是否合法,不僅要判斷每個點的度數,還要判斷整個集合是否連通;

這樣就可以過 n <= 15 的點了,UOJ上有30分;

#include<cstdio>
#include
<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=(1<<21)+5,xxn=25,xm=505,mod=998244353; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0'
,ch=getchar(); return f?ret:-ret; } int n,m,p,f[xn],g[xn],w[xxn],g2[xn]; int hd[xxn],ct,to[xm],nxt[xm],bin[xxn]; void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} ll pw(ll a,int b){ll ret=1; for(;b;b>>=1,a=a*a%mod)if
(b&1)ret=ret*a%mod; return ret;} bool vis[xxn]; int dfs(int x,int s) { vis[x]=1; int ret=1; for(int i=hd[x],u;i;i=nxt[i]) if(!vis[u=to[i]]&&(s&bin[u-1]))ret+=dfs(u,s); return ret; } bool ck(int s)// { int cnt=0; for(int x=1;x<=n;x++) { if(!(s&bin[x-1]))continue; int deg=0; cnt++; for(int i=hd[x];i;i=nxt[i]) { if(s&bin[to[i]-1])deg++; } if(deg&1)return 1; } for(int i=1;i<=n;i++)vis[i]=0; for(int i=1;i<=n;i++) if(s&bin[i-1])return dfs(i,s)!=cnt; } int main() { n=rd(); m=rd(); p=rd(); bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]*2; for(int i=1,x,y;i<=m;i++)x=rd(),y=rd(),add(x,y),add(y,x); for(int i=1;i<=n;i++)g[bin[i-1]]=rd(); for(int s=0;s<bin[n];s++)g[s]=upt(g[s&(-s)]+g[s^(s&(-s))]); for(int s=0;s<bin[n];s++)g[s]=pw(g[s],p),g2[s]=pw(g[s],mod-2); for(int s=0;s<bin[n];s++)if(!ck(s))g[s]=0; int num=0; f[0]=1; for(int s=1;s<bin[n];s++) { for(int d=s;d;d=(s&(d-1)))//d=s f[s]=(f[s]+(ll)f[s^d]*g[d])%mod; f[s]=(ll)f[s]*g2[s]%mod; } printf("%d\n",f[bin[n]-1]); return 0; }
3^n

關於FMT(其實和高維字首和差不多)和子集卷積:https://www.cnblogs.com/Dance-Of-Faith/p/8818211.html

於是可以做子集卷積加速DP的過程。

程式碼如下:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(1<<21)+5,xxn=25,xm=505,mod=998244353;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int n,m,p,f[xxn][xn],g[xxn][xn],w[xxn],g2[xn];
int hd[xxn],ct,to[xm],nxt[xm],bin[xxn],cnt[xn];
void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
ll pw(ll a,int b){ll ret=1; for(;b;b>>=1,a=a*a%mod)if(b&1)ret=ret*a%mod; return ret;}
bool vis[xxn];
int dfs(int x,int s)
{
  vis[x]=1; int ret=1;
  for(int i=hd[x],u;i;i=nxt[i])
    if(!vis[u=to[i]]&&(s&bin[u-1]))ret+=dfs(u,s);
  return ret;
}
bool ck(int s)//
{
  int cnt=0;
  for(int x=1;x<=n;x++)
    {
      if(!(s&bin[x-1]))continue;
      int deg=0; cnt++;
      for(int i=hd[x];i;i=nxt[i])
    {
      if(s&bin[to[i]-1])deg++;
    }
      if(deg&1)return 1;
    }
  for(int i=1;i<=n;i++)vis[i]=0;
  for(int i=1;i<=n;i++)
    if(s&bin[i-1])return dfs(i,s)!=cnt;
}
int cal(int s){int ret=0; while(s)ret+=(s&1),s>>=1; return ret;}
void fmt(int *a,int tp)
{
  for(int d=1;d<bin[n];d<<=1)
    for(int s=0;s<bin[n];s++)
      if(s&d)a[s]=upt(a[s]+a[s^d]*tp);
}
int main()
{
  n=rd(); m=rd(); p=rd();
  bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]*2;
  for(int i=1,x,y;i<=m;i++)x=rd(),y=rd(),add(x,y),add(y,x);
  for(int s=0;s<bin[n];s++)cnt[s]=cal(s);
  for(int i=1;i<=n;i++)g2[bin[i-1]]=rd();
  for(int s=0;s<bin[n];s++)g2[s]=upt(g2[s&(-s)]+g2[s^(s&(-s))]);
  for(int s=0;s<bin[n];s++)g[cnt[s]][s]=pw(g2[s],p),g2[s]=pw(g[cnt[s]][s],mod-2);
  for(int s=0;s<bin[n];s++)if(!ck(s))g[cnt[s]][s]=0;
  for(int i=1;i<=n;i++)fmt(g[i],1);
  f[0][0]=1; fmt(f[0],1);
  for(int i=1;i<=n;i++)
    {
      for(int j=0;j<=i;j++)
    for(int s=0;s<bin[n];s++)
      f[i][s]=(f[i][s]+(ll)f[j][s]*g[i-j][s])%mod;
      fmt(f[i],-1);
      for(int s=0;s<bin[n];s++)
    if(cnt[s]==i)f[i][s]=(ll)f[i][s]*g2[s]%mod;
    else f[i][s]=0;
      fmt(f[i],1);
    }
  fmt(f[n],-1);
  printf("%d\n",f[n][bin[n]-1]);
  return 0;
}