You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
4.8 KiB

'''
Author: SJ2050
Date: 2021-11-05 23:21:46
LastEditTime: 2021-11-07 15:10:24
Version: v0.0.1
Description: Basic K-Means algorithm.
Copyright © 2021 SJ2050
'''
import csv
import random
import numpy as np
from matplotlib import pyplot as plt
def read_csv_data(data_file):
"""Read data from file(.csv format).
Args:
data_file: Raw data file(.csv).
Returns:
A 2d numpy array of data from input file.
"""
with open(data_file, 'r') as fp_inp_data:
reader = csv.reader(fp_inp_data)
result = [(item[0], item[1], item[2]) for item in reader]
return np.array(result, dtype=np.float64)
def show_figure(points, center_points):
"""Show result through figure.
Args:
points: Points (size: N*3). For each point, first and second components represent location
and third compoint represents class type.
center_points: Center points to be highlighted.
Returns:
None.
"""
points = np.array(points)
x = points[:, 0]
y = points[:, 1]
c = points[:, 2]
plt.scatter(x, y, s=20, c=c)
if len(center_points) > 0:
center_points = np.array(center_points)
center_x = center_points[:, 0]
center_y = center_points[:, 1]
center_c = center_points[:, 2]
plt.scatter(center_x, center_y, s=100, c=center_c, marker='x')
plt.show()
class KMeans():
"""Basic KMeans class.
"""
def __init__(self, points, class_num, distance_func, compute_center_point_func, \
max_iter_num, atol):
self.points = np.array(points, dtype=np.float64)
self.class_num = class_num
self.distance_func = distance_func
self.compute_center_point_func = compute_center_point_func
self.max_iter_num = max_iter_num
self.atol = atol
self.clustered_points = []
self.center_points = []
@property
def cost(self):
assert len(self.clustered_points) == self.class_num, "点簇数目与分类数不一致"
assert len(self.center_points) == self.class_num, "中点的个数与分类数不一致"
cost = 0
for k in range(self.class_num):
for i in range(len(self.clustered_points[k])):
cost += self.distance_func(self.clustered_points[k][i], self.center_points[k])**2
return cost
def choose_which_class_belonging_to(self, point):
distances = [self.distance_func(point, self.center_points[i]) for i in range(len(self.center_points))]
return np.argmin(distances)
def initialize(self):
n = len(self.points)
center_points_indices = random.sample([i for i in range(n)], self.class_num)
self.center_points = self.points[center_points_indices]
self.clustered_points = [[] for i in range(self.class_num)]
for p in self.points:
belonging_class = self.choose_which_class_belonging_to(p)
self.clustered_points[belonging_class].append(p)
def cluster(self):
self.center_points = [self.compute_center_point_func(self.clustered_points[i]) for i in range(self.class_num)]
self.clustered_points = [[] for i in range(self.class_num)]
for p in self.points:
belonging_class = self.choose_which_class_belonging_to(p)
self.clustered_points[belonging_class].append(p)
def run(self):
self.initialize()
prev_cost = None
curr_cost = self.cost
for i in range(self.max_iter_num):
if prev_cost and abs(curr_cost - prev_cost) < self.atol:
break
self.cluster()
prev_cost = curr_cost
curr_cost = self.cost
return self.clustered_points, self.center_points, i, curr_cost
if __name__ == '__main__':
# test
def compute_center_point(points):
n = len(points)
return sum(points)/n
data_file = 'dataset_circles.csv'
original_data = read_csv_data(data_file)
distance_func = lambda p1, p2: np.linalg.norm(p1 - p2)
clustered_points, center_points, iter_num, cost = KMeans(original_data[:, 0:2], 2, \
distance_func, \
compute_center_point, \
1000, 1e-3).run()
print(f'迭代次数: {iter_num}, 代价函数值为: {cost:.3f}')
classified_points = np.array([(clustered_points[i][j][0], clustered_points[i][j][1], i) \
for i in range(2) \
for j in range(len(clustered_points[i]))])
center_points = [(center_points[i][0], center_points[i][1], i) for i in range(2)]
show_figure(classified_points, center_points)