• 大小: 511KB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2021-05-07
  • 语言: 其他
  • 标签: KNN  最近邻  

资源简介

用K近邻(KNN)做手写体识别(MNIST),准确率可以达到94%。关于具体原理可以看我的博客

资源截图

代码片段和文件信息

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;

public class HandwritingRecogniton {
static final int K = 20;

public static double calDistance(int[] a int[] b) {
double temp = 0;
for (int x = 0; x < a.length; x++) {
temp += (a[x] - b[x]) * (a[x] - b[x]);
}
return temp = Math.sqrt(temp);
}

public static double cosDis(int[] a int[] b) {
double numerator = 0 aLength = 0 bLength = 0;
for (int x = 0; x < a.length; x++) {
numerator += a[x] * b[x];
aLength += a[x];
bLength += b[x];
}
return numerator / (Math.sqrt(aLength) * Math.sqrt(bLength));
}

public static double hanmingDis(int[] a int[] b) {
double result = 0;
for (int x = 0; x < a.length; x++) {
result += Math.abs(a[x] - b[x]);
}
return result;
}

public static int[] str2int(String[] a) {
int[] b = new int[a.length];
for (int x = 0; x < a.length; x++) {
b[x] = Integer.parseInt(a[x]);
}
return b;
}

public static int classify(String filename int[] a) throws IOException {
FileReader fr = new FileReader(filename);
BufferedReader bufr = new BufferedReader(fr);

double[] d = new double[K];//存放K近邻的距离

for (int x = 0; x < K; x++) {//先将所有K近邻的距离初始化为最大距离28
d[x] = 28;
}
double temp = 0;
int lable = 0;
int[] num = new int[K];//记录对应距离的类标
String str = null;
int t = 0;
while ((str = bufr.readLine()) != null && t++ < 10000) {
int[] b = str2int(str.substring(0 str.length() - 1).split(““));
temp = calDistance(a b);
lable = Integer.parseInt(str.substring(str.length() - 1));
for (int x = 0; x < K; x++) {//找到K近邻的样本
if (temp < d[x]) {
d[x] = temp;
num[x] = lable;
break;
}
}
}
bufr.close();
int[] count = new int[10];
for (int x = 0; x < K; x++) {//统计各数字出现次数
count[num[x]]++;
}
int result = 0;
for (int x = 1; x < 10; x++) {//找出出现次数最多的
if (count[x] > count[result])
result = x;
}
return result;
}

public static void main(String[] arg) throws IOException {
System.out.println(System.currentTimeMillis());
FileReader fr = new FileReader(“validation.txt“);
BufferedReader bufr = new BufferedReader(fr);

int right = 0 sum = 0;

String str = null;
while ((str = bufr.readLine()) != null) {
int[] a = str2int(str.substring(0 str.length() - 1).split(““));
int result = classify(“train.txt“ a);
int lable = Integer.parseInt(str.substring(str.length() - 1));

sum++;
if (result == lable) {
right++;
// System.out.println(“result of classicication is:“ + result +
// “ original lable is:“ + lable);
} else {
int cc[][] = new int[28][28];
int count = 0;
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
cc[x][y] = a[count++];
}
}
System.out.println(“result of classicication is:“ + result + “   original lable is:“ + lable);
for (int x = 0; x < 

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2017-01-21 19:17  ML_project\
     文件         301  2016-12-12 17:47  ML_project\.classpath
     文件         386  2016-12-12 17:47  ML_project\.project
     目录           0  2017-01-21 19:17  ML_project\.settings\
     文件         598  2016-12-12 17:47  ML_project\.settings\org.eclipse.jdt.core.prefs
     目录           0  2017-01-21 20:26  ML_project\bin\
     文件        3941  2017-01-21 20:26  ML_project\bin\HandwritingRecogniton.class
     目录           0  2017-01-21 20:26  ML_project\src\
     文件        3331  2017-01-21 20:26  ML_project\src\HandwritingRecogniton.java
     文件    15710000  2016-12-12 18:22  ML_project\train.txt
     文件      785500  2016-12-12 18:22  ML_project\validation.txt

评论

共有 条评论