CS/Algorithm

ํฌ๋ฃจ์Šค์นผ ์•Œ๊ณ ๋ฆฌ์ฆ˜(kruskal)

deo2kim 2020. 9. 30. 20:20
๋ฐ˜์‘ํ˜•

์ถœ์ฒ˜: https://onepwnman.github.io/MST/

๐Ÿ“” ํฌ๋ฃจ์Šค์นผ(kruskal) ์ด๋ž€

  • ๊ฐ„์„ ์„ ํ•˜๋‚˜์”ฉ ์„ ํƒํ•ด์„œ MST๋ฅผ ์ฐพ๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜ ( ํ”„๋ฆผ์€ ์ •์ ์„ ์„ ํƒ, ํฌ๋ฃจ์Šค์นผ์€ ๊ฐ„์„ ์„ ์„ ํƒ )
    • ์ตœ์ดˆ ๋ชจ๋“  ๊ฐ„์„ ์„ ๊ฐ€์ค‘์น˜์— ๋”ฐ๋ผ ์˜ค๋ฆ„์ฐจ์ˆœ์œผ๋กœ ์ •๋ ฌ
    • ๊ฐ€์ค‘์น˜๊ฐ€ ๊ฐ€์žฅ ๋‚ฎ์€ ๊ฐ„์„ ๋ถ€ํ„ฐ ์„ ํƒํ•˜๋ฉด์„œ ํŠธ๋ฆฌ๋ฅผ ์ฆ๊ฐ€์‹œํ‚จ๋‹ค
      • ์‚ฌ์ดํด์ด ์กด์žฌํ•˜๋ฉด ํŒจ์“ฐ ( ์‚ฌ์ดํด์˜ ์กด์žฌ๋ฅผ ํŒŒ์•…ํ•˜๋Š”๊ฒŒ ์ค‘์š”! )
    • V-1๊ฐœ์˜ ๊ฐ„์„ ์ด ์„ ํƒ๋  ๋•Œ ๊นŒ์ง€ ๋ฐ˜๋ณต
  • ๊ฐ€์žฅ ์ ์€ ๋น„์šฉ์œผ๋กœ ๋ชจ๋“  ๋…ธ๋“œ๋ฅผ ์—ฐ๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜
  • disjoint-set ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ด์šฉํ•œ๋‹ค.

 

๐Ÿ“” ํฌ๋ฃจ์Šค์นผ(kruskal)  ๊ตฌํ˜„

def make_set(x):
    p[x] = x


def find_set(x):
    if p[x] == x:
        return x
    else:
        p[x] = find_set(p[x])
        return p[x]


def union(x, y):
    px = find_set(x)
    py = find_set(y)
    if rank[px] > rank[py]:
        p[py] = p[px]
    elif rank[py] > rank[px]:
        p[px] = p[py]
    else:
        p[px] = p[py]
        rank[py] += 1


V, E = 7, 11
edges = [
    [0, 5, 60],
    [0, 1, 32],
    [0, 2, 31],
    [0, 6, 51],
    [1, 2, 21],
    [2, 4, 46],
    [2, 6, 25],
    [3, 4, 34],
    [3, 5, 18],
    [4, 5, 40],
    [4, 6, 51],
]

# ๊ฐ„์„  ์ •๋ณด๋ฅผ ๊ฐ€์ค‘์น˜์— ๋”ฐ๋ผ ์˜ค๋ฆ„์ฐจ์ˆœ์œผ๋กœ ์ •๋ ฌ
edges.sort(key=lambda x: x[2])

# ๊ฐ ์ •์ ์˜ ๋ถ€๋ชจ ์ •๋ณด, Rank ์ •๋ณด(ํšจ์œจ์„ฑ)
p = [0] * V
rank = [0] * V

# ๊ฐ ์ •์ ์˜ ๋ถ€๋ชจ๋ฅผ ์ž์‹ ์œผ๋กœ ์„ค์ •ํ•˜๊ธฐ
for i in range(V):
    make_set(i)

# ๊ฐ„์„ ์„ ์„ ํƒํ•˜๋Š” ๊ฐœ์ˆ˜๋Š” V-1๊ฐœ, ๊ฐ„์„ ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋ˆ„์ ํ•œ ๊ฒฐ๊ณผ๊ฐ’
cnt = 0
result = 0
for i in range(E):
    s, e, c = edges[i]

    # s์™€ e๊ฐ€ ๊ฐ™์€ ์ง‘ํ•ฉ์ด๋ฉด( ์‚ฌ์ดํด์ด๋ฉด ) ํŒจ์“ฐ
    if find_set(s) == find_set(e):
        continue

    # ๊ฐ€์ค‘์น˜๋ฅผ ๋”ํ•˜๊ณ 
    result += c

    # s์™€ e๋ฅผ ๊ฐ™์€ ์ง‘ํ•ฉ์œผ๋กœ ๋งŒ๋“ ๋‹ค
    union(s, e)

    # ๊ฐ„์„ ์„ V-1๊ฐœ ์„ ํƒํ–ˆ์œผ๋ฉด ๋
    cnt += 1
    if cnt == V - 1:
        break
        
print(result)

 

๋ฐ˜์‘ํ˜•