• 大小: 5KB
    文件类型: .rar
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Java
  • 标签: java  bp  

资源简介

利用java实现bp神经网络,给定了UCI数据库的疝气病证预测病马数据,使用训练集训练BP神经网络并预测测试集的标签,错误率控制在30%以下。

资源截图

代码片段和文件信息

package Bp;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;

class BPNN {
    private static int layer = 5; // 五层神经网络
    private static int NodeNum = 25; // 每层的最多节点数
    private static final int ADJUST = 15; // 隐层节点数调节常数
    private static final int MaxTrain = 100000; // 最大训练次数
    private static final double ACCU = 0.011; // 每次迭代允许的误差 0.011
    private double ETA_W = 0.3; // 权值学习效率0.3
    private double ETA_T = 0.3; // 阈值学习效率
    private double accu;


    private int in_num; // 输入层节点数
    private int hd_num; // 隐层节点数
    private int out_num; // 输入出节点数

    private ArrayList> list = new ArrayList(); // 输入输出数据

    private double[][] in_hd_weight; // BP网络in-hidden突触权值
    private double[][] hd_out_weight; // BP网络hidden_out突触权值
    private double[] in_hd_th; // BP网络in-hidden阈值
    private double[] hd_out_th; // BP网络hidden-out阈值

    private double[][] out; // 每个神经元的值经S型函数转化后的输出值,输入层就为原值
    private double[][] delta; // delta学习规则中的值

    // 获得网络五层中神经元最多的数量
    public int GetMaxNum() {
        return Math.max(Math.max(in_num hd_num) out_num);
    }

    // 设置权值学习率
    public void SetEtaW() {
        ETA_W = 0.5;
    }

    // 设置阈值学习率
    public void SetEtaT() {
        ETA_T = 0.5;
    }

    // BPNN训练
    public void Train(int in_number int out_number
            ArrayList> arraylist) throws IOException {
        list = arraylist;
        in_num = in_number;
        out_num = out_number;

        GetNums(in_num out_num); // 获取输入层、隐层、输出层的节点数
        // SetEtaW(); // 设置学习率
        // SetEtaT();

        InitNetWork(); // 初始化网络的权值和阈值

        int datanum = list.size(); // 训练数据的组数
        int createsize = GetMaxNum(); // 比较创建存储每一层输出数据的数组
        out = new double[3][createsize];

        for (int iter = 0; iter < MaxTrain; iter++) {
            for (int cnd = 0; cnd < datanum; cnd++) {
                // 第一层输入节点赋值

                for (int i = 0; i < in_num; i++) {
                    out[0][i] = list.get(cnd).get(i); // 为输入层节点赋值,其输入与输出相同
                }
                Forward(); // 前向传播
                Backward(cnd); // 误差反向传播

            }
            System.out.println(“This is the “ + (iter + 1)
                    + “ th trainning NetWork !“);
            accu = GetAccu();
            System.out.println(“All Samples Accuracy is “ + accu);
            if (accu < ACCU)
                break;

        }

    }

    // 获取输入层、隐层、输出层的节点数,in_number、out_number分别为输入层节点数和输出层节点数
    public void GetNums(int in_number int out_number) {
        in_num = in_number;
        out_num = out_number;
        hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST;
        if (hd_num > NodeNum)
            hd_num = NodeNum; // 隐层节点数不能大于最大节点数
    }

    // 初始化网络的权值和阈值
    public void InitNetWork() {
        // 初始化上一次权值量范围为-0.5-0.5之间
        //in_hd_last = new double[in_num][hd_num];
        //hd_out_

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件      10265  2018-03-05 14:43  Bp\BPNN.java

     文件       7208  2018-03-05 14:43  Bp\DataUtil.java

     文件       3540  2018-03-05 14:55  Bp\Test.java

     目录          0  2018-05-24 17:07  Bp

----------- ---------  ---------- -----  ----

                21013                    4


评论

共有 条评论