1. 程式人生 > 其它 >7123. 【2021.6.15NOI模擬】尼特

7123. 【2021.6.15NOI模擬】尼特

給出一個序列\(a_i\),長度\(n\)

現在對於每個長度為\(n-1\)的序列\(b_i\),值域\(m\),將\(a_i\)刪掉一個位置之後最大化\(\sum [a_i=b_i]\)

對於每個\(b_i\)求和。

\(n\le 10^6\)


沒前途的DP做法:

首先考慮\(b_i\)固定時怎麼搞。設\(f_i\)表示\(a_{1..i}\)\(b_{1..i-1}\)搞的時候的答案,\(c_i=\sum_{j\le i} [a_i=b_i]\)。於是有\(f_i=\max(f_{i-1}+[a_i=b_{i-1}],c_{i-1})\)\(c_i=c_{i-1}+[a_i=b_i]\)

.

既然要計數就DP套DP:設\(g_{i,f,c,lst}\)表示前\(i\)位,\(f\)是什麼,\(c\)是什麼,最後一個字元是什麼的方案數。容易發現\(lst\)可以去掉,於是就得到了個\(O(n^3)\)的做法。

注意到我們要求\(\sum f*g_{n,f,c}\)。於是考慮轉移時\(f\)每次新增就加入答案。於是設\(g_{n,c},s_{n,c}\)\(c\)表示原來的\(f-c\)\(g\)\(s\)轉移大體相同,只是多了個從\(g\)\(s\)的轉移。然後得到\(O(n^2)\)做法。

然後題解不知道在雲什麼……不知道它是怎麼從這個方法上擴充套件的。

然後gmh114514拯救世界:

先是一個模型轉化:對於\(a_i\neq a_{i+1}\)的位置,如果\(b_i=a_i\)則標個左箭頭,如果\(b_i=a_{i+1}\)標個右箭頭,否則不標。對於\(a_i=a_{i+1}\)的位置肯定有貢獻所以可以在最後算。

現在問題是:找到個分界點,最大化左邊左箭頭+右邊右箭頭,對其計數。

好啦發現這個東西完全符合上面的DP。然而gmh114514直接選擇計數!

首先要求的東西相當於:把左箭頭看做+1,右箭頭看做-1,字首和最大值加右箭頭的個數就是貢獻。

右箭頭個數的貢獻可以先算,於是只有字首最大值的貢獻。為了方便直接將長度記作\(n\)\(m\leftarrow m-2\)

按照套路,列舉\(j\ge 1\),計算\(字首最大值\ge j\)的方案,加起來。把它畫在座標系上,按照套路,如果終點不超過\(j\)就對稱過去。於是貢獻為\(\sum_{j\ge 1}\sum_i calc(n,\max(i,2j-i))\),其中\(calc(x,y)\)表示從原點到\((x,y)\),每次可以向右上、正右、右下走的方案數。

那個東西等於\(\sum_{i\ge 1} calc(n,i)(2i-1)\)。寫成生成函式推一下:

\[\sum_{i\ge 1} [x^i](x+m+x^{-1})^n(2i-1)\\ =2\sum_{i\ge 1} [x^i](x+m+x^{-1})^ni-\sum_{i\ge 1} [x^i](x+m+x^{-1})^n\\ =2\sum_{i\ge 0} [x^i]((x+m+x^{-1})^n)'-\frac{(m+2)^n-[x^0](x+m+x^{-1})^n}{2}\\ =2\sum_{i\ge 0} [x^i]n(x+m+x^{-1})^{n-1}(1-x^{-2})-\frac{(m+2)^n-[x^0](x+m+x^{-1})^n}{2}\\ =2n([x^0]+[x^1])(x+m+x^{-1})^{n-1}-\frac{(m+2)^n-[x^0](x+m+x^{-1})^n}{2}\\ \]

發現其實只需要算\(O(1)\)\(calc\),每次計算時間\(O(n)\)


using namespace std;
#include <bits/stdc++.h>
const int Mxdt=100000;
inline char gc() {
	static char buf[Mxdt],*p1=buf,*p2=buf;
	return p1==p2&&(p2=(p1=buf)+fread(buf,1,Mxdt,stdin),p1==p2)?EOF:*p1++;
}
inline int read() {
	int s=0,f=0;char ch=gc();
	while(ch<'0'||ch>'9')f|=(ch=='-'),ch=gc();
	while(ch>='0'&&ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=gc();
	return f?-s:s;
}
const int N=1000005,mo=998244353;
typedef long long ll;
ll qpow(ll x,ll y=mo-2){
	ll r=1;
	for (;y;y>>=1,x=x*x%mo)
		if (y&1)
			r=r*x%mo;
	return r;
}
ll fac[N],ifac[N];
void initC(int n){
	fac[0]=1;
	for (int i=1;i<=n;++i)
		fac[i]=fac[i-1]*i%mo;
	ifac[n]=qpow(fac[n]);
	for (int i=n-1;i>=0;--i)
		ifac[i]=ifac[i+1]*(i+1)%mo;
}
ll C(int m,int n){
	return fac[m]*ifac[n]%mo*ifac[m-n]%mo;
}
int n,m;
int a[N];
void add(int &x,ll y){x=(x+y)%mo;}
ll pw[N*2];
ll calc(int x,int y){
	ll ans=0;
	for (int i=0;i<=x;++i)
		if (i>=x+y-i)
			(ans+=C(i,x+y-i)*pw[i*2-x-y]%mo*C(x,i))%=mo;
	return ans;
}
int main(){
	freopen("nit.in","r",stdin);
	freopen("nit.out","w",stdout);
	n=read(),m=read();
	for (int i=1;i<=n;++i)
		a[i]=read();
	if (m==1){
		printf("%d\n",n-1);
		return 0;
	}
	int cnt=0;
	for (int i=1;i<n;++i)
		cnt+=(a[i]!=a[i+1]);
	initC(n);
	ll ans=0;
	for (int i=0;i<=cnt;++i)
		(ans+=qpow(m-1,cnt-i)*C(cnt,i)%mo*i)%=mo;
		
	pw[0]=1;
	for (int i=1;i<=cnt*2;++i)
		pw[i]=pw[i-1]*(m-2)%mo;
	(ans+=cnt*2*(calc(cnt-1,0)+calc(cnt-1,1)))%=mo;
	
	(ans+=-(qpow(m,cnt)-calc(cnt,0))%mo*qpow(2))%=mo;
	
	ans=ans*m%mo;
	ans=ans*qpow(m,n-1-cnt)%mo;
	(ans+=(ll)qpow(m,n-1)%mo*(n-1-cnt)%mo)%=mo;
	
	ans=ans*qpow(m,mo-1-n)%mo;
	
	ans=(ans+mo)%mo;
	printf("%lld\n",ans);
	return 0;
}