Computer Science/Algorithm

[Python 3] BOJ 11438 LCA 2 (+Binary Lifting)

무니화니 2024. 10. 5. 13:29

https://www.acmicpc.net/problem/11438

[English Translation]

Tree with N vertices is given. Each vertex of a tree is number from 1 to N, and the root is number 1.

If M pairs of nodes are given, return the closest LCA of these two nodes.

[Input]

in the first line, count of nodes N is given.

After that, n-1 lines of two connected vertices of tree is given. After that, desired count of the LCA of two nodes are given, and for net M lines, pairs of two nodes are given.

[Output]

Print M lines of LCA of two points.

해당 문제 알고리즘

  • 최소 공통 조상(LCA)
  • 희소 배열 (Sparse Matrix)

우선적으로 최소 공통 조상이란,

트리 구조에서 임의의 두 정점이 갖는 가장 가까운 조상 정점을 의미하게 된다.

여기서 조상 정점은, 트리의 루트 노드가 1일때를 기준으로, 루트의 하위 노드, 루트의 하위 노드의 하위 노드와 같은 식으로 하위 노드들이 생기게 된다.

두 노드의 공통 조상 중 제일 노드에서부터의 길이가 긴 노드가 최소 공통 조상이 되는 것이다.

따라서, 두 노드의 Depth를 같게 맞추고, 루트 방향으로 한 노드씩 올라가면서 공통 노드인지, 아닌지를 파악할 수 있다.

하지만, 이러한 Naive한 방법은

Sparse Table은 주어진 배열의 특정 구간에서의 최솟값, 최댓값, 합과 같은 계산을 바로 O(1)의 시간으로 찾을 수 있게끔 하는 배열이다. 여기서 포인트는 "미리" 계산을 한다는 것이다. 미리 계산을 O(NlogN)에 할 수 있다.

Sparse Table과 유사한 알고리즘인 Binary Lifting은 Sparse Table의 미리 계산을 해둔다는 사실을 이용한다. Binary Lifting은 부모 노드로 빠르게 이동하거나, LCA를 찾기 위해서 사용되는데, 이진수를 이용해서 Log 시간에 조상으로 이동할 수 있게끔 해준다.

import sys
input=sys.stdin.readline
sys.setrecursionlimit(int(1e5))

def dfs(x,d):
    visited[x]=1
    depth[x]=d
    for i in data[x]:
        if visited[i]:
            continue
        parents[i][0]=x
        dfs(i,d+1)

def lca(a,b):
    for i in range(16,-1,-1):
        if depth[a]>depth[b]:
            a,b=b,a
        if depth[b]-depth[a]>=2**i:
            b=parents[b][i]
    if a==b:
        return a
    for i in range(16,-1,-1):
        if parents[a][i]!=parents[b][i]:
            a=parents[a][i]
            b=parents[b][i]
    return parents[a][0]

n=int(input())
data=[list() for _ in range(n+1)]
parents=[list(0 for _ in range(17)) for _ in range(n+1)]
visited=list(0 for _ in range(n+1))
depth=[0 for _ in range(n+1)]
for _ in range(n-1):
    a,b=map(int,input().split())
    data[a].append(b)
    data[b].append(a)
m=int(input())
dfs(1,0)
for j in range(1,n+1):
    for i in range(1,17):
        parents[j][i]=parents[parents[j][i-1]][i-1]
for _ in range(m):
    a,b=map(int,input().split())
    print(lca(a,b))

먼저, input들을 받았다.

data는 어떤 노드가 다른 어떤 노드들과 연결이 되어있는지를 저장하는 리스트이고,

parents는 특정 노드의 부모 노드를 의미한다.

여기서 굉장히 중요한 포인트가 나온다.
기존에 parents 배열은 바로 위의 노드만 저장할 수 있었다면,
여기서 parents 배열은 이차원 리스트로,
parents[ 특정 노드 ] [ 2**i 번째 부모 ] 노드를 저장한다.

즉,

parents[3][0]은 3번째 노드의 2**0 (=1) 바로 위 노드를,

parents[23][2]은 23번째 노드의 2**2 (=4) 4개 위의 노드를 의미한다. 

 

depth는 루트 노드에서부터의 깊이, visited는 처음에 dfs를 진행할 때 이미 방문했던 노드를 방문하지 않기 위해 사용한다.

또한 숫자 17은 binary lifting을 할 때, 2**17, 131,072개의 노드로 이동할 수 있게끔 하기 위해서 사용하는 숫자이다.

예를 들어, 트리가 루트 노드에서부터 1-2-3-4-......-100,000까지 이어져있다고 쳐보자.

그러면 parents[100,000][0]은 99,999을, parents[100,000][1]은 99,998로 이어지고, parents[100,000][16]은 100,000에서 65536을 뺀 34,464번째 노드로 이어지게끔 저장할 수 있다.

 

먼저 dfs를 통해서 첫번째 노드에서부터 모든 노드들을 다 방문하면서, depth랑 parents를 기록한다.

이후 모든 노드들을 대상으로 binary lifting을 진행한다. 

이중 for 문으로, n개의 모든 노드들을 바탕으로2**16부터 2**0까지의 binary lifting을 한다.

이후 LCA를 구하고 싶은 노드들을 바탕으로 LCA 1에서 사용했던 lca를 진행하는데, 앞에서 구했던 binary lifting을 응용한다.

이때 더 낮은 depth에 있는 노드를 더 높이 있는 depth에 있는 노드에 binary lifting을 이용하여 같은 높이로 맞춰주고,

여기서부터 binary lifting을 한 번 더 이용하여 LCA를 찾는다.

 

예시를 들어, 노드 2는 depth 12에, 노드 3은 depth 15에 있다고 하자.

그러면 노드 3을 노드 2의 depth인 12에 맞추기 위해 2만큼, 1만큼 위로 올려보낸다.

이후, 2**16, 2**15, ... 2**3, 2**2, 2**1, 2**0만큼 올려보려고 노력해서 LCA를 구해본다.

LCA가 depth 3에 있다고 하면, 9만큼 위로 올려야 하기 때문에, 2**3만큼 두 노드를 올리고, 2**0만큼 두 노드를 올리면 된다.