資料結構:樹狀陣列、線段樹
樹狀陣列/線段樹都可以把原來樸素的O(n2)變為O(n*logn),用於高效計算數列的字首和。具體主要表現為3種情況:區間修改單點查詢;單點修改區間查詢;區間修改區間查詢,這3種情況是一個遞進關係,理解規律之後就比較好記。
樹狀陣列的具體原理見https://www.cnblogs.com/xenny/p/9739600.html,這裡就不詳細描述了.....樹狀陣列的顯著特點就是藉助lowbit來確定修改或者查詢的位置。如下是區間修改區間查詢的完整板子程式碼:
#include<bits/stdc++.h> using namespace std; int n,m; int s1[500005],s2[500005],a[500005]; int lowbit(int x) { return x&(-x); } void update(int i,int k) { int x=i; while(i<=n){ s1[i]+=k; s2[i]+=k*(x-1); //注意 i+=lowbit(i); } } int getSum(int i) { int res=0,x=i; while(i>0){ res+=s1[i]*x-s2[i]; //注意 i-=lowbit(i); }return res; } int main() { scanf("%d%d",&n,&m); memset(a,0,sizeof a); memset(s1,0,sizeof s1); memset(s2,0,sizeof s2); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); update(i,a[i]-a[i-1]); //完全版樹狀陣列在構建時輸入a[i]-a[i-1] } int s,x,y,k; while(m--){ scanf("%d",&s); if(s==1){ scanf("%d%d%d",&x,&y,&k); update(x,k); update(y+1,-k); } else if(s==2){ scanf("%d",&x); printf("%d\n",getSum(x)-getSum(x-1)); } } return 0; }
樹狀陣列的應用包括RMQ問題,求逆序對等等。樹狀陣列用在RMQ問題需要在查詢函式改變一下,詳細情況見部落格RMQ問題。關於求逆序對數,將原數列從開始一個一個地加入元素到和它大小所對應的顛倒位置。因為某個較大的數先出現在數列中,所以它先被加入到樹狀陣列中,對較小的數(樹狀陣列中位置靠後的數)產生影響,從而達到統計逆序對的功能。隨著插入新數,順便求和,可得到逆序對數。例題洛谷P1774、P2678
#include<bits/stdc++.h> using namespace std; long long n,num[500005],loc[500005],tree[500005]; bool cmd(int a,int b) { return num[a]==num[b]?a>b:num[a]>num[b]; } int lowbit(int x) { return x&(-x); } void add(int k,int v) { while(k<=n){ tree[k]+=v; k+=lowbit(k); } } int query(int k) { int ans=0; while(k>0){ ans+=tree[k]; k-=lowbit(k); } return ans; } int main() { scanf("%lld",&n); for(int i=1;i<=n;i++){ scanf("%lld",&num[i]); loc[i]=i; } sort(loc+1,loc+n+1,cmd); long long ans=0; for(int i=1;i<=n;i++){ ans+=query(loc[i]); add(loc[i],1); } printf("%lld\n",ans); return 0; }
線段樹的核心在於push_down+lazytag,當然單點修改區間查詢和區間修改單點查詢這兩種情況是用不著的。藉助push_down和lazytag,線段樹可以處理更加複雜的區間維護問題。
對於單點查詢區間修改,線段樹上每個節點的sum是該區間內數字變動的字首。查詢時,從上到下降路過節點的sum加起來求一個字首和。
下方是完整的線段樹模板,包括區間查詢區間修改和乘除法。乘除法注意lazytag的操作與加法有所不同
/* 完整的線段樹 */ #include<bits/stdc++.h> using namespace std; long long n,m,p; long long input[100005]; struct node { long long l,r,sum,plz,mlz; }tree[400005]; void build(long long i,long long l,long long r) { tree[i].l=l; tree[i].r=r; tree[i].plz=0; tree[i].mlz=1; if(l==r){ tree[i].sum=input[l]%p; return; } long long mid=(l+r)>>1; build(i<<1,l,mid); build(i<<1|1,mid+1,r); tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } inline void push_down(long long i) { long long k1=tree[i].mlz; long long k2=tree[i].plz; tree[i<<1].mlz=(tree[i<<1].mlz*k1)%p; tree[i<<1|1].mlz=(tree[i<<1|1].mlz*k1)%p; tree[i<<1].plz=(tree[i<<1].plz*k1+k2)%p; tree[i<<1|1].plz=(tree[i<<1|1].plz*k1+k2)%p; tree[i<<1].sum=(tree[i<<1].sum*k1+k2*(tree[i<<1].r-tree[i<<1].l+1))%p; tree[i<<1|1].sum=(tree[i<<1|1].sum*k1+k2*(tree[i<<1|1].r-tree[i<<1|1].l+1))%p; tree[i].plz=0; tree[i].mlz=1; } void add(long long i,long long l,long long r,long long k) { if(tree[i].l>=l&&tree[i].r<=r){ tree[i].plz=(tree[i].plz+k)%p; tree[i].sum=(tree[i].sum+k*(tree[i].r-tree[i].l+1))%p; return; } push_down(i); if(tree[i<<1].r>=l){ add(i<<1,l,r,k); } if(tree[i<<1|1].l<=r){ add(i<<1|1,l,r,k); } tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } void mul(long long i,long long l,long long r,long long k) { if(tree[i].l>=l&&tree[i].r<=r){ tree[i].sum=(tree[i].sum*k)%p; tree[i].plz=(tree[i].plz*k)%p; tree[i].mlz=(tree[i].mlz*k)%p; return; } push_down(i); if(tree[i<<1].r>=l){ mul(i<<1,l,r,k); } if(tree[i<<1|1].l<=r){ mul(i<<1|1,l,r,k); } tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%p; } long long getsum(int i,int l,int r) { if(tree[i].l>=l&&tree[i].r<=r){ return tree[i].sum; } if(tree[i].r<l||tree[i].l>r){ return 0; } push_down(i); long long ans=0; if(tree[i<<1].r>=l){ ans+=getsum(i<<1,l,r); } if(tree[i<<1|1].l<=r){ ans+=getsum(i<<1|1,l,r); } return ans; } int main() { scanf("%lld%lld%lld",&n,&m,&p); for(int i=1;i<=n;i++){ long long temp; scanf("%lld",&temp); input[i]=temp%p; } build(1,1,n); long long s,x,y,k; while(m--){ scanf("%lld",&s); if(s==1){ scanf("%lld%lld%lld",&x,&y,&k); mul(1,x,y,k%p); } else if(s==2){ scanf("%lld%lld%lld",&x,&y,&k); add(1,x,y,k%p); } else if(s==3){ scanf("%lld%lld",&x,&y); printf("%lld\n",getsum(1,x,y)%p); } } return 0; }
(定位當前節點的左孩子和右孩子涉及到位運算,務必關注優先順序和括號。或者保險一點直接i*2,i*2+1)
一些優秀的線段樹題目中,線段樹並不是考察的主要目標,而是求解答案的資料結構。POJ2991,考察區間之間的向量轉向問題。
POJ2828,思路是關鍵,之後利用樹狀陣列動態維護區間和
POJ2777,線段樹每個節點的關鍵量變為一個狀態壓縮的值。查詢時對每個狀態壓縮的值,找有幾個狀態
POJ2886,和POJ2828類似,需要注意細節
POJ1151,掃描線+離散化+線段樹,求矩形並。將每個矩形拆為兩條線段,排序,等待掃描;將線段兩端離散化;開始掃描,利用線段樹找出當前情況下,線段覆蓋的總長度,乘高度差即可得解。這類題總是對整個區間查詢,於是就沒有必要建樹...
#include<stdio.h> #include<algorithm> #include<string.h> using namespace std; const int maxn=250; struct seg{ double x1,x2,y; int flag; bool operator <(const seg &A) const{ return y<A.y; } }node[maxn]; int col[maxn*4]; double rec[maxn],sum[maxn*4]; void pushup(int i,int l,int r){ if(col[i]) sum[i]=rec[r+1]-rec[l]; else if(l==r) sum[i]=0; else sum[i]=sum[i*2]+sum[i*2+1]; } void update(int L,int R,int k,int l,int r,int i){ if(l>=L&&r<=R){ col[i]+=k; pushup(i,l,r); return; } int m=(l+r)/2; if(L<=m) update(L,R,k,l,m,i*2); if(R>m) update(L,R,k,m+1,r,i*2+1); pushup(i,l,r); } int main(){ int cas=1; int n; while(scanf("%d",&n)!=EOF){ if(n==0) break; int cnt=0; for(int i=1;i<=n;i++){ double a,b,c,d; scanf("%lf%lf%lf%lf",&a,&b,&c,&d); node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=b;node[cnt].flag=1;rec[cnt]=a;cnt++; node[cnt].x1=a;node[cnt].x2=c;node[cnt].y=d;node[cnt].flag=-1;rec[cnt]=c;cnt++; } sort(node,node+cnt); sort(rec,rec+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); double ans=0; for(int i=0;i<cnt-1;i++){ int l=lower_bound(rec,rec+cnt,node[i].x1)-rec; int r=lower_bound(rec,rec+cnt,node[i].x2)-rec-1; if(l<=r) update(l,r,node[i].flag,0,cnt-1,1); ans+=sum[1]*(node[i+1].y-node[i].y); } printf("Test case #%d\n",cas++); printf("Total explored area: %.2f\n\n",ans); } }
HDU1255,在上一題的基礎上,需要求的是被覆蓋兩次以上的區間長度,改進pushup。
void pushup(int i,int l,int r){ if(col[i]>=2){ s2[i]=s1[i]=rec[r+1]-rec[l]; } else if(col[i]==1){ s1[i]=rec[r+1]-rec[l]; if(l==r) s2[i]=0; else s2[i]=s1[i*2]+s1[i*2+1]; } else{ if(l==r) s1[i]=s2[i]=0; else{ s1[i]=s1[i*2]+s1[i*2+1]; s2[i]=s2[i*2]+s2[i*2+1]; } } }
POJ1177,求矩陣並的周長。和求矩陣並比較類似,分兩次求和
#include<stdio.h> #include<algorithm> #include<cmath> #include<string.h> using namespace std; const int maxn=10005; struct seg{ int x1,x2,y,flag; bool operator < (const seg &A) const{ return y<A.y; } }node1[maxn],node2[maxn]; int n,rec1[maxn],rec2[maxn],col[maxn*4],sum[maxn*4]; void pushup(int i,int l,int r,int f){ if(col[i]){ if(f) sum[i]=rec1[r+1]-rec1[l]; else sum[i]=rec2[r+1]-rec2[l]; } else if(l==r) sum[i]=0; else sum[i]=sum[i*2]+sum[i*2+1]; } void update(int i,int l,int r,int k,int L,int R,int f){ if(l>=L&&r<=R){ col[i]+=k; pushup(i,l,r,f); return; } int m=(l+r)/2; if(m>=L) update(i*2,l,m,k,L,R,f); if(R>m) update(i*2+1,m+1,r,k,L,R,f); pushup(i,l,r,f); } int main(){ while(scanf("%d",&n)!=EOF){ int cnt=0; for(int i=0;i<n;i++){ int a,b,c,d; scanf("%d%d%d%d",&a,&b,&c,&d); node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=b;node1[cnt].flag=1;rec1[cnt]=a; node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=a;node2[cnt].flag=1;rec2[cnt++]=b; node1[cnt].x1=a;node1[cnt].x2=c;node1[cnt].y=d;node1[cnt].flag=-1;rec1[cnt]=c; node2[cnt].x1=b;node2[cnt].x2=d;node2[cnt].y=c;node2[cnt].flag=-1;rec2[cnt++]=d; } sort(node1,node1+cnt); sort(rec1,rec1+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); int ans=0,last=0; for(int i=0;i<cnt;i++){ int l=lower_bound(rec1,rec1+cnt,node1[i].x1)-rec1; int r=lower_bound(rec1,rec1+cnt,node1[i].x2)-rec1-1; if(l<=r) update(1,0,cnt-1,node1[i].flag,l,r,1); ans+=abs(sum[1]-last); //printf("%d %d %d %d %d\n",i,l,r,sum[1],last); last=sum[1]; } //printf("%d\n",ans); sort(node2,node2+cnt); sort(rec2,rec2+cnt); memset(col,0,sizeof(col)); memset(sum,0,sizeof(sum)); last=0; for(int i=0;i<cnt;i++){ int l=lower_bound(rec2,rec2+cnt,node2[i].x1)-rec2; int r=lower_bound(rec2,rec2+cnt,node2[i].x2)-rec2-1; if(l<=r) update(1,0,cnt-1,node2[i].flag,l,r,0); ans+=abs(sum[1]-last); last=sum[1]; } printf("%d\n",ans); } }