KNN算法详解

Wesley13
• 阅读 572

    简单的说,K近邻算法是采用不同特征值之间的距离方法进行分类。

   该方法优点:精确值高、对异常值不敏感、无数据输入假定

   缺点:计算复杂度高、空间复杂度高

   适用范围:数据型和标称型

   现在我们来讲KNN算法的工作原理:存在一个样本数据集,也称作训练样本集,并且样本中每条数据都存在标签。将新输入的没有标签的数据与训练样本数据集中每条数据进行距离计算,选择前K个最小距离。并统计出现次数最多的分类。将该分类作为新数据的标签。

  例如:使用k-近邻算法分类爱情片和动作片。

            训练样本数据集如下:

            KNN算法详解

           此时给定一部电影的统计数据:打斗镜头:18,接吻镜头:90,判断该电影属于哪种类型?

          分别计算该条数据与样本集中各条数据的距离:

          KNN算法详解

         假设K=3,将距离从小到大排列,选择前三条数据,并统计出该三条数据中所属类别最多的类别。并将该类别赋值给新输入数据。如上所属,该输入数据的类别为爱情片。

        K-近邻算法的一般流程:

       (1)收集数据

       (2)准备数据:距离计算所需的数值,最好是结构欧化数据格式

       (3)分析数据

       (4)训练数据:此步骤不适用与K-近邻算法

       (5)测试数据:计算错误率

       (6)使用数据:首先需要输入样本数据和结构化数据的输出结果,然后运用K近邻算法判断输入数据分别属于哪一类,最后应用对计算出的分类执行后续处理

        实现代码如下:

import numpy as np
import operator
from os import listdir
def createDateSet():
    group=np.array([[1,1.1],[1,1],[0,0],[0,0.1]])
    labels=['A','A','B','B']
    return group,labels
#分类
def classify(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    diffMat=np.tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)
    distances=sqDistances**0.5
    ##argsort()根据元素的值从大到小对元素进行排序,返回下标
    sortedDistance=distances.argsort()
    classCount={}
    for i in range(k):
        voteIlabel=labels[sortedDistance[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    sortedClassCount=np.sort(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
#获取数据
def filematrix(filename):
    fr=open(filename)
    arrayLine=fr.readlines()
    numberOfLines=len(arrayLine)
    returnMat=np.zeros((numberOfLines,3))
    classLabelVector=[]
    index=0
    for line in arrayLine:
        line=line.strip().split('\t')
        returnMat[index,:]=line[0:3]
        classLabelVector.append(int(line[-1]))
        index+=1
    return returnMat,classLabelVector
#归一化数据
def autoNorm(dataSet):
    minVals=dataSet.min(0) #0表示列
    maxVals=dataSet.max(0)
    ranges=maxVals-minVals
    normDataSet=np.zeros(np.shape(dataSet))
    m=dataSet.shape[0]
    normDataSet=dataSet-np.tile(minVals,(m,1))
    normDataSet=normDataSet/np.tile(ranges,(m,1))
    return normDataSet,ranges,minVals

def datingClassTest():
    hoRatio=0.10
    datingDataMat,datingLabels=filematrix('../data/testSet.txt')
    normMat,rangs,minVals=autoNorm(datingDataMat)
    m=normMat.shape[0]
    numTestVecs=int(m*hoRatio)
    errorCount=0.0
    for i in range(numTestVecs):
        classifierResult=classify(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("预测类别为:",classifierResult,"实际类别:",datingLabels[i])
        if(classifierResult!=datingLabels[i]):
            errorCount+=1
    print("总的错误率:",(errorCount/float(numTestVecs)))

def img2vector(filename):
    returnVect=np.zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readlines()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = np.zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify(vectorUnderTest, trainingMat, hwLabels, 3)
        print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\nthe total number of errors is: %d" % errorCount)
    print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))
点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
4个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
皕杰报表(关于日期时间时分秒显示不出来)
在使用皕杰报表设计器时,数据据里面是日期型,但当你web预览时候,发现有日期时间类型的数据时分秒显示不出来,只有年月日能显示出来,时分秒显示为0:00:00。1.可以使用tochar解决,数据集用selecttochar(flowdate,"yyyyMMddHH:mm:ss")fromtablename2.也可以把数据库日期类型date改成timestamp
Stella981 Stella981
3年前
Prometheus监控学习笔记之PromQL简单示例
0x00简单的时间序列选择返回度量指标http_requests_total的所有时间序列样本数据:http_requests_total返回度量指标名称为http_requests_total,标签分别是job"apiserver",handler"/api/comments"
Stella981 Stella981
3年前
KVM调整cpu和内存
一.修改kvm虚拟机的配置1、virsheditcentos7找到“memory”和“vcpu”标签,将<namecentos7</name<uuid2220a6d1a36a4fbb8523e078b3dfe795</uuid
Wesley13 Wesley13
3年前
KNN分类算法原理分析及代码实现
1、分类与聚类的概念与区别分类:是从一组已知的训练样本中发现分类模型,并且使用这个分类模型来预测待分类样本。目前常用的分类算法主要有:朴素贝叶斯分类算法(NaïveBayes)、支持向量机分类算法(SupportVectorMachines)、KNN最近邻算法(kNearestNeighbors)、神经网络算法(NNet)以及决策树(De
Wesley13 Wesley13
3年前
KNN 算法
KNN算法的全称是KNearestNeighbor,中文为K近邻算法,它是基于距离的一种算法,简单有效。KNN算法即可用于分类问题,也可用于回归问题。1,准备电影数据假如我们统计了一些电影数据,包括电影名称,打斗次数,接吻次数,电影类型,如下:电影名称打斗次数接吻次数
Stella981 Stella981
3年前
Spark OneHotEncoder
1、概念独热编码(OneHotEncoding) 将表示为标签索引的分类特征映射到二进制向量,该向量最多具有一个单一的单值,该单值表示所有特征值集合中特定特征值的存在。此编码允许期望连续特征(例如逻辑回归)的算法使用分类特征。对于字符串类型的输入数据,通常首先使用StringIndexer