You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
140 lines
4.8 KiB
140 lines
4.8 KiB
'''
|
|
Author: SJ2050
|
|
Date: 2021-11-05 23:21:46
|
|
LastEditTime: 2021-11-07 15:10:24
|
|
Version: v0.0.1
|
|
Description: Basic K-Means algorithm.
|
|
Copyright © 2021 SJ2050
|
|
'''
|
|
import csv
|
|
import random
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
def read_csv_data(data_file):
|
|
"""Read data from file(.csv format).
|
|
|
|
Args:
|
|
data_file: Raw data file(.csv).
|
|
|
|
Returns:
|
|
A 2d numpy array of data from input file.
|
|
"""
|
|
with open(data_file, 'r') as fp_inp_data:
|
|
reader = csv.reader(fp_inp_data)
|
|
result = [(item[0], item[1], item[2]) for item in reader]
|
|
|
|
return np.array(result, dtype=np.float64)
|
|
|
|
def show_figure(points, center_points):
|
|
"""Show result through figure.
|
|
|
|
Args:
|
|
points: Points (size: N*3). For each point, first and second components represent location
|
|
and third compoint represents class type.
|
|
center_points: Center points to be highlighted.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
points = np.array(points)
|
|
x = points[:, 0]
|
|
y = points[:, 1]
|
|
c = points[:, 2]
|
|
plt.scatter(x, y, s=20, c=c)
|
|
|
|
if len(center_points) > 0:
|
|
center_points = np.array(center_points)
|
|
center_x = center_points[:, 0]
|
|
center_y = center_points[:, 1]
|
|
center_c = center_points[:, 2]
|
|
plt.scatter(center_x, center_y, s=100, c=center_c, marker='x')
|
|
|
|
plt.show()
|
|
|
|
class KMeans():
|
|
"""Basic KMeans class.
|
|
"""
|
|
def __init__(self, points, class_num, distance_func, compute_center_point_func, \
|
|
max_iter_num, atol):
|
|
self.points = np.array(points, dtype=np.float64)
|
|
self.class_num = class_num
|
|
self.distance_func = distance_func
|
|
self.compute_center_point_func = compute_center_point_func
|
|
self.max_iter_num = max_iter_num
|
|
self.atol = atol
|
|
self.clustered_points = []
|
|
self.center_points = []
|
|
|
|
@property
|
|
def cost(self):
|
|
assert len(self.clustered_points) == self.class_num, "点簇数目与分类数不一致"
|
|
assert len(self.center_points) == self.class_num, "中点的个数与分类数不一致"
|
|
|
|
cost = 0
|
|
for k in range(self.class_num):
|
|
for i in range(len(self.clustered_points[k])):
|
|
cost += self.distance_func(self.clustered_points[k][i], self.center_points[k])**2
|
|
|
|
return cost
|
|
|
|
def choose_which_class_belonging_to(self, point):
|
|
distances = [self.distance_func(point, self.center_points[i]) for i in range(len(self.center_points))]
|
|
return np.argmin(distances)
|
|
|
|
def initialize(self):
|
|
n = len(self.points)
|
|
center_points_indices = random.sample([i for i in range(n)], self.class_num)
|
|
self.center_points = self.points[center_points_indices]
|
|
self.clustered_points = [[] for i in range(self.class_num)]
|
|
|
|
for p in self.points:
|
|
belonging_class = self.choose_which_class_belonging_to(p)
|
|
self.clustered_points[belonging_class].append(p)
|
|
|
|
def cluster(self):
|
|
self.center_points = [self.compute_center_point_func(self.clustered_points[i]) for i in range(self.class_num)]
|
|
|
|
self.clustered_points = [[] for i in range(self.class_num)]
|
|
|
|
for p in self.points:
|
|
belonging_class = self.choose_which_class_belonging_to(p)
|
|
self.clustered_points[belonging_class].append(p)
|
|
|
|
def run(self):
|
|
self.initialize()
|
|
prev_cost = None
|
|
curr_cost = self.cost
|
|
for i in range(self.max_iter_num):
|
|
if prev_cost and abs(curr_cost - prev_cost) < self.atol:
|
|
break
|
|
self.cluster()
|
|
prev_cost = curr_cost
|
|
curr_cost = self.cost
|
|
|
|
return self.clustered_points, self.center_points, i, curr_cost
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# test
|
|
def compute_center_point(points):
|
|
n = len(points)
|
|
return sum(points)/n
|
|
|
|
data_file = 'dataset_circles.csv'
|
|
original_data = read_csv_data(data_file)
|
|
|
|
distance_func = lambda p1, p2: np.linalg.norm(p1 - p2)
|
|
clustered_points, center_points, iter_num, cost = KMeans(original_data[:, 0:2], 2, \
|
|
distance_func, \
|
|
compute_center_point, \
|
|
1000, 1e-3).run()
|
|
print(f'迭代次数: {iter_num}, 代价函数值为: {cost:.3f}')
|
|
|
|
classified_points = np.array([(clustered_points[i][j][0], clustered_points[i][j][1], i) \
|
|
for i in range(2) \
|
|
for j in range(len(clustered_points[i]))])
|
|
center_points = [(center_points[i][0], center_points[i][1], i) for i in range(2)]
|
|
show_figure(classified_points, center_points)
|