codeforces 888G Xor-MST Sollin演算法求最小生成樹,0-1異或True
You are given a complete undirected graph with n vertices. A number ai is assigned to each vertex, and the weight of an edge between vertices i and j is equal to ai xor aj.
Calculate the weight of the minimum spanning tree in this graph.
InputThe first line contains n (1 ≤ n ≤ 200000) — the number of vertices in the graph.
The second line contains n integers a1, a2, ..., an (0 ≤ ai < 230) — the numbers assigned to the vertices.
OutputPrint one number — the weight of the minimum spanning tree in the graph.
Examples input5 1 2 3 4 5output
8input
4 1 2 3 4output
8
題解:
這個題用的演算法比較古老偏僻,反正在這之前我是沒有聽說過的。。。。。
1、Sollin演算法介紹
Sollin(Boruvka)演算法。
原理大概是這樣的:剛開始把每個點看成是一個聯通分量,然後同時對所有的聯通分量進行擴充套件,這樣的話,每次至少有一半數量的聯通分量被合併。
合併的時候是這樣進行操作的,首先拿出一個聯通分量,然後從這個聯通分量向其他的聯通分量求一個最小邊,然後把最小邊兩個端點相連的聯通分量合併,再去列舉其他的聯通分量,保證每次迭代的所有聯通分量都被考慮過。
我們只需要迭代logn次就可以了。
2、Sollin演算法在本題中的應用:
考慮到邊是xor運算得到的,這是套路之一,我們首先建立一個0-1的trie樹。
然後把所有的點都加進去。
每次遍歷一個聯通分量的時候,我們就把這個聯通分量從Trie裡面刪除掉,然後列舉這個聯通分量裡面的點,對於這個點,在Trie裡面找xor最小的點。
然後合併這兩個聯通分量就好了。
聯通分量使用並查集來維護。
3、細節:
注意,這裡不能用vector來存放聯通分量,否則會超記憶體的。
正確的方法應該是:
把點按照他所屬的聯通分量進行排序,這樣的話,屬於同一個聯通分量的點都在連續的一個區段裡面,處理起來非常方便。
程式碼:
#include<bits/stdc++.h>
#define convert(s,i) ((s>>i)&1)
using namespace std;
typedef pair<int,int> P;
const int inf = 2e9;
const int maxn = 200007;
struct Trie{
int frq,nxt[2];
}pool[maxn*31];
int cnt;
int n;
void insert(int s){
int cur = 0;
for(int i = 30;i >= 0;--i){
int &pos = pool[cur].nxt[convert(s,i)];
if(!pos) pos = ++cnt;
cur = pos;
pool[cur].frq++;
}
}
int findxor(int s){
int cur = 0,ans = s;
for(int i = 30;i >= 0;--i){
int pos = pool[cur].nxt[convert(s,i)];
if(!(pos && pool[pos].frq)) pos = pool[cur].nxt[1^convert(s,i)],ans ^= (1<<i);
cur = pos;
}
return ans;
}
void del(int s){
int cur = 0;
for(int i = 30;i >= 0;--i){
int pos = pool[cur].nxt[convert(s,i)];
cur = pos;
pool[cur].frq--;
}
}
int a[maxn],parent[maxn],used[maxn];
int find(int x){
return x == parent[x]?x:parent[x] = find(parent[x]);
}
int join(int x,int y){
int px = find(x);
int py = find(y);
if(px == py) return 0;
parent[py] = px;
return 1;
}
bool check(){
int f = 0;
for(int i = 1;i <= n;++i) f += parent[i] == i;
return f == 1;
}
long long res = 0;
P ps[maxn];
int main(){
cnt = 0;
cin>>n;
for(int i = 1;i <= n;++i) parent[i] = i;
memset(pool,0,sizeof(pool));
for(int i = 1;i <= n;++i) scanf("%d",&a[i]);
sort(a+1,a+1+n);n = unique(a+1,a+1+n) - (a+1);
for(int i = 1;i <= n;++i) insert(a[i]);
while(!check()){
memset(used,0,sizeof(used));
for(int i = 1;i <= n;++i) ps[i] = make_pair(find(i),i);
sort(ps+1,ps+1+n);
int pre = ps[1].first,last = 1;
for(int i = 1;i <= n;++i){
int u = ps[i].second;
if(!used[pre] && ps[i].first == pre) del(a[u]);
if(ps[i+1].first != pre){
if(used[find(u)]) {
for(int j = last;j <= i;j++) insert(a[ps[j].second]);
last = i+1;pre = ps[last].first;
continue;
}
used[pre] = 1;
int mi = inf,cv;
for(int j = last;j <= i;++j) {
int v = findxor(a[ps[j].second]);
if((v^a[ps[j].second]) < mi) mi = v^a[ps[j].second],cv = v;
}
res += mi;
for(int j = last;j <= i;++j) insert(a[ps[j].second]);
int pj = lower_bound(a+1,a+1+n,cv)-a;
pj = find(pj);
pre = find(u);
if(pre > pj) swap(pre,pj);
join(pre,pj);
pre = ps[i+1].first,last = i+1;
}
}
}
cout<<res<<endl;
return 0;
}