次小生成树

例题:1148. 秘密的牛奶运输

定理:对于一张无向图,如果存在最小生成树和(严格)次小生成树,那么对于任何一棵最小生成树,都存在一棵(严格)次小生成树,使得着两棵树只有一条边不同

法一

具体实现:

  • 先求最小生成树,再枚举删去最小生成树中的边求解。

时间复杂度

  • O(mlogm + nm)

证明:

  • 化零为整:把最小生成树与次小生成树得不同得一条边分为n-1类(1~n-1)
  • 只需求出每一类得最小值,则一定可以求出次小生成树

弊端

  • 求非严格次小生成树可以,但求严格次小生成树不太方便

法二

具体实现

  • 先求最小生成树,然后依次枚举非树边,将该边加入树中,同时从树中去掉一条边,使得最终得图仍然是一棵树。则一定可以求出次小生成树。

时间复杂度:

  • 朴素算法 O(m + n^2 + mlogm))
  • lca 倍增优化 O(nlogn)

证明

  • 声明:
    • 设T为图G得一棵生成树,对于非树边a和树边b,插入边a,并删除边b的操作记为(+a,-b)。
    • 如果T+a-b之后,仍然是一棵生成树,称(+a,-b)是T的一个可行交换。
    • 由T进行一次可行变换所得到的新的生成树的集合称为T的邻集。
  • 定理:
    • 次小生成树一定再最小生成树的邻集中。
  • 证明如下:
    • 反证法:假设存在某一个图,次小生成树和最小生成树至少两条边不同。
    • 先考虑非严格次小生成树:
      • 先最小生成树从小到大排序,枚举所有的最小生成树的边,枚举该边是否在次小生成树中,若不在,则将该边加上,形成环,去掉后面的边,权值小于等于之前的生成树,结果不会变坏,且次小生成树与最小生成树仍然由差异,则该操作成立,如此反复,直到次小生成树和最小生成树只剩一条边不同为止。
    • 再考虑严格次小生成树:
      • 类似上述操作,两边相等则不变(不用担心完全一样的问题,因为严格次小一定大于最小),若当前枚举的最小生成树的边较小,则替换。如此可构造出只剩一条边不同的次小生成树

朴素代码如下;

注意:在求严格次小生成树时,不能只预处理两点之间最大的树边,因为当最大树边和当前枚举的非树边长度相同时,就不能替换了,但此时却可以替换长度次大的树边。因此还需同时预处理出长度次大的树边。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;

typedef long long LL;

const int N = 510, M = 10010;

int n, m;
struct Edge
{
int a, b, w;
bool f;
bool operator< (const Edge &t)const
{
return w < t.w;
}
}edge[M];

int p[N];
int dist1[N][N], dist2[N][N];
int h[N], e[N * 2], w[N * 2], ne[N * 2], idx;

void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

int find(int x)
{
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

LL kruskal()
{
LL sum = 0;
for (int i = 0; i < m; i ++ )
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
int pa = find(a), pb = find(b);
if(pa != pb)
{
p[pa] = pb;
sum += w;
edge[i].f = true;
add(a, b, w), add(b, a, w);
}
}

return sum;
}

void dfs(int u, int fa, int maxd1, int maxd2, int d1[], int d2[])
{
d1[u] = maxd1, d2[u] = maxd2;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j != fa)
{
int td1 = maxd1, td2 = maxd2;
if(w[i] > td1) td2 = td1, td1 = w[i];
else if(w[i] < td1 && w[i] > td2) td2 = w[i];
dfs(j, u, td1, td2, d1, d2);
}
}
}

int main()
{
scanf("%d %d", &n, &m);

memset(h, -1, sizeof h);

for (int i = 0; i < m; i ++ )
{
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
edge[i] = {a, b, w};
}

sort(edge, edge + m);

for (int i = 1; i <= n; i ++ )
p[i] = i;

LL sum = kruskal();

for (int i = 1; i <= n; i ++ )
dfs(i, -1, -1e9, -1e9, dist1[i], dist2[i]);

LL res = 1e18;
for (int i = 0; i < m; i ++ )
if(!edge[i].f)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
LL t = 0;
if(w > dist1[a][b]) t = sum + w - dist1[a][b];
else if(w > dist2[a][b]) t = sum + w - dist2[a][b];
res = min(res, t);
}

printf("%lld\n", res);

return 0;
}

lca 倍增优化版代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;

typedef long long LL;

const int N = 100010, M = 300010, INF = 0x3f3f3f3f;

int n, m;
struct Edge
{
int a, b, w;
bool used;
bool operator< (const Edge &t)const
{
return w < t.w;
}
}edge[M];
int p[N];
int h[N], e[M], w[M], ne[M], idx;
int depth[N], fa[N][17], d1[N][17], d2[N][17];
int q[N];

void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

int find(int x)
{
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

LL kruskal()
{
for (int i = 1; i <= n; i ++ ) p[i] = i;
sort(edge, edge + m);
LL res = 0;
for (int i = 0; i < m; i ++ )
{
int a = find(edge[i].a), b = find(edge[i].b), w = edge[i].w;
if(a != b)
{
p[a] = b;
res += w;
edge[i].used = true;
}
}

return res;
}

void build()
{
memset(h, -1, sizeof h);
for (int i = 0; i < m; i ++ )
if(edge[i].used)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
add(a, b, w), add(b, a, w);
}
}

void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
int hh = 0, tt = 0;
q[0] = 1;

while(hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t;
d1[j][0] = w[i], d2[j][0] = -INF;
for (int k = 1; k <= 16; k ++ )
{
int anc = fa[j][k - 1];
fa[j][k] = fa[anc][k - 1];
int distance[4] = {d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]};
d1[j][k] = -INF, d2[j][k] = -INF;
for (int u = 0; u < 4; u ++ )
{
int d = distance[u];
if(d > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d;
else if(d != d1[j][k] && d > d2[j][k]) d2[j][k] = d;
}
}
}
}
}
}

int lca(int a, int b, int w)
{
static int distance[N * 2];
int cnt = 0;
if(depth[a] < depth[b]) swap(a, b);

for (int k = 16; k >= 0; k --)
if(depth[fa[a][k]] >= depth[b])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
a = fa[a][k];
}

if(a != b)
{
for (int k = 16; k >= 0; k -- )
if(fa[a][k] != fa[b][k])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
distance[cnt ++ ] = d1[b][k];
distance[cnt ++ ] = d2[b][k];
a = fa[a][k];
b = fa[b][k];
}
distance[cnt ++ ] = d1[a][0];
distance[cnt ++ ] = d1[b][0];
}

int dist1 = -INF, dist2 = -INF;
for (int i = 0; i < cnt; i ++ )
{
int d = distance[i];
if(d > dist1) dist2 = dist1, dist1 = d;
else if(d != dist1 && d > dist2) dist2 = d;
}

if(w > dist1) return w - dist1;
if(w > dist2) return w - dist2;

// 此返回值不会被用到
// 当 dist1 无法被用时,dist2 一定小于 w
// 证明:当 w == dist1 时:dist2 < dist1;当 w < dist1 时,dist2 < dist1
// 所以 w > dist2 时一定成立的
return INF;
}

int main()
{
scanf("%d%d", &n, &m);

for (int i = 0; i < m; i ++ )
{
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
edge[i] = {a, b, w};
}

LL sum = kruskal();
build();
bfs();

LL res = 1e18;
for (int i = 0; i < m; i ++ )
if(!edge[i].used)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
res = min(res, sum + lca(a, b, w));
}

printf("%lld\n", res);

return 0;
}