第一次打树链剖分,也挺容易的嘛~~~,两次dfs后建线段树维护就行了~~~
CDOE:
1 |
#include<cstdio>#include<iostream>#include<cstring>#include<algorithm>using namespace std;
#define maxn 30010struct edges{
int to,next;
}edge[maxn*2];int next[maxn],l;
int addedge(int from,int to){
edge[++l]=(edges){to,next[from]};
next[from]=l;return 0;
}int dep[maxn],par[maxn],si[maxn],ch[maxn];
bool b[maxn];
int dfs(int u){
dep[u]=dep[par[u]]+1;
b[u]=0;
si[u]=1;
for (int i=next[u];i;i=edge[i].next)
if (b[edge[i].to]) {
par[edge[i].to]=u;
dfs(edge[i].to);
si[u]+=si[edge[i].to];
if (si[ch[u]]<si[edge[i].to]) ch[u]=edge[i].to;
}
return 0;
}int top[maxn],id[maxn],arr[maxn],pos[maxn];
int ind,num;
int heavy_edge(int x,bool flag){
if (flag) top[++ind]=x;
id[x]=ind;
b[x]=0;
arr[pos[x]=++num]=x;
if (ch[x]) heavy_edge(ch[x],0);
for (int i=next[x];i;i=edge[i].next)
if (b[edge[i].to]){
heavy_edge(edge[i].to,1);
}
return 0;
}struct node{
int l,r,Max,Sum;
}t[maxn*4];int w[maxn];
int buildtree(int x,int l,int r){
t[x].l=l;t[x].r=r;
if (l==r) {t[x].Max=t[x].Sum=w[arr[l]];return 0;}
buildtree(x<<1,l,(l+r)>>1);
buildtree((x<<1)+1,((l+r)>>1)+1,r);
t[x].Max=max(t[x<<1].Max,t[(x<<1)+1].Max);
t[x].Sum=t[x<<1].Sum+t[(x<<1)+1].Sum;
return 0;
}int n;
int init(){
scanf("%d",&n);
for (int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
addedge(x,y);addedge(y,x);
}
memset(b,true,sizeof(b));
dfs(1);
memset(b,true,sizeof(b));
heavy_edge(1,1);
for (int i=1;i<=n;i++) scanf("%d",w+i);
buildtree(1,1,n);
return 0;
}int change(int x,int y){
int l=t[x].l,r=t[x].r;
if (y<l||y>r) return 0;
if (l==r) {t[x].Max=t[x].Sum=w[arr[y]];return 0;}
change(x<<1,y);change((x<<1)+1,y);
t[x].Max=max(t[x<<1].Max,t[(x<<1)+1].Max);
t[x].Sum=t[x<<1].Sum+t[(x<<1)+1].Sum;
return 0;
}int gmax(int x,int x1,int y1){
int l=t[x].l,r=t[x].r;
if (l>y1||r<x1) return -30002;
if (l>=x1&&r<=y1) return (t[x].Max);
return(max(gmax(x<<1,x1,y1),gmax((x<<1)+1,x1,y1)));
}int gsum(int x,int x1,int y1){
int l=t[x].l,r=t[x].r;
if (l>y1||r<x1) return 0;
if (l>=x1&&r<=y1) return (t[x].Sum);
return(gsum(x<<1,x1,y1)+gsum((x<<1)+1,x1,y1));
}int MAX(int l,int r){
int ans=-30001;
while (1){
if (dep[top[id[l]]]<dep[top[id[r]]]) swap(l,r);
if (id[l]!=id[r]) {ans=max(ans,gmax(1,pos[top[id[l]]],pos[l]));l=par[top[id[l]]];}
else {if (dep[l]<dep[r]) swap(l,r);ans=max(ans,gmax(1,pos[r],pos[l]));return ans;}
}
}int SUM(int l,int r){
int ans=0;
while (1){
if (dep[top[id[l]]]<dep[top[id[r]]]) swap(l,r);
if (id[l]!=id[r]) {ans+=gsum(1,pos[top[id[l]]],pos[l]);l=par[top[id[l]]];}
else {if (dep[l]<dep[r]) swap(l,r);ans+=gsum(1,pos[r],pos[l]);return ans;}
}
}int m;
int work(){
scanf("%d",&m);
for (int i=1;i<=m;i++) {
int x,y;char s[10];
scanf("%s%d%d",s,&x,&y);
if (s[0]=='C') {w[x]=y;change(1,pos[x]);}
if (s[1]=='M') {printf("%d\n",MAX(x,y));}
if (s[1]=='S') {printf("%d\n",SUM(x,y));}
}
return 0;
}int main(){
init();
work();
return 0;
} |