• 大小: 6KB
    文件类型: .cpp
    金币: 1
    下载: 0 次
    发布日期: 2021-06-02
  • 语言: C/C++
  • 标签:

资源简介

基于KNN实现的手写体数字识别C++代码,输出结果有混淆矩阵、召回率、训练准确率、预测数据输出等。

资源截图

代码片段和文件信息

#include
#include
#include
#include
#include
#define path “E:\\vs207\\train.txt“
#define path2 “E:\\vs207\\test.txt“
#define predictFile  “E:\\vs207\\predict.txt“

typedef const int cint;
typedef const char cchar;

/*一个手写数字的结构体*/
typedef struct
{
int pixel[1024];
int label;
}Digit;

/*一个有label的距离结构体*/
typedef struct
{
float distance;
int label;
}Distance;

/*文件路径+名称*/

/*每个数据集的数字个数*/
cint   ntrain = 1130;//943
cint    ntest = 501;//196
cint npredict = 50;

float calDistance(Digit digit1 Digit digit2)
/*求距离*/
{
int i squareSum = 0.0;
for (i = 0; i<1024; i++)
{
squareSum += pow(digit1.pixel[i] - digit2.pixel[i] 2.0);
}
return sqrtf(squareSum);//平方根
}

int loadDigit(Digit *digit FILE *fp int *labels)
/*读取digit*/
{
int index = 0;
for (index = 0; index<784; index++)
{
if (!fscanf(fp “%d“ &(digit->pixel[index])))
{
printf(“FILE already read finish.\n“);
return -1;
}
}
fscanf(fp “%d“ &(digit->label));
*labels = digit->label;

return 1;
}

void exchange(Distance *in int index1 int index2)
/*交换字符串两项*/
{
Distance tmp = (Distance)in[index1];
in[index1] = in[index2];
in[index2] = tmp;
}

void selectSort(Distance *in int length)
/*选择排序*/
{
int i j min;
int N = length;
for (i = 0; i {
min = i;
for (j = i + 1; j {
if (in[j].distance }
exchange(in i min);
}
}

int prediction(int K Digit in Digit *train int nt)//K Dtest[itest] Dtrain ntrain 943
/*利用训练数据预测一个数据digit*/
{
int i it;
Distance distance[1133];
/*求取输入digit与训练数据的距离*/
for (it = 0; it {
distance[it].distance = calDistance(in train[it]);
distance[it].label = train[it].label;
}
/*给计算的距离排序(选择排序)*/
int predict = 0;
int b0[10] = { 0 };

selectSort(distance nt);
for (i = 0; i {
//predict += distance[i].label;
switch (distance[i].label)
{
case 0:
b0[0]++;
break;
case 1:
b0[1]++;
break;
case 2:
b0[2]++;
break;
case 3:
b0[3]++;
break;
case 4:
b0[4]++;
break;
case 5:
b0[5]++;
break;
case 6:
b0[6]++;
break;
case 7:
b0[7]++;
break;
case 8:
b0[8]++;
break;
case 9:
b0[9]++;
break;
default:
break;
}
}
int max = 0;
for (int m = 0; m < 10; m++) {

if (b0[m] >= max) {
max = b0[m];
predict = m;
}
}
return predict;

}
void knn_classifiy(int K)
/*用测试数据集进行测试*/
{
printf(“knn_arithmetic_begin....\n“);
clock_t startfinish aa dd;
int i;
FILE *fp;

/*读入训练数据*/
int trainLabels[ntrain];
int trainCount[10] = { 0 };
Digit *Dtrain = (Digit*)calloc(ntrain  sizeof(Digit));
   fp = fopen(path “r“); //读文件
   printf(“load training digits...\n“);
start = clock();
for (i = 0; i {
loadDigit(&Dtrain[i] fp &trainLabels[i]);

trainCount[Dtrain[i

评论

共有 条评论

相关资源