KNN分类算法原理分析及代码实现

Wesley13
• 阅读 720

1、分类与聚类的概念与区别

分类:是从一组已知的训练样本中发现分类模型,并且使用这个分类模型来预测待分类样本。

目前常用的分类算法主要有:朴素贝叶斯分类算法(Naïve Bayes)、支持向量机分类算法(Support Vector Machines)、 KNN最近邻算法(k-Nearest Neighbors)、神经网络算法(NNet)以及决策树(Decision Tree)等等。

聚类:本身没有类别的样本聚集成不同的组。

聚类分析也称无监督学习, 因为和分类学习相比,聚类的样本没有标记,需要由聚类学习算法来自动确定。聚类分析是研究如何在没有训练的条件下把样本划分为若干类。

2、原理:根据距离函数计算待分类样本X和每个训练样本的距离,然后选出离这个数据最近的K个点,看这K个点属于什么类型,利用少数服从多数的原则,将新数据归类。如下图:

KNN分类算法原理分析及代码实现

若K=3,那么离绿色点(待分类样本)最近的有2个红色三角形和1个蓝色的正方形,于是绿色的这个待分类点属于红色的三角形。

若K=5,那么离绿色点(待分类样本)最近的有2个红色三角形和3个蓝色的正方形,于是绿色的这个待分类点属于蓝色的正方形。

3、根据上述原理,就可以准备数据了。

训练样本集knn-train.txt如下图:

KNN分类算法原理分析及代码实现

待分类样本knn.txt如下图:

KNN分类算法原理分析及代码实现

4、代码实现:

根据上述数据,首先我们需要一个Point类,将点的数据和类型作为两个变量。实现如下:

public class Point {
    private int type;
    private Vector v = new Vector();
    private String value;
    public Point(){}
    
    public Point(String value){
        this.value = value;
        String[] strs = value.split(" ");
        int index=0;
        //获得值
        for(;index<strs.length-1;){
            v.add(Double.parseDouble(strs[index]));
            index++;
        }
        //获得类型
        type = Integer.parseInt(strs[index]);
    }
    
    public String toString(){
        return value;
    }
    
    public int getType() {
        return type;
    }
    public void setType(int type) {
        this.type = type;
    }
    public Vector getV() {
        return v;
    }
    public void setV(Vector v) {
        this.v = v;
    }
}

因为是根据待分类样本数据和数据集中每个点计算距离,所以还需要一个工具类KNNUtils。实现如下:

public class KNNUtils {
    public static double getDiatance(Point p1, Point p2) {
        // 隐藏条件p1.size()==p2.size
        double result = 0.0;
        for (int i = 0; i < p1.getV().size(); i++) {
            result += Math.pow(p1.getV().get(i) - p2.getV().get(i), 2);
        }
        return Math.sqrt(result);
    }

 除此之外,知道待分类样本与所有已知样本的距离后,还需要比较之间的距离。如图:

KNN分类算法原理分析及代码实现

所以还定义了一个类,专门存储类别及距离,并且因为要实现根据距离来排序,所以需实现Comparable接口。实现如下:

public class KNNDisAndType implements Comparable{
    private int type;
    private double distance;
    public KNNDisAndType(){}
    
    public KNNDisAndType(String str){
        String[] strs = str.split(":");
        type = Integer.parseInt(strs[0]);
        distance = Double.parseDouble(strs[1]);
    }
    
    public KNNDisAndType(int type, double distance){
        this.type = type;
        this.distance = distance;
    }
    
    public int getType() {
        return type;
    }
    
    public void setType(int type) {
        this.type = type;
    }
    
    public double getDistance() {
        return distance;
    }
    
    public void setDistance(double distance) {
        this.distance = distance;
    }
    
    /**
     * 比较待分类样本与已知样本距离大小
     * @author ZD
     */
    @Override
    public int compareTo(KNNDisAndType o) {
        if(this.distance>o.distance){
            return 1;
        }else if(this.distance<o.distance){
            return -1;
        }
        return 0;
    }
    
    public String toString(){
        return type+":"+distance;
    }
}

一切准备就绪,最后只需在Reducer阶段统计类别次数,最终写入文件。实现如下:

/**
 * KNN算法原理实现
 * @author ZD
 */
public class KNNExer {
    private static final int NUM=5;
    
    public static class KNNExerMapper extends Mapper<LongWritable, Text, Text, Text>{
        private static List trains = new ArrayList();
        @Override
        protected void setup(Mapper<LongWritable, Text, Text, Text>.Context context)
                throws IOException, InterruptedException {
            FileSystem fs = FileSystem.get(context.getConfiguration());
            BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(new Path("/input/knn-train.txt"))));
            String line = "";
            while((line = br.readLine())!=null){
                Point p = new Point(line);
                trains.add(p);
            }
        }

        @Override
        protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, Text>.Context context)
                throws IOException, InterruptedException {
            FileSplit fSplit = (FileSplit)context.getInputSplit();
            if(fSplit.getPath().getName().equals("knn.txt")){
                //格式和数据集一样,0代表未知分类
                Point p1 = new Point(value.toString());
                for(Point p2:trains){
                    double distance = KNNUtils.getDiatance(p1, p2);
                    //当然也可以在map阶段就获取类别个数
                    context.write(new Text(p1.toString()), new Text(p2.getType()+":"+distance));
                }
            }
        }
    }
    
    private static class KNNExerReducer extends Reducer<Text, Text, Text, IntWritable>{

        @Override
        protected void reduce(Text value, Iterable datas, Reducer<Text, Text, Text, IntWritable>.Context context) throws IOException, InterruptedException {
            List list = new ArrayList();
            for (Text data : datas) {
                KNNDisAndType knnbean  = new KNNDisAndType(data.toString());
                list.add(knnbean);
            }
            Collections.sort(list);
            Map<Integer, Integer> map = new HashMap<Integer, Integer>();
            for(int i=0; i<NUM; i++){  //找距离最近的NUM个,根据少数服从多数原则判断待分类样本类别
                KNNDisAndType knn = list.get(i);
                int type = knn.getType();
                if(map.get(type)==null){
                    map.put(type, 1);
                }else{
                    map.put(type, map.get(type)+1);
                }
            }
            int finalType = 1;
            int count=0;
            for(Integer key:map.keySet()){
                if(map.get(key)>count){
                    count = map.get(key);
                    finalType = key;
                }
            }
            String[] strs = value.toString().split(" ");
            StringBuffer sb = new StringBuffer();
            for (int i=0; i<strs.length-1; i++) {
                sb.append(strs[i]).append(" ");
            }
            int len = sb.toString().length();
            context.write(new Text(sb.toString().substring(0, len-1)), new IntWritable(finalType));
        }
    }
    
    public static void main(String[] args) {
        try {
            Configuration cfg = HadoopCfg.getConfigration();
            Job job = Job.getInstance(cfg);
            job.setJobName("KNNExer");
            job.setJarByClass(KNNExer.class);
            job.setMapperClass(KNNExerMapper.class);
            job.setMapOutputKeyClass(Text.class);
            job.setMapOutputValueClass(Text.class);
            job.setReducerClass(KNNExerReducer.class);
            job.setOutputKeyClass(Text.class);
            job.setOutputValueClass(IntWritable.class);
            FileInputFormat.addInputPath(job, new Path("/input/knn"));
            FileOutputFormat.setOutputPath(job, new Path("/KNNExer/"));
            System.exit(job.waitForCompletion(true) ? 0 : 1);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

最后结果展示:

KNN分类算法原理分析及代码实现

写在最后:本人也是在慢慢学习中成长,希望能给大家带来收获。若有错误,望指出纠正。本次分享的KNN算法原理比较简单,实现起来也较为容易。下次将与大家分享朴素贝叶斯算法的原理分析与实现。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
6个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Stella981 Stella981
3年前
CSS 分类 (Classification)
★★CSS分类属性(Classification)★★⑴CSS分类属性允许你控制如何显示元素,设置图像显示于另一元素中的何处,相对于其正常位置来定位元素,使用绝对值来定位元素,以及元素的可见度。⑵下面是常用的属性以及描述:!(https://oscimg.oschina.net/oscnet/00cb565
Wesley13 Wesley13
3年前
2、创建分类器笔记
创建分类器\\简介:\\分类是指利用数据的特性将其分类成若干类型的过程。分类与回归不同,回归的输出是实数。监督学习分类器就是用带标记的训练数据建立一个模型,然后对未知的数据进行分类。分类器可以实现分类功能的任意算法,最简单的分类器就是简单的数学函数。其中有二元(binary)分类器,将数据分成两类,也可多元(m
Wesley13 Wesley13
3年前
KNN 算法
KNN算法的全称是KNearestNeighbor,中文为K近邻算法,它是基于距离的一种算法,简单有效。KNN算法即可用于分类问题,也可用于回归问题。1,准备电影数据假如我们统计了一些电影数据,包括电影名称,打斗次数,接吻次数,电影类型,如下:电影名称打斗次数接吻次数
Wesley13 Wesley13
3年前
KNN算法详解
  简单的说,K近邻算法是采用不同特征值之间的距离方法进行分类。  该方法优点:精确值高、对异常值不敏感、无数据输入假定  缺点:计算复杂度高、空间复杂度高  适用范围:数据型和标称型  现在我们来讲KNN算法的工作原理:存在一个样本数据集,也称作训练样本集,并且样本中每条数据都存在标签。将新输入的没有标签的数据与训练样本数据集中
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这