본문 바로가기
Python

[Python] K-Means 알고리즘 구현(2)

by 돌맹96 2023. 10. 29.
728x90
반응형

Step 2.생성된 400개 좌표값중 200개씩 임의로 선택해서
반반 나눔
그렇게 2개 그룹을 차트에서 빨강/파랑으로 표시한다.
차트 왼쪽위에 범례로 A그룹은 빨강 B그룹은 파랑이라는걸 표시한다.

2개 그룹 대한 평균값을 각각 구하고

두그룹의 평균값 구한걸로
모든 점들과 A평균값 거리
모든 점들과 B평균값 거리 구해봐서
다시 이제 가까운쪽으로 점들 마다 A평균값에 가까우면 A그룹
B평균값에 가까우면 B그룹으로 만든다

그렇게 바뀌어진 모습을 차트로 표시해준다.

 

Step 3. 저 두번째 상황을 계속 반복하는데 평균값이 변하지 않을때가 종료시점이다.
아마도 대부분 10번 이내에 나올것같다.
평균값구해서 거리 구해서 차트 만드는걸 반복하는데
그 차트들을 계속 캡쳐해서 변화되는모습을 찍어야한다.

 

python Code

import random
import matplotlib.pyplot as plt
import numpy as np

# 400개의 중복 없는 좌표 생성
x_coords = random.sample(range(0, 401), 400)
y_coords = random.sample(range(0, 401), 400)
points = list(zip(x_coords, y_coords))
random.shuffle(points)  # 생성한 좌표를 랜덤하게 섞는다.

# 좌표 중 200개를 임의로 선택하여 A 그룹과 B 그룹으로 나눔
A_group = random.sample(points, 200)
B_group = [point for point in points if point not in A_group]

def compute_mean(group):
    """해당 그룹의 좌표의 평균 값을 계산한다."""
    return np.mean(group, axis=0)

for iteration in range(10):
    # 각 그룹의 평균 좌표를 계산
    A_mean = compute_mean(A_group)
    B_mean = compute_mean(B_group)

    new_A_group = []
    new_B_group = []

    # 각 좌표를 가장 가까운 평균 좌표를 가진 그룹에 할당
    for point in points:
        distance_to_A = np.linalg.norm(np.array(point) - np.array(A_mean))
        distance_to_B = np.linalg.norm(np.array(point) - np.array(B_mean))

        if distance_to_A < distance_to_B:
            new_A_group.append(point)
        else:
            new_B_group.append(point)

    # 만약 그룹 할당이 변하지 않았다면 알고리즘이 수렴했음을 의미
    if set(A_group) == set(new_A_group) and set(B_group) == set(new_B_group):
        print(f"Converged after {iteration + 1} iterations.")
        break

    A_group = new_A_group
    B_group = new_B_group

    # 그룹별로 x, y 좌표를 분리하여 시각화 준비
    A_group_x, A_group_y = zip(*A_group)
    B_group_x, B_group_y = zip(*B_group)

    # 그룹 시각화
    plt.figure(figsize=(8, 8))
    plt.scatter(A_group_x, A_group_y, color='red', label="A Group", s=10)
    plt.scatter(B_group_x, B_group_y, color='blue', label="B Group", s=10)
    plt.legend()
    plt.title(f"Iteration {iteration + 1}")
    plt.xlabel("X-coordinate")
    plt.ylabel("Y-coordinate")
    plt.grid(True)
    plt.savefig(f"iteration_{iteration + 1}.png")
    plt.show()

결과사진

10번정도 반복하다가 멈추는걸 확인했다. K-Mean알고리즘 구현 참고하시길 바란다.

728x90
반응형