最近公共祖先(LCA问题)

板子题:

例题:1172.祖孙询问

次小生成树应用:

例题:1171. 距离

树上差分应用:

例题:352. 闇の連鎖

三种方法如下:第二和第三种更实用一些

法一:向上标记法

时间复杂度:

  • 每次查询:最坏情况 O(n)

算法实现:

  • x 向上走到根节点,并标记所有经过的节点
  • y 向上走到根节点,当第一次遇到已经标记的节点时,就找到了 LCA(x, y)。

法二:树上倍增法

时间复杂度:

  • 预处理:O(nlogn)
  • 每次查询:O(logn)

算法实现:

  • 预处理:
    • 预处理数组 fa[i][j] :表示从 i 开始,向上走 2^j 所能走到的节点。0 <= j <= logn。
      • 如何预处理:dfs 或 bfs 都可以
      • 若fa[i][j]的节点不存在,则令 fa[i][j] == 0
      • j == 0时,fa[i][j] = i 的父节点
      • j > 0时,fa[i][j] = fa[fa[i][j - 1]][j - 1],即、先跳到2^(j-1) 步,再跳 2^(j-1) 步
    • 预处理数组 depth[i]:表示深度,或层数,规定根节点的深度为 1,子节点的深度为父节点深度 +1
  • 具体步骤如下:基于倍增的思想
    • 先将两个点跳到同一层,
    • 让两个点同时向上跳,一直跳都他们的最近公共祖先的下一层(因为,当二者相等时无法判断是否为最近的公共祖先,所以跳到最近公共祖先的下一层)
  • 哨兵:depth[0] = 0。如果从 i 开始跳 2^j 步会跳过根节点,那么 fa[i][j] = 0,此时depth[fa[i][j]] = 0,哨兵生效,停止跳跃。
    • 具体作用如下:
    • 保证 lca 两点跳到同一层时,越界的深度小于任何树中节点的深度,从而停止此次跳跃;
    • 保证 lca 查找公共祖先时,越界的二者的父节点相同,都是0,从而停止此次跳跃。

具体板子如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

const int N = 40010, M = N * 2, K = log2(N);

int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][K + 1]; // 需要用 1~K 层, 开到 K + 1 保证不越界
int q[N];

void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

// 求每个点的深度 depth[i];
// f[i][j] 表示:从i点开始,向上走2^j步,(0 <= j <= logn);
void bfs(int root)
{
memset(depth, 0x3f, sizeof depth);

// 哨兵:当f[i][j]为0时(即、i的向上走2^j步不存在时), depth[f[i][j]] = 0
depth[0] = 0;

// 根节点初始化
depth[root] = 1;
int hh = 0, tt = 0;
q[0] = root;

while(hh <= tt)
{
int t = q[hh ++];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t; // 向上走2^0步(1步),为父节点

// 递推处理fa[i][j], 有点st表的感觉
for (int k = 1; k <= K; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}

int lca(int a, int b)
{
// 令 a 为较深的点
if(depth[a] < depth[b]) swap(a, b);

// 将 较深的点a 跳到和b的同一层
// depth[0] = 0作用:
// 保证 lca 两点跳到同一层时,越界的深度小于任何树中节点的深度,从而停止此次跳跃;
for (int k = K; k >= 0; k -- )
if(depth[fa[a][k]] >= depth[b])
a = fa[a][k];

// 此时 若 a 和 b 为同一点,则直接返回
if(a == b) return a;

// 若父节点不同,则变为父节点,再次比较
// depth[0] = 0作用:
// 保证 lca 查找公共祖先时,越界的二者的父节点相同,都是0,从而停止此次跳跃。
for (int k = K; k >= 0; k -- )
if(fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}

return fa[a][0];
}

int main()
{
scanf("%d", &n);

int root = 0;
memset(h, -1, sizeof h);
while(n -- )
{
int a, b;
scanf("%d%d", &a, &b);
if(b == -1) root = a;
else add(a, b), add(b, a);
}

bfs(root);

scanf("%d", &m);
while(m -- )
{
int a, b;
scanf("%d%d", &a, &b);
int p = lca(a, b);
if(p == a) puts("1");
else if(p == b) puts("2");
else puts("0");
}

return 0;
}

法三:Tarjan算法——离线求LCA

介绍:

算法本质是,使用并查集对向上标记法的优化。离线算法,读入所有询问,统一计算,统一输出

时间复杂度

  • O(n + m)

算法实现

  • 在深度遍历时,将所有点分为三大类
    • 第一类点:已经遍历过,并且已经回溯过的点,标记为 2
    • 第二类点:正在搜索的分支,标记为 1
    • 第三类点:还未搜索到的点,标记为 0
  • 若求lca(x, y),对于正在访问的节点 x ,即、x 的标记为 1。若 y 是已经遍历过且回溯过的节点,则 lca(x, y) 就是从 y 向上走到根,第一个遇见的标记为 1 的节点,其实就是向上标记法。
  • 并查集优化:
    • 当一个节点获得标记 2 时,把它所在的集合合并到其父节点所在的集合中(合并时,其父节点一定为 1)
    • 所有标记为 2 的节点,都有一个指针指向其祖宗节点。若 x 的标记为 1 ,y 的标记为 2,查询 y 的向上走的第一个标记为 1 的节点就是 y 的祖宗节点(一定标记为 1 ),此祖宗节点就是 lca(x, y)

具体板子如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

#define x first
#define y second

using namespace std;

typedef pair<int, int> PII;

const int N = 10010, M = N * 2;

int n, m;
int h[N], e[M], w[M], ne[M], idx;
int dist[N];
int st[N];
int p[N];
int res[M];
vector<PII> query[N];

void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

void dfs(int u, int fa)
{
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}

int find(int x)
{
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

// 将所有点分为三类:
// 第一类. 已经遍历过且搜索过的点,标记为 2
// 第二类. 正在搜索的分支,标记为 1
// 第三类. 还未搜索的点,标记为 0
void tarjan(int u)
{
st[u] = 1;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(!st[j])
{
tarjan(j);
p[j] = u;
}
}

// 遍历与此点的询问
for (auto it : query[u])
{
int son = it.x, id = it.y;
if(st[son] == 2)
{
int anc = find(son);
res[id] = dist[u] + dist[son] - dist[anc] * 2;
}
}

st[u] = 2;
}

int main()
{
scanf("%d %d", &n, &m);

memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}

for (int i = 0; i < m; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
query[a].push_back({b, i});
query[b].push_back({a, i});
}

for (int i = 1; i <= n; i ++ ) p[i] = i;

dfs(1, -1);
tarjan(1);

for (int i = 0; i < m; i ++ ) printf("%d\n", res[i]);

return 0;
}

法四:dfs序列 + RMQ算法(例如,st表)

麻烦,且很不常用,了解即可

算法思路:

  • dfs 遍历得 dfs 序列
  • 若求 lca(x, y) ,则在 dfs 序列中,找任意的 x 和 y 求之间的最小值,此最小值就是 lca(x, y)

LCA 应用

LCA 树上倍增算法的应用 —— 次小生成树

将朴素求树上两点路径之间的最值,化为倍增求最值,从而将时间复杂度由 O(n^2) 优化为 O(nlogn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;

typedef long long LL;

const int N = 100010, M = 300010, INF = 0x3f3f3f3f;

int n, m;
struct Edge
{
int a, b, w;
bool used;
bool operator< (const Edge &t)const
{
return w < t.w;
}
}edge[M];
int p[N];
int h[N], e[M], w[M], ne[M], idx;
int depth[N], fa[N][17], d1[N][17], d2[N][17];
int q[N];

void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

int find(int x)
{
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

LL kruskal()
{
for (int i = 1; i <= n; i ++ ) p[i] = i;
sort(edge, edge + m);
LL res = 0;
for (int i = 0; i < m; i ++ )
{
int a = find(edge[i].a), b = find(edge[i].b), w = edge[i].w;
if(a != b)
{
p[a] = b;
res += w;
edge[i].used = true;
}
}

return res;
}

void build()
{
memset(h, -1, sizeof h);
for (int i = 0; i < m; i ++ )
if(edge[i].used)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
add(a, b, w), add(b, a, w);
}
}

void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
int hh = 0, tt = 0;
q[0] = 1;

while(hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t;
d1[j][0] = w[i], d2[j][0] = -INF;
for (int k = 1; k <= 16; k ++ )
{
int anc = fa[j][k - 1];
fa[j][k] = fa[anc][k - 1];
int distance[4] = {d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]};
d1[j][k] = -INF, d2[j][k] = -INF;
for (int u = 0; u < 4; u ++ )
{
int d = distance[u];
if(d > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d;
else if(d != d1[j][k] && d > d2[j][k]) d2[j][k] = d;
}
}
}
}
}
}

int lca(int a, int b, int w)
{
static int distance[N * 2];
int cnt = 0;
if(depth[a] < depth[b]) swap(a, b);

for (int k = 16; k >= 0; k --)
if(depth[fa[a][k]] >= depth[b])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
a = fa[a][k];
}

if(a != b)
{
for (int k = 16; k >= 0; k -- )
if(fa[a][k] != fa[b][k])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
distance[cnt ++ ] = d1[b][k];
distance[cnt ++ ] = d2[b][k];
a = fa[a][k];
b = fa[b][k];
}
distance[cnt ++ ] = d1[a][0];
distance[cnt ++ ] = d1[b][0];
}

int dist1 = -INF, dist2 = -INF;
for (int i = 0; i < cnt; i ++ )
{
int d = distance[i];
if(d > dist1) dist2 = dist1, dist1 = d;
else if(d != dist1 && d > dist2) dist2 = d;
}

if(w > dist1) return w - dist1;
if(w > dist2) return w - dist2;

return INF;
}

int main()
{
scanf("%d%d", &n, &m);

for (int i = 0; i < m; i ++ )
{
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
edge[i] = {a, b, w};
}

LL sum = kruskal();
build();
bfs();

LL res = 1e18;
for (int i = 0; i < m; i ++ )
if(!edge[i].used)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
res = min(res, sum + lca(a, b, w));
}

printf("%lld\n", res);

return 0;
}

LCA 应用 —— 树上差分

点的权值表示边的权值,其中 d[i]:表示差分节点 i 的权值(即、i 通向父节点的边的权值);如果在 x, y的路径上 + c,则只需令 d[x] += c, d[y] += c, d[p] -= 2c;(其中 p = lca(x, y))。最终答案为d[res]:res 的所有子树的权值和

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 100010, M = 200010, K = log2(N), INF = 0x3f3f3f3f;

int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][K + 1];
int d[N], q[N];
int ans;

void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;

int hh = 0, tt = -1;
q[ ++ tt] = 1;

while (hh <= tt)
{
int t = q[hh ++ ];

for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t;
for (int k = 1; k <= K; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}

int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = K; k >= 0; k -- )
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = K; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}

return fa[a][0];
}

int dfs(int u, int father)
{
int res = d[u];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != father)
{
int s = dfs(j, u);
if (s == 0) ans += m;
else if (s == 1) ans += 1;
res += s;
}
}

return res;
}

int main()
{
scanf("%d%d", &n, &m);

memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}

bfs();

for (int i = 0; i < m; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
int p = lca(a, b);
d[a] ++, d[b] ++, d[p] -= 2;
}

dfs(1, -1);
printf("%d\n", ans);

return 0;
}