You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

32 lines
1.2 KiB

'''
Author: SJ2050
Date: 2021-11-05 23:21:46
LastEditTime: 2021-11-07 15:14:50
Version: v0.0.1
Description: Homework3.1.
Copyright © 2021 SJ2050
'''
import numpy as np
from basic_kMeans import read_csv_data, show_figure, KMeans
if __name__ == '__main__':
data_file = 'dataset_circles.csv'
original_data = read_csv_data(data_file)
distance_func = lambda p1, p2: np.linalg.norm(p1 - p2)
def compute_center_point_func(points):
n = len(points)
return sum(points)/n
clustered_points, center_points, iter_num, cost = KMeans(original_data[:, 0:2], 2, \
distance_func, \
compute_center_point_func, \
1000, 1e-3).run()
print(f'迭代次数: {iter_num}, 代价函数值为: {cost:.3f}')
classified_points = np.array([(clustered_points[i][j][0], clustered_points[i][j][1], i) \
for i in range(2) \
for j in range(len(clustered_points[i]))])
center_points = [(center_points[i][0], center_points[i][1], i) for i in range(2)]
show_figure(classified_points, center_points)