对每个i求出以它结尾的[1,i]中的LIS长度f1与个数g1,和以它开头的[i,n]中的LIS长度f2与个数g2,若f1+f2-1=整个数列的LIS长度,那么它出现在LIS中的概率就是g1*g2/整个数列LIS的个数。发现可以用线段树优化朴素DP转移,维护下区间最大值和最大值的个数即可快速求出f和g。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #include<iostream> 5 #define ls (x<<1) 6 #define rs (ls|1) 7 #define lson ls,L,mid 8 #define rson rs,mid+1,R 9 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 10 typedef long long ll; 11 using namespace std; 12 13 const int N=1000010,mod=998244353; 14 int n,tot,res,sm,a[N],b[N],f1[N],f2[N],g1[N],g2[N],v1[N<<2],v2[N<<2]; 15 16 int ksm(int a,int b){ 17 int res=1; 18 for (; b; a=1ll*a*a%mod,b>>=1) 19 if (b & 1) res=1ll*res*a%mod; 20 return res; 21 } 22 23 void upd(int x,int y,int &r1,int &r2){ 24 if (x>r1) r1=x,r2=y; else if (x==r1) r2=(r2+y)%mod; 25 } 26 27 void build(int x,int L,int R){ 28 v1[x]=v2[x]=0; 29 if (L==R) return; 30 int mid=(L+R)>>1; build(lson); build(rson); 31 } 32 33 void mdf(int x,int L,int R,int p,int k1,int k2){ 34 if (L==R){ upd(k1,k2,v1[x],v2[x]); return; } 35 upd(k1,k2,v1[x],v2[x]); 36 int mid=(L+R)>>1; 37 if (p<=mid) mdf(lson,p,k1,k2); else mdf(rson,p,k1,k2); 38 } 39 40 void que(int x,int L,int R,int l,int r,int &r1,int &r2){ 41 if (L==l && r==R){ upd(v1[x],v2[x],r1,r2); return; } 42 int mid=(L+R)>>1; 43 if (r<=mid) que(lson,l,r,r1,r2); 44 else if (l>mid) que(rson,l,r,r1,r2); 45 else que(lson,l,mid,r1,r2),que(rson,mid+1,r,r1,r2); 46 } 47 48 int main(){ 49 freopen("a.in","r",stdin); 50 freopen("a.out","w",stdout); 51 scanf("%d",&n); 52 rep(i,1,n) scanf("%d",&a[i]),tot=max(tot,a[i]),b[i]=a[i]; 53 sort(b+1,b+n+1); tot=unique(b+1,b+n+1)-b-1; 54 rep(i,1,n) a[i]=lower_bound(b+1,b+tot+1,a[i])-b; 55 build(1,0,tot+1); mdf(1,0,tot+1,0,0,1); 56 rep(i,1,n) que(1,0,tot+1,0,a[i]-1,f1[i],g1[i]),f1[i]++,mdf(1,0,tot+1,a[i],f1[i],g1[i]); 57 build(1,0,tot+1); mdf(1,0,tot+1,tot+1,0,1); 58 for (int i=n; i; i--) que(1,0,tot+1,a[i]+1,tot+1,f2[i],g2[i]),f2[i]++,mdf(1,0,tot+1,a[i],f2[i],g2[i]); 59 que(1,0,tot+1,0,tot+1,res,sm); sm=ksm(sm,mod-2); 60 rep(i,1,n) if (f1[i]+f2[i]-1==res) printf("%lld ",1ll*g1[i]*g2[i]%mod*sm%mod); else printf("0 "); 61 return 0; 62 }