문제

트리에서 리프 노드란, 자식의 개수가 0인 노드를 말한다.

트리가 주어졌을 때, 노드 하나를 지울 것이다. 그 때, 남은 트리에서 리프 노드의 개수를 구하는 프로그램을 작성하시오. 노드를 지우면 그 노드와 노드의 모든 자손이 트리에서 제거된다.

예를 들어, 다음과 같은 트리가 있다고 하자.

현재 리프 노드의 개수는 3개이다. (초록색 색칠된 노드) 이때, 1번을 지우면, 다음과 같이 변한다. 검정색으로 색칠된 노드가 트리에서 제거된 노드이다.

이제 리프 노드의 개수는 1개이다.


입력

첫째 줄에 트리의 노드의 개수 N이 주어진다. N은 50보다 작거나 같은 자연수이다. 둘째 줄에는 0번 노드부터 N-1번 노드까지, 각 노드의 부모가 주어진다. 만약 부모가 없다면 (루트) -1이 주어진다. 셋째 줄에는 지울 노드의 번호가 주어진다.


출력

첫째 줄에 입력으로 주어진 트리에서 입력으로 주어진 노드를 지웠을 때, 리프 노드의 개수를 출력한다.


코드

import sys
from collections import defaultdict
sys.setrecursionlimit(10000)
child = defaultdict(list)
parent = dict()

N = int(input())
L = [int(x) for x in sys.stdin.readline().split()]
def count_leaf(root): # tree에서 leaf 수를 counting하는 함수
    count = 0
    if len(child[root]) == 0: # leaf에 도달하면
        return 1
    for i in range(len(child[root])): # 모든 자식에 대해 재귀호출
        count = count + count_leaf(child[root][i])
    return count # 자식의 subtree의 leaf 합

 for i in range(len(L)):
    if L[i] == -1: # root는 기록
        root=i
        continue
    parent[i] = L[i] # 부모 기록
    child[L[i]].append(i) # 자식 기록
 count = 0
 M = int(input())
 if M==root: # root를 제거하는 경우는 0
     print(0)
 elif len(child[parent[M]]) == 1: # 내가 부모의 유일한 자식일 때
     print(count_leaf(root) - count_leaf(M) + 1)
 else: # 부모가 나 말고도 자식이 있을 때
     print(count_leaf(root) - count_leaf(M))

해설

트리에 관한 정보가 주어지고, 지울 Subtree root번호를 알려준다. Subtree를 제거한 후의 트리의 leaf 노드 수를 구하는 문제이다. 읽자 마자 든 생각은 Tree leaf 수를 counting하는 함수를 만들어서 전체 트리에 대한 함수 결과값에서 제거할 Subtree leaf 노드 수를 빼면 되지 않을까란 생각이 들었다. 하지만 조금 더 생각해보니 항상 그런 것은 아닌 것 같았다. 

 

첫번째는 제거할 노드가 부모의 유일한 자식이다. 그 자식을 제거한다면 부모가 대신 leaf가 된다. 

 

부모는 기존의 Tree에서는 leaf가 아니었다. 반면 제거할 노드의 부모의 유일한 자식이 아니다. 

그 자식이 제거되었을 때, 부모는 다른 자식이 있기 때문에 leaf가 되지 않는다. 

그러므로 두 경우에 1개의 leaf 차이가 생긴다. 

 

이러한 아이디어를 바탕으로 구현한다.

 

Leaf 수를 counting하는 함수는 재귀를 통해 구현한다.

Leaf에 도달하면 1을 반환하고 leaf가 아니라면 그 자식에 대한 함수를 재귀적으로 호출한다.

복사했습니다!