Python入门,实现KNN算法

avatar 2024年04月19日17:42:26 0 85 views
博主分享免费Java教学视频,B站账号:Java刘哥

KNN 流程图如下

完整代码

from numpy import *

import matplotlib.pyplot as plt

# 故事背景

# 游客   吃冰淇淋数目    喝水数目    互动时长   天热感觉
# 张三         8          4          2       非常热
# 李四         7          1          1       非常热
# 王五         1          4          4       一般热
# 刘六         3          0          5       一般热
# 陈七         2          4          4        ?


# 分别求 陈七 和 其他人的距离,分别是 6.3、6.5、1.0、4.2
# 进行排序获得索引,分别是 2、3、0、1。分别对应 王五、刘六、张三、李四
# 2、3、0、1 对应的标签分别是 一般热、一般热、非常热、非常热
# 取前k条数据,这里k=3,进行分类统计数据,得到 {'一般热': 2, '非常热': 1}
# 根据数量倒序排序后,得到 [('一般热', 2), ('非常热', 1)]
# 取结果的[0][0],即 “一般热” 是陈七的天热感觉预测值

# 总而言之,就是求陈七的点到其他数据点的距离,然后取前k个数据进行统计数量,获取数量最多的标签名,就是陈七的预测值

# 创建数据源,返回数据集和类标签
def create_dateset():
    datesets = array([[8, 4, 2],
                      [7, 1, 1],
                      [1, 4, 4],
                      [3, 0, 5]])  # 数据集,4个样本
    labels = ['非常热', '非常热', '一般热', '一般热']  # 类标签
    return datesets, labels


plt.rcParams['font.sans-serif'] = 'SimHei'


def analyze_data_plot(x, y):
    fig = plt.figure()
    # 将画布划分为1行1列1块
    ax = fig.add_subplot(111)  # 1行1列第1块
    ax.scatter(x, y)
    # 设置散点图标题和横纵坐标
    plt.title('游客冷热感知点散点图')
    plt.xlabel('天热吃冰淇淋数目')
    plt.ylabel('天热喝水数目')

    # 保存截图
    plt.savefig('data_plot.png')
    plt.show()


# 计算两点之间的欧氏距离3
def compute_euclidean_distance(newV, datasets):
    rowsize, colsize = datasets.shape  # 获取数据集的行数和列数
    diffMat = tile(newV, (rowsize, 1)) - datasets  # 将newV重复4次,生成一个4行1列的矩阵,然后减去数据集
    sqDiffMat = diffMat ** 2  # 对矩阵中的每个元素进行平方
    result = sqDiffMat.sum(axis=1) ** 0.5  # 对矩阵中的每一行进行求和,然后开方。axis=1表示按行求和
    return result


# KNN分类器, 作用是对新数据进行分类
def knn_classifier(newV, datasets, labels, k):
    import operator

    # 1 计算newV与datasets中每个点的距离
    SqrtDist = compute_euclidean_distance(newV, datasets)

    # 2 根据距离排序,获取索引
    sortedDistIndexs = SqrtDist.argsort(axis=0)  # 对距离进行排序,返回排序后的索引值。axis=0表示按列向量排序

    # 3 针对k个点,统计各个类别数量
    classCount = {}
    for i in range(k):
        votelabel = labels[sortedDistIndexs[i]]  # 获取排序后的索引值对应的类标签
        classCount[votelabel] = classCount.get(votelabel, 0) + 1  # 统计各个类标签的数量

    # 4 投票机制,少数服从多数原则,输入类别
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),
                              reverse=True)  # 对classCount字典进行排序, itemgetter(1) 表示按照第二个元素进行排序. reverse=True表示降序
    # print(newV, 'KNN投票预测结果是:', sortedClassCount[0][0])
    return sortedClassCount[0][0]


def predict_temperature():
    # 1. 创建数据集和类标签
    datesets, labels = create_dateset()

    # 2. 采访新访客
    # newV = [2, 4, 4]
    # newV = [8, 8, 1]
    x1 = float(input('请输入吃冰淇淋个数:'))
    x2 = float(input('请输入喝水杯数:'))
    x3 = float(input('请输入户外活动小时数:'))
    newV = [x1, x2, x3]
    res = knn_classifier(newV, datesets, labels, 3)  # k=3表示取最近的3个点
    print('预测结果:', res)


if __name__ == '__main__':
    # 预测
    predict_temperature()

 

代码参考:https://www.imooc.com/learn/1069

流程图来自:https://blog.csdn.net/m0_74405427/article/details/133714384

  • 微信
  • 交流学习,有偿服务
  • weinxin
  • 博客/Java交流群
  • 资源分享,问题解决,技术交流。群号:590480292
  • weinxin
avatar

发表评论

avatar 登录者:匿名
匿名评论,评论回复后会有邮件通知

  

已通过评论:0   待审核评论数:0