1. 程式人生 > 實用技巧 >資料結構:樹狀陣列、線段樹

資料結構:樹狀陣列、線段樹

樹狀陣列/線段樹都可以把原來樸素的O(n2)變為O(n*logn),用於高效計算數列的字首和。具體主要表現為3種情況:區間修改單點查詢;單點修改區間查詢;區間修改區間查詢,這3種情況是一個遞進關係,理解規律之後就比較好記。

樹狀陣列的具體原理見https://www.cnblogs.com/xenny/p/9739600.html,這裡就不詳細描述了.....樹狀陣列的顯著特點就是藉助lowbit來確定修改或者查詢的位置。如下是區間修改區間查詢的完整板子程式碼:

#include<bits/stdc++.h>
using namespace std;
int n,m;
int s1[500005],s2[500005
],a[500005]; int lowbit(int x) { return x&(-x); } void update(int i,int k) { int x=i; while(i<=n){ s1[i]+=k; s2[i]+=k*(x-1); //注意 i+=lowbit(i); } } int getSum(int i) { int res=0,x=i; while(i>0){ res+=s1[i]*x-s2[i]; //注意 i-=lowbit(i); }
return res; } int main() { scanf("%d%d",&n,&m); memset(a,0,sizeof a); memset(s1,0,sizeof s1); memset(s2,0,sizeof s2); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); update(i,a[i]-a[i-1]); //完全版樹狀陣列在構建時輸入a[i]-a[i-1] } int s,x,y,k; while(m--){ scanf(
"%d",&s); if(s==1){ scanf("%d%d%d",&x,&y,&k); update(x,k); update(y+1,-k); } else if(s==2){ scanf("%d",&x); printf("%d\n",getSum(x)-getSum(x-1)); } } return 0; }

樹狀陣列的應用包括RMQ問題,求逆序對等等。樹狀陣列用在RMQ問題需要在查詢函式改變一下,詳細情況見部落格RMQ問題。關於求逆序對數,將原數列從開始一個一個地加入元素到和它大小所對應的顛倒位置。因為某個較大的數先出現在數列中,所以它先被加入到樹狀陣列中,對較小的數(樹狀陣列中位置靠後的數)產生影響,從而達到統計逆序對的功能。隨著插入新數,順便求和,可得到逆序對數。例題洛谷P1774、P2678

#include<bits/stdc++.h>
using namespace std;
long long n,num[500005],loc[500005],tree[500005];
bool cmd(int a,int b)
{
    return num[a]==num[b]?a>b:num[a]>num[b];
}
int lowbit(int x)
{
    return x&(-x);
}
void add(int k,int v)
{
    while(k<=n){
        tree[k]+=v;
        k+=lowbit(k);
    }
}
int query(int k)
{
    int ans=0;
    while(k>0){
        ans+=tree[k];
        k-=lowbit(k);
    }
    return ans;
}
int main()
{
    scanf("%lld",&n);
    for(int i=1;i<=n;i++){
        scanf("%lld",&num[i]);
        loc[i]=i;
    }
    sort(loc+1,loc+n+1,cmd);
    long long ans=0;
    for(int i=1;i<=n;i++){
        ans+=query(loc[i]);
        add(loc[i],1);
    }
    printf("%lld\n",ans);
    return 0;
}

線段樹的核心在於push_down+lazytag,當然單點修改區間查詢和區間修改單點查詢這兩種情況是用不著的。藉助push_down和lazytag,線段樹可以處理更加複雜的區間維護問題。

對於單點查詢區間修改,線段樹上每個節點的sum是該區間內數字變動的字首。查詢時,從上到下降路過節點的sum加起來求一個字首和。

下方是完整的線段樹模板,包括區間查詢區間修改和乘除法。乘除法注意lazytag的操作與加法有所不同

/*
  完整的線段樹
*/
#include<bits/stdc++.h>
using namespace std;
long long n,m,p;
long long input[100005];
struct node
{
    long long l,r,sum,plz,mlz;
}tree[400005];

void build(long long i,long long l,long long r)
{
    tree[i].l=l;
    tree[i].r=r;
    tree[i].plz=0;
    tree[i].mlz=1;
    if(l==r){
        tree[i].sum=input[l]%p;
        return;
    }
    long long mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p;
}
inline void push_down(long long i)
{
    long long k1=tree[i].mlz;
    long long k2=tree[i].plz;
    tree[i<<1].mlz=(tree[i<<1].mlz*k1)%p;
    tree[i<<1|1].mlz=(tree[i<<1|1].mlz*k1)%p;
    tree[i<<1].plz=(tree[i<<1].plz*k1+k2)%p;
    tree[i<<1|1].plz=(tree[i<<1|1].plz*k1+k2)%p;
    tree[i<<1].sum=(tree[i<<1].sum*k1+k2*(tree[i<<1].r-tree[i<<1].l+1))%p;
    tree[i<<1|1].sum=(tree[i<<1|1].sum*k1+k2*(tree[i<<1|1].r-tree[i<<1|1].l+1))%p;
    tree[i].plz=0;
    tree[i].mlz=1;
}
void add(long long i,long long l,long long r,long long k)
{
    if(tree[i].l>=l&&tree[i].r<=r){
        tree[i].plz=(tree[i].plz+k)%p;
        tree[i].sum=(tree[i].sum+k*(tree[i].r-tree[i].l+1))%p;
        return;
    }
    push_down(i);
    if(tree[i<<1].r>=l){
        add(i<<1,l,r,k);
    }
    if(tree[i<<1|1].l<=r){
        add(i<<1|1,l,r,k);
    }
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p;
}
void mul(long long i,long long l,long long r,long long k)
{
    if(tree[i].l>=l&&tree[i].r<=r){
        tree[i].sum=(tree[i].sum*k)%p;
        tree[i].plz=(tree[i].plz*k)%p;
        tree[i].mlz=(tree[i].mlz*k)%p;
        return;
    }
    push_down(i);
    if(tree[i<<1].r>=l){
        mul(i<<1,l,r,k);
    }
    if(tree[i<<1|1].l<=r){
        mul(i<<1|1,l,r,k);
    }
    tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p;
}
long long getsum(int i,int l,int r)
{
    if(tree[i].l>=l&&tree[i].r<=r){
        return tree[i].sum;
    }
    if(tree[i].r<l||tree[i].l>r){
        return 0;
    }
    push_down(i);
    long long ans=0;
    if(tree[i<<1].r>=l){
        ans+=getsum(i<<1,l,r);
    }
    if(tree[i<<1|1].l<=r){
        ans+=getsum(i<<1|1,l,r);
    }
    return ans;
}

int main()
{
    scanf("%lld%lld%lld",&n,&m,&p);
    for(int i=1;i<=n;i++){
        long long temp;
        scanf("%lld",&temp);
        input[i]=temp%p;
    }
    build(1,1,n);
    long long s,x,y,k;
    while(m--){
        scanf("%lld",&s);
        if(s==1){
            scanf("%lld%lld%lld",&x,&y,&k);
            mul(1,x,y,k%p);
        }
        else if(s==2){
            scanf("%lld%lld%lld",&x,&y,&k);
            add(1,x,y,k%p);
        }
        else if(s==3){
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",getsum(1,x,y)%p);
        }
    }
    return 0;
}

(定位當前節點的左孩子和右孩子涉及到位運算,務必關注優先順序和括號。或者保險一點直接i*2,i*2+1)

一些優秀的線段樹題目中,線段樹並不是考察的主要目標,而是求解答案的資料結構。POJ2991,考察區間之間的向量轉向問題。

POJ2828,思路是關鍵,之後利用樹狀陣列動態維護區間和

POJ2777,線段樹每個節點的關鍵量變為一個狀態壓縮的值。查詢時對每個狀態壓縮的值,找有幾個狀態

POJ2886,和POJ2828類似,需要注意細節

POJ1151,掃描線+離散化+線段樹,求矩形並。將每個矩形拆為兩條線段,排序,等待掃描;將線段兩端離散化;開始掃描,利用線段樹找出當前情況下,線段覆蓋的總長度,乘高度差即可得解。這類題總是對整個區間查詢,於是就沒有必要建樹...

#include<stdio.h>
#include<algorithm>
#include<string.h>
using namespace std;
const int maxn=250;
struct seg{
    double x1,x2,y;
    int flag;
    bool operator <(const seg &A) const{
        return y<A.y;
    }
}node[maxn];
int col[maxn*4];
double rec[maxn],sum[maxn*4];

void pushup(int i,int l,int r){
    if(col[i]) sum[i]=rec[r+1]-rec[l];
    else if(l==r) sum[i]=0;
    else sum[i]=sum[i*2]+sum[i*2+1];
}
void update(int L,int R,int k,int l,int r,int i){
    if(l>=L&&r<=R){
        col[i]+=k;
        pushup(i,l,r);
        return;
    }
    int m=(l+r)/2;
    if(L<=m) update(L,R,k,l,m,i*2);
    if(R>m) update(L,R,k,m+1,r,i*2+1);
    pushup(i,l,r);
}
int main(){
    int cas=1;
    int n;
    while(scanf("%d",&n)!=EOF){
        if(n==0) break;
        int cnt=0;
        for(int i=1;i<=n;i++){
            double a,b,c,d;
            scanf("%lf%lf%lf%lf",&a,&b,&c,&d);
            node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=b;node[cnt].flag=1;rec[cnt]=a;cnt++;
            node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=d;node[cnt].flag=-1;rec[cnt]=c;cnt++;
        }
        sort(node,node+cnt);
        sort(rec,rec+cnt);

        memset(col,0,sizeof(col));
        memset(sum,0,sizeof(sum));
        double ans=0;
        for(int i=0;i<cnt-1;i++){
            int l=lower_bound(rec,rec+cnt,node[i].x1)-rec;
            int r=lower_bound(rec,rec+cnt,node[i].x2)-rec-1;
            if(l<=r) update(l,r,node[i].flag,0,cnt-1,1);
            ans+=sum[1]*(node[i+1].y-node[i].y);
        }
        printf("Test case #%d\n",cas++);
        printf("Total explored area: %.2f\n\n",ans);
    }
}

HDU1255,在上一題的基礎上,需要求的是被覆蓋兩次以上的區間長度,改進pushup。

void pushup(int i,int l,int r){
    if(col[i]>=2){
        s2[i]=s1[i]=rec[r+1]-rec[l];
    }
    else if(col[i]==1){
        s1[i]=rec[r+1]-rec[l];
        if(l==r) s2[i]=0;
        else s2[i]=s1[i*2]+s1[i*2+1];
    }
    else{
        if(l==r) s1[i]=s2[i]=0;
        else{
            s1[i]=s1[i*2]+s1[i*2+1];
            s2[i]=s2[i*2]+s2[i*2+1];
        }
    }
}

POJ1177,求矩陣並的周長。和求矩陣並比較類似,分兩次求和

#include<stdio.h>
#include<algorithm>
#include<cmath>
#include<string.h>
using namespace std;
const int maxn=10005;
struct seg{
    int x1,x2,y,flag;
    bool operator < (const seg &A) const{
        return y<A.y;
    }
}node1[maxn],node2[maxn];
int n,rec1[maxn],rec2[maxn],col[maxn*4],sum[maxn*4];

void pushup(int i,int l,int r,int f){
    if(col[i]){
        if(f) sum[i]=rec1[r+1]-rec1[l];
        else sum[i]=rec2[r+1]-rec2[l];
    }
    else if(l==r) sum[i]=0;
    else sum[i]=sum[i*2]+sum[i*2+1];
}
void update(int i,int l,int r,int k,int L,int R,int f){
    if(l>=L&&r<=R){
        col[i]+=k;
        pushup(i,l,r,f);
        return;
    }
    int m=(l+r)/2;
    if(m>=L) update(i*2,l,m,k,L,R,f);
    if(R>m) update(i*2+1,m+1,r,k,L,R,f);
    pushup(i,l,r,f);
}
int main(){
    while(scanf("%d",&n)!=EOF){
        int cnt=0;
        for(int i=0;i<n;i++){
            int a,b,c,d;
            scanf("%d%d%d%d",&a,&b,&c,&d);
            node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=b;node1[cnt].flag=1;rec1[cnt]=a;
            node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=a;node2[cnt].flag=1;rec2[cnt++]=b;
            node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=d;node1[cnt].flag=-1;rec1[cnt]=c;
            node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=c;node2[cnt].flag=-1;rec2[cnt++]=d;
        }

        sort(node1,node1+cnt);
        sort(rec1,rec1+cnt);
        memset(col,0,sizeof(col));
        memset(sum,0,sizeof(sum));
        int ans=0,last=0;
        for(int i=0;i<cnt;i++){
            int l=lower_bound(rec1,rec1+cnt,node1[i].x1)-rec1;
            int r=lower_bound(rec1,rec1+cnt,node1[i].x2)-rec1-1;
            if(l<=r) update(1,0,cnt-1,node1[i].flag,l,r,1);
            ans+=abs(sum[1]-last);
            //printf("%d %d %d %d %d\n",i,l,r,sum[1],last);
            last=sum[1];
        }
        //printf("%d\n",ans);

        sort(node2,node2+cnt);
        sort(rec2,rec2+cnt);
        memset(col,0,sizeof(col));
        memset(sum,0,sizeof(sum));
        last=0;
        for(int i=0;i<cnt;i++){
            int l=lower_bound(rec2,rec2+cnt,node2[i].x1)-rec2;
            int r=lower_bound(rec2,rec2+cnt,node2[i].x2)-rec2-1;
            if(l<=r) update(1,0,cnt-1,node2[i].flag,l,r,0);
            ans+=abs(sum[1]-last);
            last=sum[1];
        }
        printf("%d\n",ans);
    }
}