Description
Solution
一道很简单的题,自己没有想出来,相当不应该
一开始我先想的假如直接询问w(i,j)怎么做,但是完全没有思路(后来发现询问w(i,j)比这题难的多)
首先那个排列a似乎没什么用,那么以下的表示都忽略掉(把x看做a[x])
先化一波式子,可以得到这个结果
我们令,假如求出了Si,再做一个Si的前缀和和Si*i的前缀和就可以算出ans_k了,这是线性的。
问题在于如何求Si
一种方法是轻重链剖分
我们将两点距离拆成的形式
那么有
其中dep[x]*x以及dep[i]*x的和都是容易计算的
现在就是要算
考虑dep[lca]等于什么,它显然可以看做是lca到根路径上的点的个数
从左到右扫,每扫到一个点i,就在它祖先到根的链上每个节点打上+i的标记,那么后面的点求dep[lca]就只需要祖先到根的路径求和即可,这个就用轻重链剖分+线段树维护。
另一种方法是点分治。
考虑直接计算所有的
对于每个分治中心,将分治子树中的所有点拉出来排序,从前向后一个个加,这个是容易计算的。
此时还要减掉同一个子树中的,那么再分别将每个子树同样做一遍减去即可。
Code
#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 200005
#define LL long long
#define mo 998244353
using namespace std;
int fs[N],nt[2*N],dt[2*N],n,q,a[N],m1,top[N],sz[N],son[N],dfn[N],n1,t[N][2],dep[N],ft[N];
LL s1[N],s2[N],sm[N],lz[N];
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
void dfs(int k,int fa)
{
sz[k]=1;
ft[k]=fa;
dep[k]=dep[fa]+1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa)
{
dfs(p,k);
sz[k]+=sz[p];
if(sz[p]>sz[son[k]]) son[k]=p;
}
}
}
void make(int k,int fa)
{
dfn[k]=++dfn[0];
if(son[k]) top[son[k]]=top[k],make(son[k],k);
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa&&p!=son[k]) top[p]=p,make(p,k);
}
}
void build(int k,int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
t[k][0]=++n1,build(n1,l,mid);
t[k][1]=++n1,build(n1,mid+1,r);
}
void down(int k,LL le,LL re)
{
if(lz[k])
{
lz[t[k][0]]=(lz[t[k][0]]+lz[k])%mo;
lz[t[k][1]]=(lz[t[k][1]]+lz[k])%mo;
sm[t[k][0]]=(sm[t[k][0]]+lz[k]*le)%mo;
sm[t[k][1]]=(sm[t[k][1]]+lz[k]*re)%mo;
lz[k]=0;
}
}
void add(int k,int l,int r,int x,int y,LL v)
{
if(x>y||y<l||x>r) return;
if(x<=l&&r<=y) lz[k]=(lz[k]+v)%mo,sm[k]=(sm[k]+(LL)(r-l+1)*v)%mo;
else
{
int mid=(l+r)>>1;
down(k,mid-l+1,r-mid);
add(t[k][0],l,mid,x,y,v);
add(t[k][1],mid+1,r,x,y,v);
sm[k]=(sm[t[k][0]]+sm[t[k][1]])%mo;
}
}
LL query(int k,int l,int r,int x,int y)
{
if(x>y||y<l||x>r||!sm[k]) return 0;
if(x<=l&&r<=y) return sm[k];
int mid=(l+r)>>1;
down(k,mid-l+1,r-mid);
return (query(t[k][0],l,mid,x,y)+query(t[k][1],mid+1,r,x,y))%mo;
}
void put(int k,LL v)
{
while(k)
{
add(1,1,n,dfn[top[k]],dfn[k],v);
k=ft[top[k]];
}
}
LL get(int k)
{
LL s=0;
while(k)
{
s=(s+query(1,1,n,dfn[top[k]],dfn[k]))%mo;
k=ft[top[k]];
}
return s;
}
int main()
{
cin>>n>>q;
bool pd=1;
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
fo(i,1,n) scanf("%d",&a[i]);
dfs(1,0);
top[1]=1;
make(1,0);
n1=1;
build(1,1,n);
LL sp=0,si=0;
fo(i,1,n)
{
s1[i]=(si*(LL)dep[a[i]]%mo+sp-(LL)2*get(a[i])+mo+mo)%mo;
put(a[i],i);
s2[i]=(s2[i-1]+s1[i]*(LL)i)%mo;
s1[i]=(s1[i-1]+s1[i])%mo;
si=(si+i)%mo;
sp=(sp+(LL)i*(LL)dep[a[i]])%mo;
}
fo(i,1,q)
{
int x;
scanf("%d",&x);
printf("%lld\n",(s1[x]*(LL)(x+1)-s2[x]+mo)%mo);
}
}