ABC401-F: Add One Edge 3
Page content
問題概要
- 2 つの木が与えられる
- 各木から頂点を 1 つずつ選んで辺を追加すると新しい木が得られる
- 全ペアについて新しい木の直径の総和を求めよ
後付けで言語化した思考過程
まず
木1の最遠点までの距離 + 木2の最遠点までの距離 + 1
を考えたくなる- 最遠点といえば直径だが、任意の点に対する最遠点はどうすれば?
- そもそも木の直径を求めるアルゴリズムが ↓ なので、直径の両端のどっちかだろ多分
- 適当な点から最遠点を求める
- 求めた最遠点から最遠点を求める
- 適当な点、端点 1、端点 2 で計 3 回 bfs して、端点 1 と端点 2 の max をとれば良さそう
- そもそも木の直径を求めるアルゴリズムが ↓ なので、直径の両端のどっちかだろ多分
- 最遠点といえば直径だが、任意の点に対する最遠点はどうすれば?
木 1 の各頂点からの最遠点までの距離 dist1 と、木 2 ついての dist2 が得られる
sum(dist1) * N2 + sum(dist2) * N1 + N1 * N2
や! -> WA- あくまで求めるのは直径の総和で、両方の木を通る最長パスではないですね
- クソデカ木にクソチビ木を合成しても直径が変わらないよね
- 木 1: 1-2-3-4-5
- 木 2: 1
- 木 2 を 3 に付けても直径は 4
- クソデカ木にクソチビ木を合成しても直径が変わらないよね
両方の木を使うかどうかの境界を出したい
ソートして二分探索でよさそう(実家)
なんかこんな感じのことをする(添字をバグらせがちなのでガチャガチャがんばる)
# m2は木2の直径 for i in range(N1): j = bisect_left(dist2, m2 - dist1[i]) ans += cumsum_2[N2] - cumsum_2[j] + (N2 - j) * dist1[i] + (N2 - j) ans += m2 * j
↑ だと木 1 の直径 > 木 2 の直径のとき負数になって面倒なことがわかったので適当に置換しておく。おりゃー
m1 = max(dist1) m2 = max(dist2) if m1 > m2: dist1, dist2 = dist2, dist1 N1, N2 = N2, N1 m1, m2 = m2, m1
AC
実装
コード
from bisect import bisect_left
from collections import deque
import sys
input = sys.stdin.readline
N1 = int(input())
E1 = [[] for _ in range(N1)]
for _ in range(N1 - 1):
u, v = list(map(int, input().split()))
u -= 1
v -= 1
E1[u].append(v)
E1[v].append(u)
N2 = int(input())
E2 = [[] for _ in range(N2)]
for _ in range(N2 - 1):
u, v = list(map(int, input().split()))
u -= 1
v -= 1
E2[u].append(v)
E2[v].append(u)
INF = 10**18
def calc_longest_path_lengths(N, E):
dist = [INF] * N
dist[0] = 0
deq = deque([0])
while deq:
v = deq.popleft()
for u in E[v]:
if dist[u] > dist[v] + 1:
dist[u] = dist[v] + 1
deq.append(u)
dist2 = [INF] * N
v = dist.index(max(dist))
dist2[v] = 0
deq = deque([v])
while deq:
v = deq.popleft()
for u in E[v]:
if dist2[u] > dist2[v] + 1:
dist2[u] = dist2[v] + 1
deq.append(u)
dist3 = [INF] * N
v = dist2.index(max(dist2))
dist3[v] = 0
deq = deque([v])
while deq:
v = deq.popleft()
for u in E[v]:
if dist3[u] > dist3[v] + 1:
dist3[u] = dist3[v] + 1
deq.append(u)
dist = [max(dist2[i], dist3[i]) for i in range(N)]
return dist
dist1 = calc_longest_path_lengths(N1, E1)
dist2 = calc_longest_path_lengths(N2, E2)
dist1.sort()
dist2.sort()
m1 = max(dist1)
m2 = max(dist2)
if m1 > m2:
dist1, dist2 = dist2, dist1
N1, N2 = N2, N1
m1, m2 = m2, m1
cumsum_2 = [0] * (N2 + 1)
for i in range(N2):
cumsum_2[i + 1] = cumsum_2[i] + dist2[i]
ans = 0
for i in range(N1):
j = bisect_left(dist2, m2 - dist1[i])
ans += cumsum_2[N2] - cumsum_2[j] + (N2 - j) * dist1[i] + (N2 - j)
ans += m2 * j
print(ans)