luogu4092[HEOI2016]树题解--树链剖分

题目链接:

https://www.luogu.org/problemnew/show/P4092

分析

瞎扯—$O(Q \log^3 N)$解法

这道先yy出了一个$O(Q \log^3 N)$,的做法,先树链剖分。

对于加标记操作,找到那个点所在的链,将其$top$标记一下,然后该点到根节点区间和+1.

对于查询操作,先看这个点所在链有没有标记,如果没有,就一直向上跳直到找到一条标记了的链,然后在那条链上根据到根节点区间和进行倍增/二分

然后出去吃饭的时候忽然想到了$O(Q \log^2 N)$的解法,于是刚刚这个解法刚打完还没有查错,放在这做一个参考

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
#include <algorithm>
#include <cmath>
#define ll long long
#define ri register int
using namespace std;
const int maxn=100005;
const int inf=0x7fffffff;
template <class T>inline void read(T &x){
x=0;int ne=0;char c;
while(!isdigit(c=getchar()))ne=c=='-';
x=c-48;
while(isdigit(c=getchar()))x=(x<<3)+(x<<1)+c-48;
x=ne?-x:x;
return ;
}
int n,q;
struct Edge{
int ne,to;
}edge[maxn<<1];
int h[maxn],num_edge=0;
inline void add_edge(int f,int t){
edge[++num_edge].ne=h[f];
edge[num_edge].to=t;
h[f]=num_edge;
return ;
}
int dep[maxn],fa[maxn],size[maxn],son[maxn],top[maxn],dfn[maxn],rnk[maxn],cnt=0;
void dfs_1(int now){
int v;size[now]=1;
for(ri i=h[now];i;i=edge[i].ne){
v=edge[i].to;
if(v==fa[now])continue;
fa[v]=now,dep[v]=dep[now]+1;
dfs_1(v);
size[now]+=size[v];
if(!son[now]||size[son[now]]<size[v])son[now]=v;
}
return ;
}
void dfs_2(int now,int t){
int v;top[now]=t;
dfn[now]=++cnt,rnk[cnt]=now;
if(!son[now])return ;
dfs_2(son[now],t);
for(ri i=h[now];i;i=edge[i].ne){
v=edge[i].to;
if(v==fa[now]||v==son[now])continue;
dfs_2(v,v);
}
return ;
}
int sum[maxn<<2],tag[maxn<<2],L,R,dta,ok[maxn];
void build(int now,int l,int r){
if(l==r){
sum[now]=ok[rnk[l]];
return ;
}
int mid=(l+r)>>1;
build(now<<1,l,mid);
build(now<<1|1,mid+1,r);
return ;
}
void pushdown(int now,int ln,int rn){
if(tag[now]){
sum[now<<1]+=tag[now]*ln;
sum[now<<1|1]+=tag[now]*rn;
tag[now<<1]+=tag[now];
tag[now<<1|1]+=tag[now];
tag[now]=0;
}
return ;
}
void update(int now,int l,int r){
if(L<=l&&r<=R){
sum[now]+=dta*(r-l+1);
tag[now]+=dta;
return ;
}
int mid=(l+r)>>1;
pushdown(now,mid-l+1,r-mid);
if(L<=mid)update(now<<1,l,mid);
if(mid<R)update(now<<1|1,mid+1,r);
sum[now]=sum[now<<1]+sum[now<<1|1];
return ;
}
int query(int now,int l,int r){
if(L<=l&&r<=R){
return sum[now];
}
int mid=(l+r)>>1,ans=0;
pushdown(now,mid-l+1,r-mid);
if(L<=mid)ans+=query(now<<1,l,mid);
if(mid<R)ans+=query(now<<1|1,mid+1,r);
sum[now]=sum[now<<1]+sum[now<<1|1];
return ans;
}
void update_path(int x,int y){
dta=1;ok[top[x]]=1;//该条链上有一个标记的点
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
L=dfn[top[x]],R=dfn[x];
update(1,1,n);
}
if(dfn[x]<dfn[y])swap(x,y);
L=dfn[x],R=dfn[y];
update(1,1,n);
return ;
}
inline int solve(int x,int y){
int tmp,val,p=0,k=1,len,ans=0;
bool flag=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
len=dfn[x]-dfn[top[x]];
if(ok[top[x]]){
L=dfn[top[x]],R=dfn[x],
tmp=query(1,1,n);
p=0,k=1,flag=0;
while(k!=0){
L=dfn[x+p+k],R=dfn[x];
if(query(1,1,n)>tmp)flag=1,k=k>>1;
else p=p+k,k=k<<1;
while(p+k>len)k=k>>1;
}
if(flag)return ans+dfn[x+p]-dfn[x];
}
ans+=len;
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
L=dfn[x],R=dfn[y],len=dfn[y]-dfn[x];
tmp=query(1,1,n);
p=0,k=1;
//cout<<y<<endl;
if(x==y)return ans;
while(k!=0){
L=dfn[x+p+k],R=dfn[x];
if(query(1,1,n)>tmp)k=k>>1;
else p=p+k,k=k<<1;
//if(y==3)cout<<k<<' '<<p<<endl;
while(p+k>len)k=k>>1;
}
return ans+dfn[x+p]-dfn[x];
}
int main(){
char opt[5];
int x,y,z;
read(n),read(q);
for(ri i=1;i<n;i++){
read(x),read(y);
add_edge(x,y);
add_edge(y,x);
}
dep[1]=1,fa[1]=0;
dfs_1(1);
dfs_2(1,1);
ok[dfn[1]]=1;
build(1,1,n);
while(q--){
scanf("%s",opt);
if(opt[0]=='C'){
read(x);
//cout<<x<<"-----"<<endl;
update_path(1,x);
}
else{
read(x);
//cout<<x<<"***"<<endl;
printf("%d\n",solve(x,1));
}
}
return 0;
}

$O(Q \log^2 N)$解法

首先我想到了一个错误的解法,就是因为链是线段树上一个连续的区间,每个$[dfn[x],dfn[top[x]]]$线段树区间有个$mx$值,表示,$x$到$top[x]$路径中距离它最近标记的祖先,加标记时比较原有标记深度与新标记深度然后更新。查询的时候查询$x$到$top[x]$的区间最大之就可以了,如果没有,就一直往上跳直至找到

然而这个解法有个错误我SB地没有发现,就是你更新区间最大值时,$x$上的祖先节点也会被更新到(因为深度更小),再次感谢wjyyy和creed_两位大佬指出我的错误

正解应该是更新子树,将子树的最大值更新,查询照样,相比于我错误的代码只需改一句话

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
#include <algorithm>
#include <cmath>
#define ll long long
#define ri register int
using namespace std;
const int maxn=100005;
const int inf=0x7fffffff;
template <class T>inline void read(T &x){
x=0;int ne=0;char c;
while(!isdigit(c=getchar()))ne=c=='-';
x=c-48;
while(isdigit(c=getchar()))x=(x<<3)+(x<<1)+c-48;
x=ne?-x:x;
return ;
}
int n,q;
struct Edge{
int ne,to;
}edge[maxn<<1];
int h[maxn],num_edge=0;
inline void add_edge(int f,int t){
edge[++num_edge].ne=h[f];
edge[num_edge].to=t;
h[f]=num_edge;
return ;
}
int dep[maxn],fa[maxn],size[maxn],son[maxn],top[maxn],dfn[maxn],rnk[maxn],cnt=0;
void dfs_1(int now){
int v;size[now]=0;
for(ri i=h[now];i;i=edge[i].ne){
v=edge[i].to;
if(v==fa[now])continue;
fa[v]=now,dep[v]=dep[now]+1;
dfs_1(v);
size[now]+=size[v];
if(!son[now]||size[son[now]]<size[v])son[now]=v;
}
return ;
}
void dfs_2(int now,int t){
int v;top[now]=t;
dfn[now]=++cnt,rnk[cnt]=now;
if(!son[now])return ;
dfs_2(son[now],t);
for(ri i=h[now];i;i=edge[i].ne){
v=edge[i].to;
if(v==fa[now]||v==son[now])continue;
dfs_2(v,v);
}
return ;
}
int mx[maxn<<2],L,R,dta;
void build(int now,int l,int r){
if(l==r){
if(rnk[l]==1)mx[now]=1;
else mx[now]=0;
return ;
}
int mid=(l+r)>>1;
build(now<<1,l,mid);
build(now<<1|1,mid+1,r);
if(dep[mx[now<<1]]>dep[mx[now<<1|1]]){
mx[now]=mx[now<<1];
}
else mx[now]=mx[now<<1|1];
return ;
}
void update(int now,int l,int r){
if(L<=l&&r<=R){
if(dep[mx[now]]<dep[dta]){
mx[now]=dta;
}
return ;
}
int mid=(l+r)>>1;
if(L<=mid)update(now<<1,l,mid);
if(mid<R)update(now<<1|1,mid+1,r);
if(dep[mx[now<<1]]>dep[mx[now<<1|1]]){
mx[now]=mx[now<<1];
}
else mx[now]=mx[now<<1|1];
return ;
}
int query(int now,int l,int r){
if(L<=l&&r<=R){
return mx[now];
}
int mid=(l+r)>>1,ans=0,tmp;
if(L<=mid){
int tmp=query(now<<1,l,mid);
if(dep[ans]<dep[tmp])ans=tmp;
}
if(mid<R){
int tmp=query(now<<1|1,mid+1,r);
if(dep[ans]<dep[tmp])ans=tmp;
}
return ans;
}
void update_path(int x){
dta=x;
//L=R=dfn[x];
L=dfn[x],R=dfn[x]+size[x];
update(1,1,n);
return ;
}
int query_path(int x){
int ans=0;
while(top[x]!=1){
L=dfn[top[x]],R=dfn[x];
ans=query(1,1,n);
if(ans!=0)return ans;
x=fa[top[x]];
}
L=dfn[1],R=dfn[x];
ans=query(1,1,n);
return ans;
}
int main(){
char opt[5];
int x,y,z;
read(n),read(q);
for(ri i=1;i<n;i++){
read(x),read(y);
add_edge(x,y);
add_edge(y,x);
}
dep[0]=-1,dep[1]=1,fa[1]=0;
dfs_1(1);
dfs_2(1,1);
build(1,1,n);
while(q--){
//cout<<q<<endl;
scanf("%s",opt);
if(opt[0]=='C'){
read(x);
update_path(x);
}
else{
read(x);
printf("%d\n",query_path(x));
}
}
return 0;
}