You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
1.2 KiB
32 lines
1.2 KiB
'''
|
|
Author: SJ2050
|
|
Date: 2021-11-05 23:21:46
|
|
LastEditTime: 2021-11-07 15:14:50
|
|
Version: v0.0.1
|
|
Description: Homework3.1.
|
|
Copyright © 2021 SJ2050
|
|
'''
|
|
import numpy as np
|
|
from basic_kMeans import read_csv_data, show_figure, KMeans
|
|
|
|
if __name__ == '__main__':
|
|
data_file = 'dataset_circles.csv'
|
|
original_data = read_csv_data(data_file)
|
|
|
|
distance_func = lambda p1, p2: np.linalg.norm(p1 - p2)
|
|
def compute_center_point_func(points):
|
|
n = len(points)
|
|
return sum(points)/n
|
|
|
|
clustered_points, center_points, iter_num, cost = KMeans(original_data[:, 0:2], 2, \
|
|
distance_func, \
|
|
compute_center_point_func, \
|
|
1000, 1e-3).run()
|
|
print(f'迭代次数: {iter_num}, 代价函数值为: {cost:.3f}')
|
|
|
|
classified_points = np.array([(clustered_points[i][j][0], clustered_points[i][j][1], i) \
|
|
for i in range(2) \
|
|
for j in range(len(clustered_points[i]))])
|
|
center_points = [(center_points[i][0], center_points[i][1], i) for i in range(2)]
|
|
show_figure(classified_points, center_points)
|