BP神经网络

【测试】trainlmBP的测试Demo

作者 : 老饼 发表日期 : 2022-06-09 04:38:51 更新日期 : 2023-02-22 19:15:51
本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com




本代码用于测试《trainlmBP.m》的使用Demo

同时也测试trainlmBP.m的结果是否与matlab的工具箱用"trainlm"的训练结果是否一致




   01. trainlmBP的测试Demo  



本节展示trainlmBP的测试Demo代码及对代码进行简要说明



    代码简要解说    


DEMO代码主要干的事情如下:
 👉1. 先用工具箱训练一个两隐层的神经网络                        
👉2. 再用自写代码trainlmBP训练一个两隐层的神经网络  

然后分别打印工具箱的训练结果,与trainlmBP的训练结果





      trainlmBP的使用与测试Demo    


% ---------数据生成与参数预设-------------
clear 
% 生成输入输出数据
X  = [-1:0.2:1;-1:0.2:1];
y  = [sin(X( 1,:)) + X( 2,:);sin(X( 1,:).*X( 1,:)) + 0.5*X( 2,:)];
% 数据归一化
X = 2*(X-repmat(min(X,[],2),1,size(X,2)))./(repmat(max(X,[],2)-min(X,[],2),1,size(X,2)))-1;
y = 2*(y-repmat(min(y,[],2),1,size(y,2)))./(repmat(max(y,[],2)-min(y,[],2),1,size(y,2)))-1;

% 参数预设
hnn     = [4,2];        % 隐节点个数(hideNodeNum)
tf      = {'tansig','tansig','purelin'};
maxStep = 1000;         % 最大训练步数
goal    = 0.00001;      % 目标误差



%---------调用自写函数进行训练--------------
rand('seed',70);
[W,B] = trainlmBP(X,y,hnn,tf,goal,maxStep); % 网络训练
py    = predictBP(W,B,tf,X);                   % 网络预测
w12 = W{1,2}                                % 提取网络的权重
w23 = W{2,3}                                % 提取网络的权重
w34 = W{3,4}                                % 提取网络的权重
b2  = B{2};                                 % 提取网络阈值
b3  = B{3};                                 % 提取网络阈值
b4  = B{4};                                 % 提取网络阈值
% -----调用工具箱,与工具箱的结果比较------
rand('seed',70);
net = newff(X,y,hnn,{'tansig','tansig','purelin'},'trainlm');


%设置训练参数
net.trainparam.goal        = goal;      % 训练目标
net.trainparam.epochs      = maxStep;   % 最大训练次数.
net.divideParam.trainRatio = 1;         % 全部数据用于训练
net.divideParam.valRatio   = 0;         % 关掉泛化验证数据
net.divideParam.testRatio  = 0;         % 关掉测试数据

% 网络训练
[net,tr,py_tool] = train(net,X,y);   % 训练网络
w12_tool = net.IW{1}                 % 提取网络的权重
w23_tool = net.LW{2,1}               % 提取网络的权重
w34_tool = net.LW{3,2}               % 提取网络的权重

% 与工具箱的差异
maxECompareNet = max([max(abs(py(:)-py_tool(:))),max(abs(w12(:)-w12_tool(:))),max(abs(w23(:)-w23_tool(:))),max(abs(w34(:)-w34_tool(:)))]);
disp(['自写代码与工具箱权重阈值的最大差异:',num2str(maxECompareNet)])

版本:matlab 2018a





   02. 代码运行结果解说    



本节展示代码的运行结果,并进行简单解说



     代码运行结果解说     


运行结果共三部分

 1. 自写代码求得的网络权重与阈值
 
    
  
2. 调用工具箱求得的网络权重与阈值
 
 
 
  
3. 自写代码与工具箱的结果对比
 
 
✍️解说
 
从运行结果可以看到,自写代码与工具箱的结果一样
 
说明扒出的逻辑与工具箱的基本一致









 End 






联系老饼