• 大小: 3KB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2021-06-14
  • 语言: Java
  • 标签: KNN  

资源简介

knn(java实现)http://blog.csdn.net/u011067360/article/details/45937327

资源截图

代码片段和文件信息

package Marchinglearning.knn2;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * KNN算法主体类
 */
public class KNN {
/**
 * 设置优先级队列的比较函数,距离越大,优先级越高
 */
private Comparator comparator = new Comparator() {
public int compare(KNNNode o1 KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
 * 获取K个不同的随机数
 * @param k 随机数的个数
 * @param max 随机数最大的范围
 * @return 生成的随机数数组
 */
public List getRandKNum(int k int max) {
List rand = new ArrayList(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
 * 计算测试元组与训练元组之前的距离
 * @param d1 测试元组
 * @param d2 训练元组
 * @return 距离值
 */
public double calDistance(List d1 List d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
 * 执行KNN算法,获取测试元组的类别
 * @param datas 训练数据集
 * @param testData 测试元组
 * @param k 设定的K值
 * @return 测试元组的类别
 */
public String knn(List> datas List testData int k) {
PriorityQueue pq = new PriorityQueue(k comparator);
List randNum = getRandKNum(k datas.size());
System.out.println(“randNum:“+randNum.toString());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
//System.out.println(“currData:“+currData+“c:“+c+“testData“+testData);
//计算测试元组与训练元组之前的距离
KNNNode node = new KNNNode(index calDistance(testData currData) c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List t = datas.get(i);
//System.out.println(“testData:“+testData);
//System.out.println(“t:“+t);
double distance = calDistance(testData t);
//System.out.println(“distance:“+distance);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i distance t.get(t.size() - 1).toString()));
}
}

return getMostClass(pq);
}
/**
 * 获取所得到的k个最近邻元组的多数类
 * @param pq 存储k个最近近邻元组的优先级队列
 * @return 多数类的名称
 */
private String getMostClass(PriorityQueue pq) {
Map classCount = new HashMap();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c classCount.get(c) + 1);
} else {
classCount.put(c 1);
}
}
int maxIndex = -1;
int maxCount = 0;
object[] classes = classCount.keySet().toArray();
for (int i = 0;

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        3372  2015-05-23 21:39  knn2\KNN.java
     文件         712  2015-05-23 21:30  knn2\KNNNode.java
     文件        1907  2015-05-23 21:31  knn2\TestKNN.java
     文件         203  2015-05-23 20:24  knndata2\datafile.data
     文件         191  2015-05-23 20:24  knndata2\testfile.data

评论

共有 条评论