今天才发现自己根本不会树形背包,我太菜了。

  一般的树形背包是这样做的:

树形依赖背包的两种做法

  看上去,它的复杂度是 $O(nk^2)$ 的。

 

第一种优化:

  这里,如果第二维的大小和子树大小有关,同时又不超过一个常数 $k$ 。例如:第二维表示子树内选了多少个点,那么通过一些精妙的分析和上界优化,复杂度就可以变成 $O(nk)$ 了。

  以下的 $siz_x$ 表示合并 $son$ 这个子树前 $x$ 子树的大小(注意:不是 $x$ 的真实子树大小,这里很重要)。

树形依赖背包的两种做法

  这样分析出来的复杂度就是 $O(nk)$ .


  证明:摘自这里

  首先,定义 $T(n)$ 为处理 $n$ 这棵子树时所用的时间,$f(n)$ 为处理 $n$ 这个点时所用的时间。

  $T(x)=\left(\sum_{f_y=x} T_{y}\right)+f(x)\\f(x)=\min(m,siz(y_1))\times \min(m,siz(y_1))+\min(m,siz(y_1)+siz(y_2))\times \min(m,siz(y_1))\\ ~~~~~~~~~~~+\cdots+\min(m,siz(x))\times \min(m,siz(y_n))$

  现在进行一番放缩,把每个乘法的前一项统一变成 $\min(m,siz(x))$ ,这样显然只会使答案变大,所以分析出来的复杂度上界就应该是正确的。

  $f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} \min(m,siz(y))\right)$

  再次放缩,把后面括号里的 $min$ 直接扔掉,得:

  $f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} siz(y)\right)\\~~~~~~~~=\min(m,siz(x))\times siz(x)$

  对于 $siz(x)<m$ 的点,首先考虑他的子树都是叶子的情况:

  $T(x)=siz(x)^2+\sum 1$

  对于任意 $siz(x)<m$ 的点,递归证明,由于 “平方和小于和的平方” ,所以 $T(x)$ 与 $siz(x)^2$ 同阶;

  对于 $siz(x)>m$ 的点,首先考虑它的所有子树都小于 $m$ 的情况:

  $T(x)=m\times siz(x)+\sum siz(j)^2$

  接着放缩可得,$T(x)$ 与 $m\times siz(x)$ 同阶;

  继续使用递归证明的技巧,考虑某一层出现了子树大于 $m$ 的情况:

  $T(x)=m\times siz(x)+\sum siz(j)^2+\sum siz(j)\times m$

  所以,$T(x)$ 还是与 $m\times siz(x)$ 同阶;

  综上所述,这种做法的复杂度是 $n\times k$ 。


  选课加强版:https://www.luogu.org/problem/U53204

  
 1 # include <cstdio>
 2 # include <iostream>
 3 # include <cstring>
 4 # include <vector>
 5 # define R register int
 6 
 7 using namespace std;
 8 
 9 const int N=100005;
10 struct edge
11 {
12     int to,nex;
13 };
14 int si,h=0,n,m;
15 edge g[N<<1];
16 int firs[N],a[N],siz[N];
17 bool vis[N];
18 int dp[100000100];
19 
20 void add(int u,int v)
21 {
22     g[++h].to=v;
23     g[h].nex=firs[u];
24     firs[u]=h;
25 }
26 
27 void dfs(int x)
28 {
29     dp[x*(m+1)+1]=a[x];
30     siz[x]=1;
31     vis[x]=true;
32     int j;
33     for (R i=firs[x];i;i=g[i].nex)
34     {
35         j=g[i].to;
36         if(vis[j]) continue;
37         dfs(j);
38         for (int k=min(siz[x]+siz[j],m);k>=1;--k)
39             for (int z=max(1,k-siz[x]);z<=min(siz[j],k-1);++z)
40                 dp[x*(m+1)+k]=max(dp[x*(m+1)+k],dp[x*(m+1)+k-z]+dp[j*(m+1)+z]);
41         siz[x]+=siz[j];
42     }
43 }
44 
45 int read()
46 {
47     int x=0;
48     char c=getchar();
49     while (!isdigit(c)) c=getchar();
50     while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
51     return x;
52 }
53 
54 int main()
55 {
56     scanf("%d%d",&n,&m); m++;
57     memset(g,0,sizeof(g));
58     for (R i=1;i<=n;i++)
59     {
60         si=read(),a[i]=read();
61         add(i,si);
62         add(si,i);
63     }
64     dfs(0);
65     printf("%d",dp[m]);
66     return 0;
67 }
选课

相关文章: