1. 程式人生 > 實用技巧 >[BZOJ2616]SPOJ PERIODNI(笛卡爾樹+樹形dp)

[BZOJ2616]SPOJ PERIODNI(笛卡爾樹+樹形dp)

題面

http://darkbzoj.tk/problem/2616

題解

前置知識

先建出笛卡爾樹,並定義幾個變數:

\(lb[i]\)表示第i列左側第一個低於i的位置。

\(rb[i]\)表示第i列右側第一個低於i的位置。

\(dh[i]=h[i]-max(h[lb[i]],h[rb[i]])\)

\(S_i\)描述一個圖形:擷取原圖中的lb[i]+1到rb[i]-1列,並把最下方h[i]行去掉。

然後就可以描述轉移函式:$f[u][x] \(表示\)S_i$中放入x個車的方案數。

由於lb[u]+1到rb[u]-1這些點正好構成了笛卡爾樹中u的子樹,所以轉化為一個樹形dp。並且,有u的子樹大小sz[u]=rb[u]-lb[u]+1。

for(int i = 0;i <= sz[lc[u]];i++)
	for(int j = 0;j <= sz[rc[u]];j++)
		for(int d = 0;i + j + d <= sz[u];d++)	
			f[u][i+j+d] += f[lc[u]][i] * f[rc[u]][j] % mod * C(sz[u]-i-j,d) % mod * P(dh[u],d) % mod);

其中lc[u],rc[u]代表u的左右子節點,C、P代表組合數與排列數。

這是為什麼呢?\(S_u\)其實就是\(S_{lc[u]}\)\(S_{rc[u]}\)“中間隔一格”拼在一起,再在下面加上\(dh[u]\)行所得。程式碼中d列舉的就是下面這dh[u]行中共放了幾個車。如果\(S_{lc[u]}\)中放了i個車,\(S_{rc[u]}\)中放了j個車,那麼\(S_u\)中還沒被佔用的列數就是\(sz[u]-i-j\)。在這些列中選出無序的d列,再在最下面新增的dh[u]行中選出有序的d行放車,這就解釋了上面這個轉移方程。

時間複雜度方面,由於i和j都只列舉到對應的sz,所以總時間複雜度為\(O(n^3)\)

  • P.S.這一過程還可以用卷積優化。

程式碼

#include<bits/stdc++.h>

using namespace std;

#define ll long long
#define rg register
#define In inline

const int N = 500;
const ll mod = 1e9 + 7;
const ll W = 1e6;

namespace ModCalc{
	In void Inc(ll &x,ll y){
		x += y;if(x >= mod)x -= mod;
	}
	In void Dec(ll &x,ll y){
		x -= y;if(x < 0)x += mod;
	}
	In ll Add(ll x,ll y){
		Inc(x,y);return x;
	}
	In ll Sub(ll x,ll y){
		Dec(x,y);return x;
	}
}
using namespace ModCalc;

int n,k;
ll jc[W+5],iv[W+5];

ll power(ll a,ll n){
	ll s = 1,x = a;
	while(n){
		if(n & 1)s = s * x % mod;
		x = x * x % mod;
		n >>= 1;
	}
	return s;
}

void prepro(){
	jc[0] = 1;
	for(rg int i = 1;i <= W;i++)jc[i] = jc[i-1] * i % mod;
	iv[W] = power(jc[W],mod - 2);
	for(rg int i = W - 1;i >= 0;i--)iv[i] = iv[i+1] * (i + 1) % mod;
}

In ll C(ll n,ll m){
	if(n < m)return 0;
	return jc[n] * iv[m] % mod * iv[n-m] % mod;
}

In ll P(ll n,ll m){
	if(n < m)return 0;
	return jc[n] * iv[n-m] % mod;
}

struct CartTree{
	int top,rt;
	int fa[N+5],lc[N+5],rc[N+5],sz[N+5];
	ll h[N+5],dh[N+5],f[N+5][N+5];
	int st[N+5];	
	In void build(){
		for(rg int i = 1;i <= n;i++){
			scanf("%lld",&h[i]);
			while(top && h[st[top]] > h[i])
				lc[i] = st[top--];
			fa[i] = st[top];
			if(!fa[i])rt = i;else rc[fa[i]] = i;
			if(lc[i])fa[lc[i]] = i;
			st[++top] = i;
		}
		f[0][0] = 1;
	}
	In void prepro(int u){
		int l = u;while(lc[l])l = lc[l];
		int r = u;while(rc[r])r = rc[r];
		dh[u] = h[u] - max(h[l-1],h[r+1]);
	}
	In void dfs(int u){
		if(lc[u])dfs(lc[u]);
		if(rc[u])dfs(rc[u]);
		sz[u] = sz[lc[u]] + sz[rc[u]] + 1;
		for(rg int i = 0;i <= sz[lc[u]];i++)
			for(rg int j = 0;j <= sz[rc[u]];j++)
				for(rg int d = 0;i + j + d <= sz[u];d++)	
					Inc(f[u][i+j+d],f[lc[u]][i] * f[rc[u]][j] % mod * C(sz[u]-i-j,d) % mod * P(dh[u],d) % mod);
	}
}T;

int main(){
	scanf("%d%d",&n,&k);
	prepro();
	T.build();
	for(rg int i = 1;i <= n;i++)T.prepro(i);
	T.dfs(T.rt);
	cout << T.f[T.rt][k] << endl;
	return 0;
}