大致题意:给你一棵包含n个节点的无根树,然后q个询问,每个询问给出一系列点和m、r。表示问你,把根设置为r的情况下,把给定点分为至多m部分,要求每一部分至少一个给定点,且一个部分中不能出现两个给定点一个点是另一个点的祖先,问划分的方案数。
还是一样,我们先简化问题,如果这时一个有根树,然后每次是固定根去询问。那么,令dp[i][j]表示只考虑给定点的前i个点,分为j个部分的合法方案数。我们可以写出状态转移方程:dp[i][j]=dp[i-1][j-1]+(j-f[i])*dp[i-1][j]。其中f[i]表示在之前的i-1个给定点中,为i的祖先的点的个数。这个也很容易理解,对于一个新的给定点,要么独自一个部分,要么和之前的一些可用与它一个部分的点一个部分。
那么问题的关键就是如何求这个f[i]。对于固定根的树来说,很简单,直接统计某个点到根的路径上的点中有多少个给定点即可。对于无根树来说,其实也是类似。原本是统计从根到它的路径,现在的话就是统计一个指定点到它的路径。为了实现这个,我们完全可以认为定根,然后用LCA来解决。这样,对于无根树的情况,我们一样也可以求出这个f[i]。
然后就是我们dp的顺序的问题。因为如果一个后代节点比祖先节点先考虑的话,显然会多算一些方案,所以我们dp的顺序也是必须遵照先祖先后后代的顺序。但是如果严格遵照这个顺序,我们说了,这个根是不确定的,实现起来很难。不过,我们发现,这个f[i]的大小基本上可以判定两个点dp的先后关系。因为后代的f[i]的值肯定比它的祖先的要大,所以我们对f[i]进行排序,按照这个顺序进行dp。
最后的答案就是所有的部分数之和。具体见代码:
#include <bits/stdc++.h>
#define INF 1e18
#define LL long long
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
const int N = 100010;
const int mod = 1e9 + 7;
int l[N],r[N],dep[N],h[N],a[N],c[N],index;
int dp[N],f[N][20],n,q;
vector<int> g[N];
bool v[N];
void dfs(int x,int fa)
{
l[x]=++index;
for(int i=0;i<g[x].size();i++)
{
int y=g[x][i];
if (y==fa) continue;
dep[y]=dep[x]+1;
f[y][0]=x; dfs(y,x);
}
r[x]=index;
}
void ST()
{
for(int j=0;j<17;j++)
for(int i=1;i<=n;i++)
f[i][j+1]=f[f[i][j]][j];
}
inline int lca(int u,int v)
{
if (u==v) return u;
if (dep[v]>dep[u]) swap(u,v);
for(int i=18;i>=0;i--)
if (dep[f[u][i]]>=dep[v]) u=f[u][i];
if (u==v) return u;
for(int i=18;i>=0;i--)
if (f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
inline void update(int x,int y)
{
for(int i=x;i<N;i+=i&-i)
c[i]+=y;
}
inline int getsum(int x)
{
int res=0;
for(int i=x;i;i-=i&-i)
res+=c[i];
return res;
}
int main()
{
scc(n,q);
for(int i=1;i<n;i++)
{
int x,y;
scc(x,y);
g[x].push_back(y);
g[y].push_back(x);
}
dep[1]=1;
dfs(1,-1); ST();
while(q--)
{
int k,m,rt;
sccc(k,m,rt);
for(int i=1;i<=k;i++)
{
sc(a[i]);
v[a[i]]=1;
update(l[a[i]],1);
update(r[a[i]]+1,-1);
}
int root=getsum(l[rt]);
for(int i=1;i<=k;i++)
{
int LCA=lca(rt,a[i]);
h[i]=getsum(l[a[i]])+root-2*getsum(l[LCA])+(v[LCA]==1)-1;
}
sort(h+1,h+1+k);
for(int i=1;i<=k;i++)
{
v[a[i]]=0;
update(l[a[i]],-1);
update(r[a[i]]+1,1);
}
memset(dp,0,sizeof(dp));
dp[0]=1;
if (h[k]>=m) printf("0\n");
else
{
for(int i=1;i<=k;i++)
{
for(int j=min(i,m);j>=0;j--)
if (j<=h[i]) dp[j]=0;
else dp[j]=((LL)dp[j]*(j-h[i])%mod+dp[j-1])%mod;
}
LL ans=0;
for(int i=1;i<=m;i++)
ans=(ans+dp[i])%mod;
printf("%lld\n",ans);
}
}
return 0;
}