1. 程式人生 > 實用技巧 >【CF600E】Lomset gelral 題解(樹上啟發式合併)

【CF600E】Lomset gelral 題解(樹上啟發式合併)

題目連結

題目大意:給出一顆含有$n$個結點的樹,每個節點有一個顏色。求樹中每個子樹最多的顏色的編號和。

-------------------------

樹上啟發式合併(dsu on tree)。

我們先考慮暴力怎麼做。遍歷整顆樹,暴力列舉子樹然後用桶維護顏色個數。這樣做是$O(n^2)$的,顯然會T。我們需要一種更快的演算法:樹上啟發式合併。

關於啟發式演算法的介紹,詳見OI Wiki。本文只介紹樹上啟發式合併演算法。本題的解法:

每處理完一顆子樹,我們都要把桶清空一次,以免對它的兄弟造成影響。而這樣做還要從它的祖先遍歷一遍,浪費時間。

我們發現:遍歷最後一顆子樹時,桶是不用清空的。因為遍歷完那顆子樹後可以直接把答案加入$ans$中。那我們肯定選重兒子啊,省時省力。遍歷輕兒子相對不費事。

看起來是不是沒有快多少?實際上它是$O(n\log n)$的。下面是證明:

對於每個節點,它被計算的次數就是它到根節點路徑的輕邊個數。

而結點往上跳一次,子樹大小至少為原來兩倍,所以輕邊個數最多是$\log n$。所以時間複雜度$O(n\log n)$。

證明過程跟樹鏈剖分的有點像。

程式碼:

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,color[200005],bucket[200005],ans[200005];
int size[200005],son[200005],sum,mx;
int head[200005],cnt;
struct node { int next,to; }edge[200005]; inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to) { edge[++cnt].next=head[from]; edge[cnt].to
=to; head[from]=cnt; } inline void dfs_son(int now,int fa) { size[now]=1; int mx=0,p=0; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; dfs_son(to,now); size[now]+=size[to]; if (size[to]>mx) { mx=size[to]; p=to; } } if (p) son[p]=1; } void getans(int x,int f,int p){ bucket[color[x]]++; if(bucket[color[x]]>mx){ mx=bucket[color[x]]; sum=color[x]; }else if(bucket[color[x]]==mx)sum+=color[x]; for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==f || y==p)continue; getans(y,x,p); } } inline void init(int now,int fa) { bucket[color[now]]--; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; init(to,now); } } inline void dfs(int now,int fa) { int p=0; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; if (!son[to]) { dfs(to,now); init(to,now); sum=mx=0; } else p=to; } if (p) dfs(p,now); getans(now,fa,p); ans[now]=sum; } signed main() { n=read(); for (int i=1;i<=n;i++) color[i]=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } dfs_son(1,0); dfs(1,0); for (int i=1;i<=n;i++) printf("%lld ",ans[i]); return 0; }