1. 程式人生 > 實用技巧 >JAVA實現BP神經網路演算法

JAVA實現BP神經網路演算法

工作中需要預測一個過程的時間,就想到了使用BP神經網路來進行預測。

簡介

BP神經網路(Back Propagation Neural Network)是一種基於BP演算法的人工神經網路,其使用BP演算法進行權值與閾值的調整。在20世紀80年代,幾位不同的學者分別開發出了用於訓練多層感知機的反向傳播演算法,David Rumelhart和James McClelland提出的反向傳播演算法是最具影響力的。其包含BP的兩大主要過程,即工作訊號的正向傳播與誤差訊號的反向傳播,分別負責了神經網路中輸出的計算與權值和閾值更新。工作訊號的正向傳播是通過計算得到BP神經網路的實際輸出,誤差訊號的反向傳播是由後往前逐層修正權值與閾值,為了使實際輸出更接近期望輸出。

​ (1)工作訊號正向傳播。輸入訊號從輸入層進入,通過突觸進入隱含層神經元,經傳遞函式運算後,傳遞到輸出層,並且在輸出層計算出輸出訊號傳出。當工作訊號正向傳播時,權值與閾值固定不變,神經網路中每層的狀態只與前一層的淨輸出、權值和閾值有關。若正向傳播在輸出層獲得到期望的輸出,則學習結束,並保留當前的權值與閾值;若正向傳播在輸出層得不到期望的輸出,則在誤差訊號的反向傳播中修正權值與閾值。

​ (2)誤差訊號反向傳播。在工作訊號正向傳播後若得不到期望的輸出,則通過計算誤差訊號進行反向傳播,通過計算BP神經網路的實際輸出與期望輸出之間的差值作為誤差訊號,並且由神經網路的輸出層,逐層向輸入層傳播。在此過程中,每向前傳播一層,就對該層的權值與閾值進行修改,由此一直向前傳播直至輸入層,該過程是為了使神經網路的結果與期望的結果更相近。

​ 當進行一次正向傳播和反向傳播後,若誤差仍不能達到要求,則該過程繼續下去,直至誤差滿足精度,或者滿足迭代次數等其他設定的結束條件。

推導請見https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95

BPNN結構

該BPNN為單輸入層單隱含層單輸出層結構

專案結構

  • ActivationFunction:啟用函式的介面
  • BPModel:BP模型實體類
  • BPNeuralNetworkFactory:BP神經網路工廠,包括訓練BP神經網路,計算,序列化等功能
  • BPParameter:BP神經網路引數實體類
  • Matrix:矩陣實體類
  • Sigmoid:Sigmoid傳輸函式,實現了ActivationFunction介面

實現程式碼

Matrix實體類

模擬了矩陣的基本運算方法。

import java.io.Serializable;

public class Matrix implements Serializable {
    private double[][] matrix;
    //矩陣列數
    private int matrixColNums;
    //矩陣行數
    private int matrixRowNums;

    /**
     * 構造一個空矩陣
     */
    public Matrix() {
        this.matrix = null;
        this.matrixColNums = 0;
        this.matrixRowNums = 0;
    }

    /**
     * 構造一個matrix矩陣
     * @param matrix
     */
    public Matrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowNums = matrix.length;
        this.matrixColNums = matrix[0].length;
    }

    /**
     * 構造一個rowNums行colNums列值為0的矩陣
     * @param rowNums
     * @param colNums
     */
    public Matrix(int rowNums,int colNums) {
        double[][] matrix = new double[rowNums][colNums];
        for (int i = 0; i < rowNums; i++) {
            for (int j = 0; j < colNums; j++) {
                matrix[i][j] = 0;
            }
        }
        this.matrix = matrix;
        this.matrixRowNums = rowNums;
        this.matrixColNums = colNums;
    }

    /**
     * 構造一個rowNums行colNums列值為val的矩陣
     * @param val
     * @param rowNums
     * @param colNums
     */
    public Matrix(double val,int rowNums,int colNums) {
        double[][] matrix = new double[rowNums][colNums];
        for (int i = 0; i < rowNums; i++) {
            for (int j = 0; j < colNums; j++) {
                matrix[i][j] = val;
            }
        }
        this.matrix = matrix;
        this.matrixRowNums = rowNums;
        this.matrixColNums = colNums;
    }

    public double[][] getMatrix() {
        return matrix;
    }

    public void setMatrix(double[][] matrix) {
        this.matrix = matrix;
        this.matrixRowNums = matrix.length;
        this.matrixColNums = matrix[0].length;
    }

    public int getMatrixColNums() {
        return matrixColNums;
    }

    public int getMatrixRowNums() {
        return matrixRowNums;
    }

    /**
     * 獲取矩陣指定位置的值
     *
     * @param x
     * @param y
     * @return
     */
    public double getValOfIdx(int x, int y) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (x > matrixRowNums - 1) {
            throw new Exception("索引x越界");
        }
        if (y > matrixColNums - 1) {
            throw new Exception("索引y越界");
        }
        return matrix[x][y];
    }

    /**
     * 獲取矩陣指定行
     *
     * @param x
     * @return
     */
    public Matrix getRowOfIdx(int x) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (x > matrixRowNums - 1) {
            throw new Exception("索引x越界");
        }
        double[][] result = new double[1][matrixColNums];
        result[0] = matrix[x];
        return new Matrix(result);
    }

    /**
     * 獲取矩陣指定列
     *
     * @param y
     * @return
     */
    public Matrix getColOfIdx(int y) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (y > matrixColNums - 1) {
            throw new Exception("索引y越界");
        }
        double[][] result = new double[matrixRowNums][1];
        for (int i = 0; i < matrixRowNums; i++) {
            result[i][1] = matrix[i][y];
        }
        return new Matrix(result);
    }

    /**
     * 矩陣乘矩陣
     *
     * @param a
     * @return
     * @throws Exception
     */
    public Matrix multiple(Matrix a) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (a.getMatrix() == null) {
            throw new Exception("引數矩陣為空");
        }
        if (matrixColNums != a.getMatrixRowNums()) {
            throw new Exception("矩陣緯度不同,不可計算");
        }
        double[][] result = new double[matrixRowNums][a.getMatrixColNums()];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < a.getMatrixColNums(); j++) {
                for (int k = 0; k < matrixColNums; k++) {
                    result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
                }
            }
        }
        return new Matrix(result);
    }

    /**
     * 二維陣列乘一個數字
     *
     * @param a
     * @return
     */
    public Matrix multiple(double a) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double[][] result = new double[matrixRowNums][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][j] = matrix[i][j] * a;
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣點乘
     *
     * @param a
     * @return
     */
    public Matrix pointMultiple(Matrix a) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (a.getMatrix() == null) {
            throw new Exception("引數矩陣為空");
        }
        if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
            throw new Exception("矩陣緯度不同,不可計算");
        }
        double[][] result = new double[matrixRowNums][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣加法
     *
     * @param a
     * @return
     */
    public Matrix plus(Matrix a) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (a.getMatrix() == null) {
            throw new Exception("引數矩陣為空");
        }
        if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
            throw new Exception("矩陣緯度不同,不可計算");
        }
        double[][] result = new double[matrixRowNums][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣減法
     *
     * @param a
     * @return
     */
    public Matrix subtract(Matrix a) throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        if (a.getMatrix() == null) {
            throw new Exception("引數矩陣為空");
        }
        if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
            throw new Exception("矩陣緯度不同,不可計算");
        }
        double[][] result = new double[matrixRowNums][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣行求和
     *
     * @return
     */
    public Matrix sumRow() throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double[][] result = new double[matrixRowNums][1];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][1] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣列求和
     *
     * @return
     */
    public Matrix sumCol() throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double[][] result = new double[1][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[0][i] += matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣所有元素求和
     *
     * @return
     */
    public double sumAll() throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double result = 0;
        for (double[] doubles : matrix) {
            for (int j = 0; j < matrixColNums; j++) {
                result += doubles[j];
            }
        }
        return result;
    }

    /**
     * 矩陣所有元素求平方
     *
     * @return
     */
    public Matrix square() throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double[][] result = new double[matrixRowNums][matrixColNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[i][j] = matrix[i][j] * matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    /**
     * 矩陣轉置
     *
     * @return
     */
    public Matrix transpose() throws Exception {
        if (matrix == null) {
            throw new Exception("矩陣為空");
        }
        double[][] result = new double[matrixColNums][matrixRowNums];
        for (int i = 0; i < matrixRowNums; i++) {
            for (int j = 0; j < matrixColNums; j++) {
                result[j][i] = matrix[i][j];
            }
        }
        return new Matrix(result);
    }

    @Override
    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("\r\n");
        for (int i = 0; i < matrixRowNums; i++) {
            stringBuilder.append("# ");
            for (int j = 0; j < matrixColNums; j++) {
                stringBuilder.append(matrix[i][j]).append("\t ");
            }
            stringBuilder.append("#\r\n");
        }
        stringBuilder.append("\r\n");
        return stringBuilder.toString();
    }
}
Matrix程式碼

ActivationFunction介面

public interface ActivationFunction {
    //計算值
    double computeValue(double val);
    //計算導數
    double computeDerivative(double val);
}
ActivationFunction程式碼

Sigmoid

import java.io.Serializable;

public class Sigmoid implements ActivationFunction, Serializable {
    @Override
    public double computeValue(double val) {
        return 1 / (1 + Math.exp(-val));
    }

    @Override
    public double computeDerivative(double val) {
        return computeValue(val) * (1 - computeValue(val));
    }
}
Sigmoid程式碼

BPParameter

包含了BP神經網路訓練所需的引數

import java.io.Serializable;

public class BPParameter implements Serializable {

    //輸入層神經元個數
    private int inputLayerNeuronNum = 3;
    //隱含層神經元個數
    private int hiddenLayerNeuronNum = 3;
    //輸出層神經元個數
    private int outputLayerNeuronNum = 1;
    //歸一化區間
    private double normalizationMin = 0.2;
    private double normalizationMax = 0.8;
    //學習步長
    private double step = 0.05;
    //動量因子
    private double momentumFactor = 0.2;
    //啟用函式
    private ActivationFunction activationFunction = new Sigmoid();
    //精度
    private double precision = 0.000001;
    //最大迴圈次數
    private int maxTimes = 1000000;

    public double getMomentumFactor() {
        return momentumFactor;
    }

    public void setMomentumFactor(double momentumFactor) {
        this.momentumFactor = momentumFactor;
    }

    public double getStep() {
        return step;
    }

    public void setStep(double step) {
        this.step = step;
    }

    public double getNormalizationMin() {
        return normalizationMin;
    }

    public void setNormalizationMin(double normalizationMin) {
        this.normalizationMin = normalizationMin;
    }

    public double getNormalizationMax() {
        return normalizationMax;
    }

    public void setNormalizationMax(double normalizationMax) {
        this.normalizationMax = normalizationMax;
    }

    public int getInputLayerNeuronNum() {
        return inputLayerNeuronNum;
    }

    public void setInputLayerNeuronNum(int inputLayerNeuronNum) {
        this.inputLayerNeuronNum = inputLayerNeuronNum;
    }

    public int getHiddenLayerNeuronNum() {
        return hiddenLayerNeuronNum;
    }

    public void setHiddenLayerNeuronNum(int hiddenLayerNeuronNum) {
        this.hiddenLayerNeuronNum = hiddenLayerNeuronNum;
    }

    public int getOutputLayerNeuronNum() {
        return outputLayerNeuronNum;
    }

    public void setOutputLayerNeuronNum(int outputLayerNeuronNum) {
        this.outputLayerNeuronNum = outputLayerNeuronNum;
    }

    public ActivationFunction getActivationFunction() {
        return activationFunction;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public double getPrecision() {
        return precision;
    }

    public void setPrecision(double precision) {
        this.precision = precision;
    }

    public int getMaxTimes() {
        return maxTimes;
    }

    public void setMaxTimes(int maxTimes) {
        this.maxTimes = maxTimes;
    }
}
BPParameter程式碼

BPModel

BP神經網路模型,包括權值與閾值及訓練引數等屬性

import java.io.Serializable;

public class BPModel implements Serializable {
    //BP神經網路權值與閾值
    private Matrix weightIJ;
    private Matrix b1;
    private Matrix weightJP;
    private Matrix b2;
    /*用於反歸一化*/
    private Matrix inputMax;
    private Matrix inputMin;
    private Matrix outputMax;
    private Matrix outputMin;
    /*BP神經網路訓練引數*/
    private BPParameter bpParameter;
    /*BP神經網路訓練情況*/
    private double error;
    private int times;

    public Matrix getWeightIJ() {
        return weightIJ;
    }

    public void setWeightIJ(Matrix weightIJ) {
        this.weightIJ = weightIJ;
    }

    public Matrix getB1() {
        return b1;
    }

    public void setB1(Matrix b1) {
        this.b1 = b1;
    }

    public Matrix getWeightJP() {
        return weightJP;
    }

    public void setWeightJP(Matrix weightJP) {
        this.weightJP = weightJP;
    }

    public Matrix getB2() {
        return b2;
    }

    public void setB2(Matrix b2) {
        this.b2 = b2;
    }

    public Matrix getInputMax() {
        return inputMax;
    }

    public void setInputMax(Matrix inputMax) {
        this.inputMax = inputMax;
    }

    public Matrix getInputMin() {
        return inputMin;
    }

    public void setInputMin(Matrix inputMin) {
        this.inputMin = inputMin;
    }

    public Matrix getOutputMax() {
        return outputMax;
    }

    public void setOutputMax(Matrix outputMax) {
        this.outputMax = outputMax;
    }

    public Matrix getOutputMin() {
        return outputMin;
    }

    public void setOutputMin(Matrix outputMin) {
        this.outputMin = outputMin;
    }

    public BPParameter getBpParameter() {
        return bpParameter;
    }

    public void setBpParameter(BPParameter bpParameter) {
        this.bpParameter = bpParameter;
    }

    public double getError() {
        return error;
    }

    public void setError(double error) {
        this.error = error;
    }

    public int getTimes() {
        return times;
    }

    public void setTimes(int times) {
        this.times = times;
    }
}
BPModel程式碼

BPNeuralNetworkFactory

BP神經網路工廠,包含了BP神經網路訓練等功能

import java.io.*;
import java.util.*;

public class BPNeuralNetworkFactory {
    /**
     * 訓練BP神經網路模型
     * @param bpParameter
     * @param inputAndOutput
     * @return
     */
    public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception {
        //BP神經網路的輸出
        BPModel result = new BPModel();
        result.setBpParameter(bpParameter);

        ActivationFunction activationFunction = bpParameter.getActivationFunction();
        int inputNum = bpParameter.getInputLayerNeuronNum();
        int hiddenNum = bpParameter.getHiddenLayerNeuronNum();
        int outputNum = bpParameter.getOutputLayerNeuronNum();
        double normalizationMin = bpParameter.getNormalizationMin();
        double normalizationMax = bpParameter.getNormalizationMax();
        double step = bpParameter.getStep();
        double momentumFactor = bpParameter.getMomentumFactor();
        double precision = bpParameter.getPrecision();
        int maxTimes = bpParameter.getMaxTimes();

        if(inputAndOutput.getMatrixColNums() != inputNum + outputNum){
            throw new Exception("神經元個數不符,請修改");
        }
        //初始化權值
        Matrix weightIJ = initWeight(inputNum, hiddenNum);
        Matrix weightJP = initWeight(hiddenNum, outputNum);

        //初始化閾值
        Matrix b1 = initThreshold(hiddenNum);
        Matrix b2 = initThreshold(outputNum);

        //動量項
        Matrix deltaWeightIJ0 = new Matrix(inputNum, hiddenNum);
        Matrix deltaWeightJP0 = new Matrix(hiddenNum, outputNum);
        Matrix deltaB10 = new Matrix(1, hiddenNum);
        Matrix deltaB20 = new Matrix(1, outputNum);

        Matrix input = new Matrix(new double[inputAndOutput.getMatrixRowNums()][inputNum]);
        Matrix output = new Matrix(new double[inputAndOutput.getMatrixRowNums()][outputNum]);
        for (int i = 0; i < inputAndOutput.getMatrixRowNums(); i++) {
            for (int j = 0; j < inputNum; j++) {
                input.getMatrix()[i][j] = inputAndOutput.getValOfIdx(i,j);
            }
            for (int j = 0; j < inputAndOutput.getMatrixColNums() - inputNum; j++) {
                output.getMatrix()[i][j] = inputAndOutput.getValOfIdx(i,inputNum+j);
            }
        }

        //歸一化
        Map<String,Object> inputAfterNormalize = normalize(input, normalizationMin, normalizationMax);
        input = (Matrix) inputAfterNormalize.get("res");
        Matrix inputMax = (Matrix) inputAfterNormalize.get("max");
        Matrix inputMin = (Matrix) inputAfterNormalize.get("min");
        result.setInputMax(inputMax);
        result.setInputMin(inputMin);

        Map<String,Object> outputAfterNormalize = normalize(output, normalizationMin, normalizationMax);
        output = (Matrix) outputAfterNormalize.get("res");
        Matrix outputMax = (Matrix) outputAfterNormalize.get("max");
        Matrix outputMin = (Matrix) outputAfterNormalize.get("min");
        result.setOutputMax(outputMax);
        result.setOutputMin(outputMin);

        int times = 1;
        double E = 0;//誤差
        while (times < maxTimes) {
            /*-----------------正向傳播---------------------*/
            //隱含層輸入
            Matrix jIn = input.multiple(weightIJ);
            double[][] b1CopyArr = new double[jIn.getMatrixRowNums()][b1.getMatrixRowNums()];
            //擴充閾值
            for (int i = 0; i < jIn.getMatrixRowNums(); i++) {
                b1CopyArr[i] = b1.getMatrix()[0];
            }
            Matrix b1Copy = new Matrix(b1CopyArr);
            //加上閾值
            jIn = jIn.plus(b1Copy);
            //隱含層輸出
            Matrix jOut = computeValue(jIn,activationFunction);
            //輸出層輸入
            Matrix pIn = jOut.multiple(weightJP);
            double[][] b2CopyArr = new double[pIn.getMatrixRowNums()][b2.getMatrixRowNums()];
            //擴充閾值
            for (int i = 0; i < pIn.getMatrixRowNums(); i++) {
                b2CopyArr[i] = b2.getMatrix()[0];
            }
            Matrix b2Copy = new Matrix(b2CopyArr);
            //加上閾值
            pIn = pIn.plus(b2Copy);
            //輸出層輸出
            Matrix pOut = computeValue(pIn,activationFunction);
            //計算誤差
            Matrix e = output.subtract(pOut);
            E = computeE(e);//誤差
            //判斷是否符合精度
            if (Math.abs(E) <= precision) {
                System.out.println("滿足精度");
                break;
            }

            /*-----------------反向傳播---------------------*/
            //J與P之間權值修正量
            Matrix deltaWeightJP = e.multiple(step);
            deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction));
            deltaWeightJP = deltaWeightJP.transpose().multiple(jOut);
            deltaWeightJP = deltaWeightJP.transpose();
            //P層神經元閾值修正量
            Matrix deltaThresholdP = e.multiple(step);
            deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction));

            //I與J之間的權值修正量
            Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction));
            Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose();
            Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction));
            deltaWeightIJ = input.transpose().multiple(deltaWeightIJ);
            deltaWeightIJ = deltaWeightIJ.multiple(step);

            //J層神經元閾值修正量
            Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction));
            deltaThresholdJ = deltaThresholdJ.multiple(-step);

            if (times == 1) {
                //更新權值與閾值
                weightIJ = weightIJ.plus(deltaWeightIJ);
                weightJP = weightJP.plus(deltaWeightJP);
                b1 = b1.plus(deltaThresholdJ);
                b2 = b2.plus(deltaThresholdP);
            }else{
                //加動量項
                weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor));
                weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor));
                b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor));
                b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor));
            }

            deltaWeightIJ0 = deltaWeightIJ;
            deltaWeightJP0 = deltaWeightJP;
            deltaB10 = deltaThresholdJ;
            deltaB20 = deltaThresholdP;

            times++;
        }

        result.setWeightIJ(weightIJ);
        result.setWeightJP(weightJP);
        result.setB1(b1);
        result.setB2(b2);
        result.setError(E);
        result.setTimes(times);
        System.out.println("迴圈次數:" + times + ",誤差:" + E);

        return result;
    }

    /**
     * 計算BP神經網路的值
     * @param bpModel
     * @param input
     * @return
     */
    public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception {
        if (input.getMatrixColNums() != bpModel.getBpParameter().getInputLayerNeuronNum()) {
            throw new Exception("輸入矩陣緯度有誤");
        }
        ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction();
        Matrix weightIJ = bpModel.getWeightIJ();
        Matrix weightJP = bpModel.getWeightJP();
        Matrix b1 = bpModel.getB1();
        Matrix b2 = bpModel.getB2();
        double[][] normalizedInput = new double[input.getMatrixRowNums()][input.getMatrixColNums()];
        for (int i = 0; i < input.getMatrixRowNums(); i++) {
            for (int j = 0; j < input.getMatrixColNums(); j++) {
                normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin()
                        + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j))
                        / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j))
                        * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin());
            }
        }
        Matrix normalizedInputMatrix = new Matrix(normalizedInput);
        Matrix jIn = normalizedInputMatrix.multiple(weightIJ);
        double[][] b1CopyArr = new double[jIn.getMatrixRowNums()][b1.getMatrixRowNums()];
        //擴充閾值
        for (int i = 0; i < jIn.getMatrixRowNums(); i++) {
            b1CopyArr[i] = b1.getMatrix()[0];
        }
        Matrix b1Copy = new Matrix(b1CopyArr);
        //加上閾值
        jIn = jIn.plus(b1Copy);
        //隱含層輸出
        Matrix jOut = computeValue(jIn,activationFunction);
        //輸出層輸入
        Matrix pIn = jOut.multiple(weightJP);
        double[][] b2CopyArr = new double[pIn.getMatrixRowNums()][b2.getMatrixRowNums()];
        //擴充閾值
        for (int i = 0; i < pIn.getMatrixRowNums(); i++) {
            b2CopyArr[i] = b2.getMatrix()[0];
        }
        Matrix b2Copy = new Matrix(b2CopyArr);
        //加上閾值
        pIn = pIn.plus(b2Copy);
        //輸出層輸出
        Matrix pOut = computeValue(pIn,activationFunction);
        //反歸一化
        Matrix result = inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin());

        return result;

    }

    //初始化權值
    private Matrix initWeight(int x,int y){
        Random random=new Random();
        double[][] weight = new double[x][y];
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                weight[i][j] = 2*random.nextDouble()-1;
            }
        }
        return new Matrix(weight);
    }
    //初始化閾值
    private Matrix initThreshold(int x){
        Random random = new Random();
        double[][] result = new double[1][x];
        for (int i = 0; i < x; i++) {
            result[0][i] = 2*random.nextDouble()-1;
        }
        return new Matrix(result);
    }

    /**
     * 計算啟用函式的值
     * @param a
     * @return
     */
    private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception {
        if (a.getMatrix() == null) {
            throw new Exception("引數值為空");
        }
        double[][] result = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
        for (int i = 0; i < a.getMatrixRowNums(); i++) {
            for (int j = 0; j < a.getMatrixColNums(); j++) {
                result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j));
            }
        }
        return new Matrix(result);
    }

    /**
     * 啟用函式導數的值
     * @param a
     * @return
     */
    private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception {
        if (a.getMatrix() == null) {
            throw new Exception("引數值為空");
        }
        double[][] result = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
        for (int i = 0; i < a.getMatrixRowNums(); i++) {
            for (int j = 0; j < a.getMatrixColNums(); j++) {
                result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j));
            }
        }
        return new Matrix(result);
    }

    /**
     * 資料歸一化
     * @param a 要歸一化的資料
     * @param normalizationMin  要歸一化的區間下限
     * @param normalizationMax  要歸一化的區間上限
     * @return
     */
    private Map<String, Object> normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception {
        HashMap<String, Object> result = new HashMap<>();
        double[][] maxArr = new double[1][a.getMatrixColNums()];
        double[][] minArr = new double[1][a.getMatrixColNums()];
        double[][] res = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
        for (int i = 0; i < a.getMatrixColNums(); i++) {
            List tmp = new ArrayList();
            for (int j = 0; j < a.getMatrixRowNums(); j++) {
                tmp.add(a.getValOfIdx(j,i));
            }
            double max = (double) Collections.max(tmp);
            double min = (double) Collections.min(tmp);
            //資料歸一化(注:若max與min均為0則不需要歸一化)
            if (max != 0 || min != 0) {
                for (int j = 0; j < a.getMatrixRowNums(); j++) {
                    res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin);
                }
            }
            maxArr[0][i] = max;
            minArr[0][i] = min;
        }
        result.put("max", new Matrix(maxArr));
        result.put("min", new Matrix(minArr));
        result.put("res", new Matrix(res));
        return result;
    }

    /**
     * 反歸一化
     * @param a 要反歸一化的資料
     * @param normalizationMin 要反歸一化的區間下限
     * @param normalizationMax 要反歸一化的區間上限
     * @param dataMax   資料最大值
     * @param dataMin   資料最小值
     * @return
     */
    private Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin) throws Exception {
        double[][] res = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
        for (int i = 0; i < a.getMatrixColNums(); i++) {
            //資料反歸一化
            if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) {
                for (int j = 0; j < a.getMatrixRowNums(); j++) {
                    res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin);
                }
            }
        }
        return new Matrix(res);
    }

    /**
     * 計算誤差
     * @param e
     * @return
     */
    private double computeE(Matrix e) throws Exception {
        e = e.square();
        return 0.5*e.sumAll();
    }

    /**
     * 將BP模型序列化到本地
     * @param bpModel
     * @throws IOException
     */
    public void serialize(BPModel bpModel,String path) throws IOException {
        File file = new File(path);
        System.out.println(file.getAbsolutePath());
        ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file));
        out.writeObject(bpModel);
        out.close();
    }

    /**
     * 將BP模型反序列化
     * @return
     * @throws IOException
     * @throws ClassNotFoundException
     */
    public BPModel deSerialization(String path) throws IOException, ClassNotFoundException {
        File file = new File(path);
        ObjectInputStream oin = new ObjectInputStream(new FileInputStream(file));
        BPModel bpModel = (BPModel) oin.readObject(); // 強制轉換到BPModel型別
        oin.close();
        return bpModel;
    }
}
BPNeuralNetworkFactory程式碼

使用方式

思路就是建立BPNeuralNetworkFactory物件,並傳入BPParameter物件,呼叫BPNeuralNetworkFactory的trainBP(BPParameter bpParameter, Matrix inputAndOutput)方法,返回一個BPModel物件,可以使用BPNeuralNetworkFactory的序列化方法,將其序列化到本地,或者將其放到快取中,使用時直接從本地反序列化獲取到BPModel物件,呼叫BPNeuralNetworkFactory的computeBP(BPModel bpModel,Matrix input)方法,即可獲取計算值。

使用詳情請看:https://github.com/ineedahouse/top-algorithm-set-doc/blob/master/doc/bpnn/BPNeuralNetwork.md

原始碼github地址

https://github.com/ineedahouse/top-algorithm-set

對您有幫助的話,請點個Star~謝謝

參考:基於BP神經網路的無約束優化方法研究及應用[D]. 趙逸翔.東北農業大學2019