主席樹(可持久化線段樹)
主席樹(可持久化線段樹)
前置芝士
知識點
線段樹,權值線段樹(不一樣),離散化,字首和(思想)
由來
據說,是一位叫fotile主席的大大在寫一道題時因為不會劃分樹就臨時yy出一個演算法,於是,這演算法就這麼誕生了。
作用
對區間求 \(kth\)
思想
思考優化策略
一列數,可以對於每個點i都建一棵權值線段樹,維護1~i這些數,每個不同的數出現的個數(權值線段樹以值域作為區間)
現在,n棵線段樹就建出來了,第i棵線段樹代表1~i這個區間
例如,一列數,n為6,數分別為1 3 2 3 6 1
首先,每棵樹都是這樣的:
![](C:\Users\Administrator\Desktop\my things\2\20190511122617292.png)
以第4棵線段樹為例,1~4的數分別為1 3 2 3
![](C:\Users\Administrator\Desktop\my things\2\20190511123014414.png)
因為是同一個問題,n棵權值線段樹的形狀是一模一樣的,只有節點的權值不一樣
所以這樣的兩棵線段樹之間是可以相加減的(兩顆線段樹相減就是每個節點對應相減)
想想,第x棵線段樹減去第y棵線段樹會發生什麼?
第x棵線段樹代表的區間是[1,x]
第y棵線段樹代表的區間是[1,y]
兩棵線段樹一減
設x>y,[1,x]−[1,y]=[y+1,x][1,x]-[1,y]=[y+1,x][1,x]−[1,y]=[y+1,x]
所以這兩棵線段樹相減可以產生一個新的區間對應的線段樹!
等等,這不是字首和的思想嗎
這樣一來,任意一個區間的線段樹,都可以由我這n個基礎區間表示出來了!
因為每個區間都有一個線段樹
然後詢問對應區間,在區間對應的線段樹中查詢kth就行了
這就是主席樹的一個核心思想:字首和思想
具體做法待會兒再講,現在還有一個嚴峻的問題,就是n棵線段樹空間太大了!
如何優化空間,就是主席樹另一個核心思想
我們發現這n棵線段樹中,有很多重複的點,這些重複的點浪費了大部分的空間,所以考慮如何去掉這些冗餘點
在建樹中優化
假設現在有一棵線段樹,序列往右移一位,建一棵新的線段樹
對於一個兒子的值域區間,如果權值有變化,那麼新建一個節點,否則,連到原來的那個節點上
現在舉幾個例子來說明
序列4 3 2 3 6 1
區間[1,1]的線段樹(藍色節點為新節點)
![](C:\Users\Administrator\Desktop\my things\2\20190511125552180.png)
區間[1,2]的線段樹(橙色節點為新節點)
![](C:\Users\Administrator\Desktop\my things\2\20190511130206210.png)
區間[1,3]的線段樹(紫色節點為新節點)
![](C:\Users\Administrator\Desktop\my things\2\20190511130727560.png)
這樣是不是非常優秀啊?
(部分借用https://blog.csdn.net/ModestCoder_/java/article/details/90107874)
模板及程式碼
離散化
for(int i=1;i<=n;i++) a[i]=read(),b[i]=a[i];
sort(b+1,b+n+1);
q=unique(b+1,b+n+1)-b-1;
插入(建新樹)
int update(int o,int l,int r){
int oo=++cnt;
ls[oo]=ls[o],rs[oo]=rs[o],sum[oo]=sum[o]+1;
if(l==r) return oo;
int mid=(l+r)>>1;
if(mid>=p) ls[oo]=update(ls[oo],l,mid); //下文有p
else rs[oo]=update(rs[oo],mid+1,r);
}
建樹
//函式
void build(int &rt,int l,int r){
rt=++cnt,sum[rt]=0;
if(l==r) return ;
int mid=(l+r)>>1;
build(ls[rt],l,mid);
build(rs[rt],mid+1,r);
}
//主函式中的操作
build(rt[0],1,q); //建一棵空樹,雖說不建也沒關係 以防萬一
for(int i=1;i<=n;i++){ //1~n依次建樹
p=lower_bound(b+1,b+q+1,a[i])-b;
rt[i]=update(rt[i-1],1,q);
}
查詢
//函式
int query(int u,int v,int l,int r,int k){//u、v為兩棵線段樹當前節點編號,相減就是詢問區間
int mid=(l+r)>>1,x=sum[ls[v]]-sum[ls[u]];
if(l==r) return l;
if(x>=k) return query(ls[u],ls[v],l,mid,k);
else return query(rs[u],rs[v],mid+1,r,k-x);
//kth操作,排名<=左兒子的數的個數,說明在左兒子,進入左兒子;反之,目標在右兒子,排名需要減去左兒子的權值
}
//主函式中的操作
while(m--){
int l=read(),r=read(),k=read();
printf("%d\n",b[query(rt[l-1],rt[r],1,q,k)]);
}
模板題1
區間第 \(k\) 大
程式碼
#include <bits/stdc++.h>
#define maxn 200010
using namespace std;
int a[maxn], b[maxn], n, m, q, p, sz;
int lc[maxn << 5], rc[maxn << 5], sum[maxn << 5], rt[maxn << 5];
//空間要注意
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void build(int &rt, int l, int r){
rt = ++sz, sum[rt] = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(lc[rt], l, mid); build(rc[rt], mid + 1, r);
}
int update(int o, int l, int r){
int oo = ++sz;
lc[oo] = lc[o], rc[oo] = rc[o], sum[oo] = sum[o] + 1;
if (l == r) return oo;
int mid = (l + r) >> 1;
if (mid >= p) lc[oo] = update(lc[oo], l, mid); else rc[oo] = update(rc[oo], mid + 1, r);
return oo;
}
int query(int u, int v, int l, int r, int k){
int mid = (l + r) >> 1, x = sum[lc[v]] - sum[lc[u]];
if (l == r) return l;
if (x >= k) return query(lc[u], lc[v], l, mid, k); else return query(rc[u], rc[v], mid + 1, r, k - x);
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read(), b[i] = a[i];
sort(b + 1, b + 1 + n);
q = unique(b + 1, b + 1 + n) - b - 1;
build(rt[0], 1, q);
for (int i = 1; i <= n; ++i){
p = lower_bound(b + 1, b + 1 + q, a[i]) - b;
rt[i] = update(rt[i - 1], 1, q);
}
while (m--){
int l = read(), r = read(), k = read();
printf("%d\n", b[query(rt[l - 1], rt[r], 1, q, k)]);
}
return 0;
}
模板題2
可持久化陣列
程式碼
#include <bits/stdc++.h>
#define maxn 1000010
using namespace std;
struct chairman{
int l, r, v;
}seg[maxn << 5];
int rt[maxn], sz, n, m, a[maxn];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void build(int &rt, int l, int r){
rt = ++sz;
if (l == r){
seg[rt].v = a[l]; return;
}
int mid = (l + r) >> 1;
build(seg[rt].l, l, mid); build(seg[rt].r, mid + 1, r);
}
int update(int o, int l, int r, int p, int k){
int oo = ++sz;
seg[oo] = seg[o];
if (l == r){
seg[oo].v = k; return oo;
}
int mid = (l + r) >> 1;
if (mid >= p) seg[oo].l = update(seg[oo].l, l, mid, p, k); else
seg[oo].r = update(seg[oo].r, mid + 1, r, p, k);
return oo;
}
int query(int rt, int l, int r, int p){
if (l == r) return seg[rt].v;
int mid = (l + r) >> 1;
if (mid >= p) return query(seg[rt].l, l, mid, p); else
return query(seg[rt].r, mid + 1, r, p);
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read();
build(rt[0], 1, n);
for (int i = 1; i <= m; ++i){
int x = read(), opt = read(), y = read();
if (opt == 1){
int z = read();
rt[i] = update(rt[x], 1, n, y, z);
} else{
rt[i] = rt[x];
printf("%d\n", query(rt[i], 1, n, y));
}
}
return 0;
}