https://www.luogu.com.cn/problem/P3246
解法一:莫队+ST表+单调栈
考虑如何由[L,R]的答案推向[L,R+1]的答案
[L,R]向[L,R+1],增加的是[L,R+1] [L+1,R+1] [L+2,R+1] …… [R,R+1] [R,R] 这些区间每个区间的最小值
设[L,R+1]最小的数是a[p]
那么区间左端点在[L,p]之间的这p-L+1个区间的最小值都是a[p]
他们的贡献是a[p]*(p-L+1)
还剩下区间左端点在[p+1,R+1]之间的区间
设f[i]表示以i为右端点,[1,i]中的数为左端点的所有区间的贡献
f[R+1]包含的是 [1,R+1] [2,R+1] [3,R+1]……[R,R+1] [R+1,R+1]
f[p]包含的是[1,p] [2,p] [3,p] …… [p-1,p] [p,p]
因为a[p]是区间[L,R+1]的最小值
所以p<=R+1,且当i<=p时,[i,p]最小值等于[i,R+1]最小值
所以f[R+1]-f[p] 包含的是[p+1,R+1] [p+2,R+1] ……[R,R+1] [R+1,R+1],即我们需要得到的区间左端点在[p+1,R+1]之间的区间的贡献
由[L,R]的答案推向[L,R-1]就根据上面算R的贡献减去即可
移动左端点同理,f[i]就表示以i为左端点,[i,n]中的数为右端点的所有区间的贡献
有一点要注意一下,我才用的莫队写法要先动右端点,涉及到区间查询,不允许出现L>R的情况
#include<bits/stdc++.h> using namespace std; #define N 100003 typedef long long LL; int S; int a[N]; int l[N],r[N]; int s[N],top; LL fl[N],fr[N]; int lg2[N],st[N][17],stp[N][17]; long long ans; int L=1,R; struct node { int l,r,id,pos; long long out; }e[N]; bool cmp(node p,node q) { if(p.pos!=q.pos) return p.pos<q.pos; return p.r<q.r; } bool cmp2(node p,node q) { return p.id<q.id; } int query(int l,int r) { int len=lg2[r-l+1]; if(st[l][len]<st[r-(1<<len)+1][len]) return stp[l][len]; return stp[r-(1<<len)+1][len]; } void updater(int pos,int ty) { int p=query(L,pos); long long tmp=1ll*(p-L+1)*a[p]+fl[pos]-fl[p]; ans+=tmp*ty; } void updatel(int pos,int ty) { int p=query(pos,R); long long tmp=1ll*(R-p+1)*a[p]+fr[pos]-fr[p]; ans+=tmp*ty; } int main() { int n,m; scanf("%d%d",&n,&m); for(int i=1;i<=n;++i) scanf("%d",&a[i]); for(int i=2;i<=n;++i) lg2[i]=lg2[i>>1]+1; for(int i=1;i<=n;++i) { st[i][0]=a[i]; stp[i][0]=i; } for(int i=1;1<<i<=n;++i) for(int j=1;j+(1<<i)-1<=n;++j) { st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]); if(st[j][i]==st[j][i-1]) stp[j][i]=stp[j][i-1]; else stp[j][i]=stp[j+(1<<i-1)][i-1]; } a[0]=-2e9; for(int i=1;i<=n;++i) { while(a[i]<=a[s[top]]) top--; l[i]=s[top]; fl[i]=fl[l[i]]+1ll*(i-l[i])*a[i]; s[++top]=i; } a[n+1]=-2e9; s[top=0]=n+1; for(int i=n;i;--i) { while(a[i]<=a[s[top]]) top--; r[i]=s[top]; fr[i]=fr[r[i]]+1ll*(r[i]-i)*a[i]; s[++top]=i; } S=sqrt(n); for(int i=1;i<=m;++i) { scanf("%d%d",&e[i].l,&e[i].r); e[i].id=i; e[i].pos=(e[i].l-1)/S+1; } sort(e+1,e+m+1,cmp); for(int i=1;i<=m;++i) { while(R<e[i].r) updater(++R,1); while(R>e[i].r) updater(R--,-1); while(L<e[i].l) updatel(L++,-1); while(L>e[i].l) updatel(--L,1); e[i].out=ans; } sort(e+1,e+m+1,cmp2); for(int i=1;i<=m;++i) printf("%lld\n",e[i].out); }