tj
2026-02-03 19:16:01
发布于:四川
1阅读
0回复
0点赞
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
const int N=5e4+10,mod=1e9+7;
struct segment_tree{
int lson,rson;
long long sum;
int val;
};
segment_tree hjt[400*N];
int root[N],p[N],v[N],sta[N];//sta就是垃圾回收的栈
int tmp[2][20];
int n,m,tot,top;
inline int lowbit(int x){
return x&-x;
}
void hjt_ins(int &now,int l,int r,int pos,int val){
if(!now)
now=top?sta[top--]:++tot;//垃圾回收
hjt[now].val++;
hjt[now].sum=(hjt[now].sum+val)%mod;
if(l==r)
return ;
int mid=l+r>>1;
if(pos<=mid)
hjt_ins(hjt[now].lson,l,mid,pos,val);
else
hjt_ins(hjt[now].rson,mid+1,r,pos,val);
}
void hjt_del(int &now,int l,int r,int pos,int val){
if(!now)
now=top?sta[top--]:++tot;//垃圾回收
hjt[now].val--;
hjt[now].sum=(hjt[now].sum-val)%mod;
if(l==r)
return ;
int mid=l+r>>1;
if(pos<=mid)
hjt_del(hjt[now].lson,l,mid,pos,val);
else
hjt_del(hjt[now].rson,mid+1,r,pos,val);
if(!hjt[now].val)//判断是否要扔进栈里
sta[++top]=now,now=0;
}
long long hjt_query1(int l,int r,int pos,int val){
if(l==r)
return 0;
long long key=0,sum=0;
for(int i=1;i<=tmp[0][0];i++)
key=(key+hjt[hjt[tmp[0][i]].rson].val)%mod;
for(int i=1;i<=tmp[1][0];i++)
key=(key-hjt[hjt[tmp[1][i]].rson].val)%mod;
for(int i=1;i<=tmp[0][0];i++)
sum=(sum+hjt[hjt[tmp[0][i]].rson].sum)%mod;
for(int i=1;i<=tmp[1][0];i++)
sum=(sum-hjt[hjt[tmp[1][i]].rson].sum)%mod;
int mid=l+r>>1;
if(pos<=mid){
for(int i=1;i<=tmp[0][0];i++)
tmp[0][i]=hjt[tmp[0][i]].lson;
for(int i=1;i<=tmp[1][0];i++)
tmp[1][i]=hjt[tmp[1][i]].lson;
return (key*val%mod+sum+hjt_query1(l,mid,pos,val))%mod;
}
else{
for(int i=1;i<=tmp[0][0];i++)
tmp[0][i]=hjt[tmp[0][i]].rson;
for(int i=1;i<=tmp[1][0];i++)
tmp[1][i]=hjt[tmp[1][i]].rson;
return hjt_query1(mid+1,r,pos,val);
}
}
int hjt_query2(int l,int r,int pos,int val){
if(l==r)
return 0;
long long key=0,sum=0;
for(int i=1;i<=tmp[0][0];i++)
key=(key+hjt[hjt[tmp[0][i]].lson].val)%mod;
for(int i=1;i<=tmp[1][0];i++)
key=(key-hjt[hjt[tmp[1][i]].lson].val)%mod;
for(int i=1;i<=tmp[0][0];i++)
sum=(sum+hjt[hjt[tmp[0][i]].lson].sum)%mod;
for(int i=1;i<=tmp[1][0];i++)
sum=(sum-hjt[hjt[tmp[1][i]].lson].sum)%mod;
int mid=l+r>>1;
if(pos<=mid){
for(int i=1;i<=tmp[0][0];i++)
tmp[0][i]=hjt[tmp[0][i]].lson;
for(int i=1;i<=tmp[1][0];i++)
tmp[1][i]=hjt[tmp[1][i]].lson;
return hjt_query2(l,mid,pos,val);
}
else{
for(int i=1;i<=tmp[0][0];i++)
tmp[0][i]=hjt[tmp[0][i]].rson;
for(int i=1;i<=tmp[1][0];i++)
tmp[1][i]=hjt[tmp[1][i]].rson;
return (key*val%mod+sum+hjt_query2(mid+1,r,pos,val))%mod;
}
}
void bit_ins(int now,int pos,int val){
for(int i=now;i<=n;i+=lowbit(i))
hjt_ins(root[i],1,n,pos,val);
}
void bit_del(int now,int pos,int val){
for(int i=now;i<=n;i+=lowbit(i))
hjt_del(root[i],1,n,pos,val);
}
long long bit_query1(int l,int r,int pos,int val){
//查询l~r中所有优先级>pos与val的贡献和
if(l>r)//特判
return 0;
for(int i=0;i<20;i++)
tmp[0][i]=tmp[1][i]=0;
for(int i=r;i;i-=lowbit(i))
tmp[0][++tmp[0][0]]=root[i];
for(int i=l-1;i;i-=lowbit(i))
tmp[1][++tmp[1][0]]=root[i];
return hjt_query1(1,n,pos,val);
}
long long bit_query2(int l,int r,int pos,int val){
//查询l~r中所有优先级Kpos与val的贡献和
if(l>r)//特判
return 0;
for(int i=0;i<20;i++)
tmp[0][i]=tmp[1][i]=0;
for(int i=r;i;i-=lowbit(i))
tmp[0][++tmp[0][0]]=root[i];
for(int i=l-1;i;i-=lowbit(i))
tmp[1][++tmp[1][0]]=root[i];
return hjt_query2(1,n,pos,val);
}
int main(){
cin>>n>>m;
long long ans=0;
for(int i=1;i<=n;i++){
cin>>p[i]>>v[i];
bit_ins(i,p[i],v[i]);//插入节点
ans+=bit_query1(1,i-1,p[i],v[i]);//加上新增的逆序对
ans%=mod;
}
int x,y;
for(int i=1;i<=m;i++){
cin>>x>>y;
if(x>y)//特判x>y的情况
swap(x,y);
if(x==y)//特判x=y的情况
goto s;//goto用法,表示跳到s处,但是最好不要用
ans-=bit_query2(x+1,y-1,p[x],v[x]);//减去原来的逆序对
ans%=mod;
ans+=bit_query1(x+1,y-1,p[x],v[x]);//加上新增的逆序对
ans%=mod;
ans-=bit_query1(x+1,y-1,p[y],v[y]);//减去原来的逆序对
ans%=mod;
ans+=bit_query2(x+1,y-1,p[y],v[y]);//加上新增的逆序对
ans%=mod;
if(p[x]>p[y])//原来p[x]与p[y]是逆序对
ans-=v[x]+v[y];//减去贡献
else//原来p[x]与p[y]不是逆序对
ans+=v[x]+v[y];//加上贡献
ans%=mod;
bit_del(x,p[x],v[x]);//在x处删除x
bit_del(y,p[y],v[y]);//在y处删除y
bit_ins(x,p[y],v[y]);//在x处插入y
bit_ins(y,p[x],v[x]);//在y处插入x
swap(p[x],p[y]);
swap(v[x],v[y]);
//以上步骤完成x与y的交换
s://goto就跳到这里
if(ans<0)//特判ans减成负数的情况
ans+=mod;
cout<<ans<<'\n';
}
}
这里空空如也







有帮助,赞一个