【题意】给定n个点的树,m条从下往上的链,每条链代价ci,求最少代价使得链覆盖所有边。n,m<=3*10^5,ci<=10^9,time=4s。
【算法】树形DP+线段树||可并堆
【题解】从每条边都需要一条链来覆盖的角度出发,令f[i]表示覆盖子树 i 以及 i到fa[i]的边(i->fa[i])的最小代价,整个过程通过dfs从下往上做。
由于f[son[i]]已知,所以f[i]的转移实际上是考虑覆盖i->fa[i]的链,定义这条链为主链。那么f[i]=min(c+Σf[k]),c是主链代价,k是主链上在i子树内的所有点的子节点(不含主链上点),所有起点在子树i内终点在i的祖先的链都可以作为主链,取最小值。
自然地,可以在递归的过程中将Σf[k]并入c中。具体而言,对于每个点x:
1.删。将终点在x的链删除。
2.加。记sum=Σf[son[i]],son[i]子树内所有的链c+=sum-f[son[i]](就是把Σf[k]并入c中),特别地,起点在i的链c+=sum。
3.取。f[i]是子树i中所有的链c的最小值。
现在需要快速支持子树加值和子树求最小值的操作,可以用线段树按dfs序维护所有链实现(把链按起点的dfs序作为线段树下标)。
复杂度O(n log n)。
#include<cstdio> #include<cctype> #include<vector> #include<algorithm> #define ll long long using namespace std; int read(){ char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } const int maxn=300010; const ll inf=1e15; struct tree{int l,r;ll delta,mins;}t[maxn*4]; struct edge{int v,from;}e[maxn*2]; vector<int>v[maxn]; int n,m,ku[maxn],kv[maxn],kw[maxn],kp[maxn],tot=0,dfsnum=0,first[maxn],be[maxn],ed[maxn]; ll a[maxn],f[maxn]; void ins(int u,int v){tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs_order(int x,int fa){ be[x]=dfsnum+1; for(int i=0;i<(int)v[x].size();i++){ kp[v[x][i]]=++dfsnum; a[dfsnum]=kw[v[x][i]]; } for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa){ dfs_order(e[i].v,x); } ed[x]=dfsnum; if(be[x]>ed[x]){printf("-1");exit(0);} } void modify(int k,ll x){t[k].mins+=x;t[k].delta+=x;} void up(int k){t[k].mins=min(t[k<<1].mins,t[k<<1|1].mins);} void down(int k){ if(t[k].delta){ modify(k<<1,t[k].delta); modify(k<<1|1,t[k].delta); t[k].delta=0; } } void build(int k,int l,int r){ t[k].l=l;t[k].r=r;t[k].delta=0; if(l==r){t[k].mins=a[l];}else{ int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); up(k); } } void add(int k,int l,int r,ll x){ if(l<=t[k].l&&t[k].r<=r){modify(k,x);return;} down(k); int mid=(t[k].l+t[k].r)>>1; if(l<=mid)add(k<<1,l,r,x); if(r>mid)add(k<<1|1,l,r,x); up(k); } ll ask(int k,int l,int r){ if(l<=t[k].l&&t[k].r<=r){return t[k].mins;} down(k); int mid=(t[k].l+t[k].r)>>1; ll ans=inf; if(l<=mid)ans=ask(k<<1,l,r); if(r>mid)ans=min(ans,ask(k<<1|1,l,r)); return ans; } ll dp(int x,int fa){ f[x]=0;ll sum=0; for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa)sum+=dp(e[i].v,x); for(int i=0;i<(int)v[x].size();i++)add(1,v[x][i],v[x][i],inf); add(1,be[x],ed[x],sum); for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa){ add(1,be[e[i].v],ed[e[i].v],-f[e[i].v]); } f[x]=ask(1,be[x],ed[x]); if(x!=1&&f[x]>=inf){printf("-1");exit(0);} return f[x]; } int main(){ n=read();m=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); ins(u,v);ins(v,u); } for(int i=1;i<=m;i++){ ku[i]=read(),kv[i]=read(),kw[i]=read(); v[ku[i]].push_back(i); } dfsnum=0; dfs_order(1,0); build(1,1,dfsnum); for(int i=1;i<=m;i++)v[ku[i]].clear(); for(int i=1;i<=m;i++)v[kv[i]].push_back(kp[i]); dp(1,0); ll ans=0; for(int i=first[1];i;i=e[i].from)ans+=f[e[i].v]; printf("%lld",ans); return 0; }