ABC401-F: Add One Edge 3

Page content

問題概要

  • 2 つの木が与えられる
  • 各木から頂点を 1 つずつ選んで辺を追加すると新しい木が得られる
  • 全ペアについて新しい木の直径の総和を求めよ

後付けで言語化した思考過程

  • まず 木1の最遠点までの距離 + 木2の最遠点までの距離 + 1 を考えたくなる

    • 最遠点といえば直径だが、任意の点に対する最遠点はどうすれば?
      • そもそも木の直径を求めるアルゴリズムが ↓ なので、直径の両端のどっちかだろ多分
        • 適当な点から最遠点を求める
        • 求めた最遠点から最遠点を求める
      • 適当な点、端点 1、端点 2 で計 3 回 bfs して、端点 1 と端点 2 の max をとれば良さそう
  • 木 1 の各頂点からの最遠点までの距離 dist1 と、木 2 ついての dist2 が得られる

    • sum(dist1) * N2 + sum(dist2) * N1 + N1 * N2 や! -> WA
    • あくまで求めるのは直径の総和で、両方の木を通る最長パスではないですね
      • クソデカ木にクソチビ木を合成しても直径が変わらないよね
        • 木 1: 1-2-3-4-5
        • 木 2: 1
        • 木 2 を 3 に付けても直径は 4
  • 両方の木を使うかどうかの境界を出したい

    • ソートして二分探索でよさそう(実家)

    • なんかこんな感じのことをする(添字をバグらせがちなのでガチャガチャがんばる)

      • # m2は木2の直径
        for i in range(N1):
          j = bisect_left(dist2, m2 - dist1[i])
          ans += cumsum_2[N2] - cumsum_2[j] + (N2 - j) * dist1[i] + (N2 - j)
          ans += m2 * j
        
    • ↑ だと木 1 の直径 > 木 2 の直径のとき負数になって面倒なことがわかったので適当に置換しておく。おりゃー

      • m1 = max(dist1)
        m2 = max(dist2)
        
        if m1 > m2:
          dist1, dist2 = dist2, dist1
          N1, N2 = N2, N1
          m1, m2 = m2, m1
        
  • AC

実装

提出

コード

from bisect import bisect_left
from collections import deque
import sys

input = sys.stdin.readline

N1 = int(input())
E1 = [[] for _ in range(N1)]
for _ in range(N1 - 1):
    u, v = list(map(int, input().split()))
    u -= 1
    v -= 1
    E1[u].append(v)
    E1[v].append(u)

N2 = int(input())
E2 = [[] for _ in range(N2)]
for _ in range(N2 - 1):
    u, v = list(map(int, input().split()))
    u -= 1
    v -= 1
    E2[u].append(v)
    E2[v].append(u)

INF = 10**18


def calc_longest_path_lengths(N, E):
    dist = [INF] * N
    dist[0] = 0
    deq = deque([0])
    while deq:
        v = deq.popleft()
        for u in E[v]:
            if dist[u] > dist[v] + 1:
                dist[u] = dist[v] + 1
                deq.append(u)

    dist2 = [INF] * N
    v = dist.index(max(dist))
    dist2[v] = 0
    deq = deque([v])
    while deq:
        v = deq.popleft()
        for u in E[v]:
            if dist2[u] > dist2[v] + 1:
                dist2[u] = dist2[v] + 1
                deq.append(u)

    dist3 = [INF] * N
    v = dist2.index(max(dist2))
    dist3[v] = 0
    deq = deque([v])
    while deq:
        v = deq.popleft()
        for u in E[v]:
            if dist3[u] > dist3[v] + 1:
                dist3[u] = dist3[v] + 1
                deq.append(u)

    dist = [max(dist2[i], dist3[i]) for i in range(N)]
    return dist


dist1 = calc_longest_path_lengths(N1, E1)
dist2 = calc_longest_path_lengths(N2, E2)

dist1.sort()
dist2.sort()
m1 = max(dist1)
m2 = max(dist2)

if m1 > m2:
    dist1, dist2 = dist2, dist1
    N1, N2 = N2, N1
    m1, m2 = m2, m1

cumsum_2 = [0] * (N2 + 1)
for i in range(N2):
    cumsum_2[i + 1] = cumsum_2[i] + dist2[i]

ans = 0
for i in range(N1):
    j = bisect_left(dist2, m2 - dist1[i])
    ans += cumsum_2[N2] - cumsum_2[j] + (N2 - j) * dist1[i] + (N2 - j)
    ans += m2 * j

print(ans)