1. 程式人生 > 實用技巧 >主席樹(可持久化線段樹)

主席樹(可持久化線段樹)

主席樹(可持久化線段樹)

前置芝士

知識點

線段樹,權值線段樹(不一樣),離散化,字首和(思想)

由來

據說,是一位叫fotile主席的大大在寫一道題時因為不會劃分樹就臨時yy出一個演算法,於是,這演算法就這麼誕生了。

作用

對區間求 \(kth\)

思想

思考優化策略

一列數,可以對於每個點i都建一棵權值線段樹,維護1~i這些數,每個不同的數出現的個數(權值線段樹以值域作為區間)

現在,n棵線段樹就建出來了,第i棵線段樹代表1~i這個區間

例如,一列數,n為6,數分別為1 3 2 3 6 1
首先,每棵樹都是這樣的:

![](C:\Users\Administrator\Desktop\my things\2\20190511122617292.png)

以第4棵線段樹為例,1~4的數分別為1 3 2 3

![](C:\Users\Administrator\Desktop\my things\2\20190511123014414.png)

因為是同一個問題,n棵權值線段樹的形狀是一模一樣的,只有節點的權值不一樣
所以這樣的兩棵線段樹之間是可以相加減的(兩顆線段樹相減就是每個節點對應相減)

想想,第x棵線段樹減去第y棵線段樹會發生什麼?
第x棵線段樹代表的區間是[1,x]
第y棵線段樹代表的區間是[1,y]
兩棵線段樹一減
設x>y,[1,x]−[1,y]=[y+1,x][1,x]-[1,y]=[y+1,x][1,x]−[1,y]=[y+1,x]
所以這兩棵線段樹相減可以產生一個新的區間對應的線段樹!

等等,這不是字首和的思想嗎
這樣一來,任意一個區間的線段樹,都可以由我這n個基礎區間表示出來了!
因為每個區間都有一個線段樹
然後詢問對應區間,在區間對應的線段樹中查詢kth就行了

這就是主席樹的一個核心思想:字首和思想

具體做法待會兒再講,現在還有一個嚴峻的問題,就是n棵線段樹空間太大了!
如何優化空間,就是主席樹另一個核心思想

我們發現這n棵線段樹中,有很多重複的點,這些重複的點浪費了大部分的空間,所以考慮如何去掉這些冗餘點

在建樹中優化

假設現在有一棵線段樹,序列往右移一位,建一棵新的線段樹
對於一個兒子的值域區間,如果權值有變化,那麼新建一個節點,否則,連到原來的那個節點上

現在舉幾個例子來說明
序列4 3 2 3 6 1

區間[1,1]的線段樹(藍色節點為新節點)

![](C:\Users\Administrator\Desktop\my things\2\20190511125552180.png)

區間[1,2]的線段樹(橙色節點為新節點)

![](C:\Users\Administrator\Desktop\my things\2\20190511130206210.png)

區間[1,3]的線段樹(紫色節點為新節點)

![](C:\Users\Administrator\Desktop\my things\2\20190511130727560.png)

這樣是不是非常優秀啊?
(部分借用https://blog.csdn.net/ModestCoder_/java/article/details/90107874)

模板及程式碼

離散化
for(int i=1;i<=n;i++) a[i]=read(),b[i]=a[i];
sort(b+1,b+n+1);
q=unique(b+1,b+n+1)-b-1;
插入(建新樹)
int update(int o,int l,int r){
    int oo=++cnt;
    ls[oo]=ls[o],rs[oo]=rs[o],sum[oo]=sum[o]+1;
    if(l==r) return oo;
    int mid=(l+r)>>1;
    if(mid>=p) ls[oo]=update(ls[oo],l,mid);	//下文有p
    else rs[oo]=update(rs[oo],mid+1,r);
}
建樹
//函式
void build(int &rt,int l,int r){
    rt=++cnt,sum[rt]=0;
    if(l==r) return ;
    int mid=(l+r)>>1;
    build(ls[rt],l,mid);
    build(rs[rt],mid+1,r);
}

//主函式中的操作
build(rt[0],1,q);				//建一棵空樹,雖說不建也沒關係 以防萬一
for(int i=1;i<=n;i++){			//1~n依次建樹
    p=lower_bound(b+1,b+q+1,a[i])-b;
    rt[i]=update(rt[i-1],1,q);
}
查詢
//函式
int query(int u,int v,int l,int r,int k){//u、v為兩棵線段樹當前節點編號,相減就是詢問區間
	int mid=(l+r)>>1,x=sum[ls[v]]-sum[ls[u]];
	if(l==r) return l;
    if(x>=k) return query(ls[u],ls[v],l,mid,k);
    else return query(rs[u],rs[v],mid+1,r,k-x);
    //kth操作,排名<=左兒子的數的個數,說明在左兒子,進入左兒子;反之,目標在右兒子,排名需要減去左兒子的權值
}

//主函式中的操作
while(m--){
    int l=read(),r=read(),k=read();
    printf("%d\n",b[query(rt[l-1],rt[r],1,q,k)]);
}

模板題1

區間第 \(k\)
程式碼
#include <bits/stdc++.h>
#define maxn 200010
using namespace std;
int a[maxn], b[maxn], n, m, q, p, sz;
int lc[maxn << 5], rc[maxn << 5], sum[maxn << 5], rt[maxn << 5];
//空間要注意

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

void build(int &rt, int l, int r){
	rt = ++sz, sum[rt] = 0;
	if (l == r) return;
	int mid = (l + r) >> 1;
	build(lc[rt], l, mid); build(rc[rt], mid + 1, r);
}

int update(int o, int l, int r){
	int oo = ++sz;
	lc[oo] = lc[o], rc[oo] = rc[o], sum[oo] = sum[o] + 1;
	if (l == r) return oo;
	int mid = (l + r) >> 1;
	if (mid >= p) lc[oo] = update(lc[oo], l, mid); else rc[oo] = update(rc[oo], mid + 1, r);
	return oo;
}

int query(int u, int v, int l, int r, int k){
	int mid = (l + r) >> 1, x = sum[lc[v]] - sum[lc[u]];
	if (l == r) return l;
	if (x >= k) return query(lc[u], lc[v], l, mid, k); else return query(rc[u], rc[v], mid + 1, r, k - x);
}

int main(){
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) a[i] = read(), b[i] = a[i];
	sort(b + 1, b + 1 + n);
	q = unique(b + 1, b + 1 + n) - b - 1;
	build(rt[0], 1, q);
	for (int i = 1; i <= n; ++i){
		p = lower_bound(b + 1, b + 1 + q, a[i]) - b;
		rt[i] = update(rt[i - 1], 1, q);
	} 
	while (m--){
		int l = read(), r = read(), k = read();
		printf("%d\n", b[query(rt[l - 1], rt[r], 1, q, k)]);
	}
	return 0;
}

模板題2

可持久化陣列
程式碼
#include <bits/stdc++.h>
#define maxn 1000010
using namespace std;
struct chairman{
	int l, r, v;
}seg[maxn << 5];
int rt[maxn], sz, n, m, a[maxn];

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

void build(int &rt, int l, int r){
	rt = ++sz;
	if (l == r){
		seg[rt].v = a[l]; return;
	}
	int mid = (l + r) >> 1;
	build(seg[rt].l, l, mid); build(seg[rt].r, mid + 1, r);
}

int update(int o, int l, int r, int p, int k){
	int oo = ++sz;
	seg[oo] = seg[o];
	if (l == r){
		seg[oo].v = k; return oo;
	}
	int mid = (l + r) >> 1;
	if (mid >= p) seg[oo].l = update(seg[oo].l, l, mid, p, k); else
	seg[oo].r = update(seg[oo].r, mid + 1, r, p, k);
	return oo;
}

int query(int rt, int l, int r, int p){
	if (l == r) return seg[rt].v;
	int mid = (l + r) >> 1;
	if (mid >= p) return query(seg[rt].l, l, mid, p); else
	return query(seg[rt].r, mid + 1, r, p);
}

int main(){
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	build(rt[0], 1, n);
	for (int i = 1; i <= m; ++i){
		int x = read(), opt = read(), y = read();
		if (opt == 1){
			int z = read();
			rt[i] = update(rt[x], 1, n, y, z);
		} else{
			rt[i] = rt[x];
			printf("%d\n", query(rt[i], 1, n, y));
		}
	}
	return 0;
}