今天才发现自己根本不会树形背包,我太菜了。
一般的树形背包是这样做的:
看上去,它的复杂度是 $O(nk^2)$ 的。
第一种优化:
这里,如果第二维的大小和子树大小有关,同时又不超过一个常数 $k$ 。例如:第二维表示子树内选了多少个点,那么通过一些精妙的分析和上界优化,复杂度就可以变成 $O(nk)$ 了。
以下的 $siz_x$ 表示合并 $son$ 这个子树前 $x$ 子树的大小(注意:不是 $x$ 的真实子树大小,这里很重要)。
这样分析出来的复杂度就是 $O(nk)$ .
证明:摘自这里;
首先,定义 $T(n)$ 为处理 $n$ 这棵子树时所用的时间,$f(n)$ 为处理 $n$ 这个点时所用的时间。
$T(x)=\left(\sum_{f_y=x} T_{y}\right)+f(x)\\f(x)=\min(m,siz(y_1))\times \min(m,siz(y_1))+\min(m,siz(y_1)+siz(y_2))\times \min(m,siz(y_1))\\ ~~~~~~~~~~~+\cdots+\min(m,siz(x))\times \min(m,siz(y_n))$
现在进行一番放缩,把每个乘法的前一项统一变成 $\min(m,siz(x))$ ,这样显然只会使答案变大,所以分析出来的复杂度上界就应该是正确的。
$f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} \min(m,siz(y))\right)$
再次放缩,把后面括号里的 $min$ 直接扔掉,得:
$f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} siz(y)\right)\\~~~~~~~~=\min(m,siz(x))\times siz(x)$
对于 $siz(x)<m$ 的点,首先考虑他的子树都是叶子的情况:
$T(x)=siz(x)^2+\sum 1$
对于任意 $siz(x)<m$ 的点,递归证明,由于 “平方和小于和的平方” ,所以 $T(x)$ 与 $siz(x)^2$ 同阶;
对于 $siz(x)>m$ 的点,首先考虑它的所有子树都小于 $m$ 的情况:
$T(x)=m\times siz(x)+\sum siz(j)^2$
接着放缩可得,$T(x)$ 与 $m\times siz(x)$ 同阶;
继续使用递归证明的技巧,考虑某一层出现了子树大于 $m$ 的情况:
$T(x)=m\times siz(x)+\sum siz(j)^2+\sum siz(j)\times m$
所以,$T(x)$ 还是与 $m\times siz(x)$ 同阶;
综上所述,这种做法的复杂度是 $n\times k$ 。
选课加强版:https://www.luogu.org/problem/U53204
1 # include <cstdio> 2 # include <iostream> 3 # include <cstring> 4 # include <vector> 5 # define R register int 6 7 using namespace std; 8 9 const int N=100005; 10 struct edge 11 { 12 int to,nex; 13 }; 14 int si,h=0,n,m; 15 edge g[N<<1]; 16 int firs[N],a[N],siz[N]; 17 bool vis[N]; 18 int dp[100000100]; 19 20 void add(int u,int v) 21 { 22 g[++h].to=v; 23 g[h].nex=firs[u]; 24 firs[u]=h; 25 } 26 27 void dfs(int x) 28 { 29 dp[x*(m+1)+1]=a[x]; 30 siz[x]=1; 31 vis[x]=true; 32 int j; 33 for (R i=firs[x];i;i=g[i].nex) 34 { 35 j=g[i].to; 36 if(vis[j]) continue; 37 dfs(j); 38 for (int k=min(siz[x]+siz[j],m);k>=1;--k) 39 for (int z=max(1,k-siz[x]);z<=min(siz[j],k-1);++z) 40 dp[x*(m+1)+k]=max(dp[x*(m+1)+k],dp[x*(m+1)+k-z]+dp[j*(m+1)+z]); 41 siz[x]+=siz[j]; 42 } 43 } 44 45 int read() 46 { 47 int x=0; 48 char c=getchar(); 49 while (!isdigit(c)) c=getchar(); 50 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); 51 return x; 52 } 53 54 int main() 55 { 56 scanf("%d%d",&n,&m); m++; 57 memset(g,0,sizeof(g)); 58 for (R i=1;i<=n;i++) 59 { 60 si=read(),a[i]=read(); 61 add(i,si); 62 add(si,i); 63 } 64 dfs(0); 65 printf("%d",dp[m]); 66 return 0; 67 }