본문 바로가기

백준 문제풀이 코드저장소/Gold

Baekjoon 1197. 최소 스패닝 트리 / Python

728x90

https://www.acmicpc.net/problem/1197

2024.12.16 - [알고리즘/알고리즘 정리] - 알고리즘 유형공부 02 - 최소신장트리 (MST)

 

알고리즘 유형공부 02 - 최소신장트리 (MST)

최소신장트리, 최소스패닝트리 - Minimum Spanning Tree개요신장트리란 무엇일까?신장트리(Spanning Tree)란 - 그래프에서 모든 정점을 포함하며, 모든 정점에 대한 최소한의 연결만을 가진 그래프이다.그

developer-traxer.tistory.com

알고리즘 설명은 위 글을 참고하면 좋다.

문제는 그냥 이 최소스패닝트리를 구현하는 것이다.

1. PRIM 알고리즘

이 경우에는 시간초과가 난다. 그 이유는 프림알고리즘의 시간복잡도는 I(V * ElogV)이므로 O(10^11)이기 때문이다.

보통 1초에 1억번 연산을 수행한다고 알려져있는 백준 컴파일러에서는 1초의 시간당 O(10^8)의 시간복잡도를 맞춰주어야한다. 하지만 이 알고리즘을 사용하면 1000초...가 걸린다. 

import sys
# sys.stdin = open("C:/Users/ghtjd/Desktop/tmp/python/input.txt", "r")

input = sys.stdin.readline

from heapq import heappop, heappush

v, e = map(int, input().split())
arr = [[] for _ in range(v + 1)]
for _ in range(e):
    a, b, c = map(int, input().split())
    arr[a].append((c, b))
    arr[b].append((c, a))

res = 1e10

def bfs(s):
    global res
    q = []
    visit = [0] * (v + 1)
    heappush(q, (0, s))
    sum_weight = 0

    while q:
        w, now = heappop(q)
        if visit[now]:
            continue

        visit[now] = 1
        sum_weight += w

        for next_weight, next_node in arr[now]:
            if not visit[next_node]:
                heappush(q, (next_weight, next_node))
    
    res = min(res, sum_weight)

for i in range(1, v + 1):
    bfs(i)

print(res)

이에 반해 크루스칼 알고리즘의 경우 시간복잡도는 O(ElogE)이므로 계산하면 약 O(2*10^7)이 나온다.

따라서, 제한시간 1초(O(10^8))에 충분히 부합할 것이다.

문제의 난이도가 올라갈수록, 알고리즘과 자료구조 선택에 신중을 기해야 한다.

2. KRUSKAL 알고리즘

참고 : 백준의 재귀한도는 1000으로 정해져있기 때문에, 재귀한도를 늘려줄 필요가 있다.

sys.setrecursionlimit(100000)

import sys
# sys.stdin = open("C:/Users/ghtjd/Desktop/tmp/python/input.txt", "r")

input = sys.stdin.readline
sys.setrecursionlimit(100000)

v, e = map(int, input().split())
arr = []

for _ in range(e):
    a, b, c = map(int, input().split())
    arr.append([a, b, c])

arr.sort(key=lambda x: x[2])

sum_weight = 0
parents = [i for i in range(v + 1)]

def union(x):
    if parents[x] == x: return x
    parents[x] = union(parents[x])
    return parents[x]

def union_set(x, y):
    x, y = union(x), union(y)
    if x == y:return
    if x < y: parents[y] = x
    if x > y: parents[x] = y

for a, b, c in arr:
    if union(a) == union(b):continue
    union_set(a, b)
    sum_weight += c

print(sum_weight)
반응형