You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

33 lines
1.3 KiB

'''
Author: SJ2050
Date: 2021-11-07 13:49:04
LastEditTime: 2021-11-07 15:05:48
Version: v0.0.1
Description: Homework3.2.
Copyright © 2021 SJ2050
'''
import numpy as np
from basic_kMeans import read_csv_data, show_figure, KMeans
if __name__ == '__main__':
data_file = 'dataset_circles.csv'
original_data = read_csv_data(data_file)
distance_func = lambda p1, p2: abs(np.linalg.norm(p1)-np.linalg.norm(p2))
def compute_center_point_func(points):
n = len(points)
return np.array([0 ,sum(np.linalg.norm(points, axis=1))/n])
clustered_points, center_points, iter_num, cost = KMeans(original_data[:, 0:2], 2, \
distance_func, \
compute_center_point_func, \
1000, 1e-3).run()
print(f'迭代次数: {iter_num}, 代价函数值为: {cost:.3f}')
classified_points = np.array([(clustered_points[i][j][0], clustered_points[i][j][1], i) \
for i in range(2) \
for j in range(len(clustered_points[i]))])
center_points = [(center_points[i][0], center_points[i][1], i) for i in range(2)]
show_figure(classified_points, center_points)