HDU 5193
題意:給出n個數的序列a,m個操作。
操作1:[x,y] 將y插入到第x個人之後.
操作2:[x],將第x個人刪除(x+1,..n向前進一格).
n,m,a[i]<=2e4. 問每次操作後序列a的逆序對(i,j)有多少? (i<j && a[i]>a[j] ).
假如當前逆序對為res,那麼插入一個數y之後 要知道[x+1,n]有多個比y小,[1..x-1]直接有多個比y大.
插入,刪除操作 如何處理下標?
此時用到一個叫 塊狀鏈表的東西,鏈表中每個元素是一個數組,
數組大小最多為2sqrt(n) 若超過2sqrt(n),則用到分裂操作.
若相鄰兩個表 元素個數<=sqrt(n) 則合并這兩個表. 塊狀鏈表的插入和刪除操作都是sqrt(n)滴。
現在對鏈表中的每一塊,套一個樹狀數組.
x之後有多少個比x大 則鏈表往後走 每走一個向其BIT查詢大於x的個數.
x之前有多少個比x小 則鏈表往前走.查詢每塊中小於x的個數.
然後對於塊內的元素,sqrt(n)暴力查詢即可.
#include <bits/stdc++.h>using namespace std;typedef pair<int,int> ii;const int N=2e4+5,m=320;int lowbit(int x){return x&-x;}void add(int c[],int x,int val){ for(int i=x;i<N;i+=lowbit(i)) c[i]+=val;}int sum(int c[],int l,int r){ int sum1=0,sum2=0; while(l>0) { sum1+=c[l]; l-=lowbit(l); } while(r>0) { sum2+=c[r]; r-=lowbit(r); } return sum2-sum1;}struct data{ int s,a[N*2]; data *next; int c[N]; data() { memset(c,0,sizeof(c)); next=NULL; }};data *root;void insert(int x,int pos){ if(root==NULL) { root=new data; root->s=1; root->a[1]=x; add(root->c,x,1);// return; } data *k=root; while(pos> k->s && k->next!=NULL) { pos-=k->s; k=k->next; } memmove(k->a+pos+1,k->a+pos,sizeof(int)*(k->s-pos+1)); k->s++; k->a[pos]=x; add(k->c,x,1); //split if(k->s==2*m) { data *t=new data; t->next=k->next; k->next=t; memcpy(t->a+1,k->a+m+1,sizeof(int)*m); for(int i=1;i<=m;i++) { add(k->c,t->a[i],-1); add(t->c,t->a[i],1); } t->s=k->s=m; }}int find(int pos){ data *k=root; while(pos>k->s && k->next!=NULL) { pos-=k->s; k=k->next; } return k->a[pos];}int work(int pos){ int res=0; data *k=root; int x=find(pos); while(pos>k->s && k->next!=NULL) { pos-=k->s; res+=sum(k->c,x,N);//large than x k=k->next; } for(int i=1;i<pos;i++) if(k->a[i]>x) res++; for(int i=pos+1;i<=k->s;i++) if(k->a[i]<x) res++; while(k->next!=NULL) { k=k->next; res+=sum(k->c,0,x-1); } return res;}void destroy(data *k){ if(k->next!=NULL) destroy(k->next); delete k;}void del(int pos){ data *k=root; while(pos>k->s&&k->next!=NULL) { pos-=k->s; k=k->next; } add(k->c,k->a[pos],-1); memmove(k->a+pos,k->a+pos+1,sizeof(int)*(k->s -pos)); k->s--;}int main(){ int n,p; while(~scanf("%d%d",&n,&p)) { root=NULL; int ans=0,x; for(int i=1;i<=n;i++) { scanf("%d",&x); insert(x,i); ans+=work(i); } while(p--) { int q,x,y; scanf("%d",&q); if(q==0) { scanf("%d %d",&x,&y); x++; insert(y,x); ans+=work(x); } else { scanf("%d",&x); ans-=work(x); del(x); } printf("%d\n",ans); } destroy(root); } return 0;}