本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com
本文是笔者细扒matlab神经网络工具箱newgrnn的源码后
去除冗余代码,重现的简版newgrnn代码,代码与newgrnn的结果完全一致
通过本代码的学习,可以完全细节的了解广义回归神经网络的实现逻辑
本节展示如何不依赖软件包自实现广义回归神经网络的代码
广义回归-自实现代码
function testGrnnNet()
%本代码来自 www.bbbdata.com
%本代码模仿matlab神经网络工具箱的newgrnn神经网络,用于训练《广义回归神经网络》,
%代码主旨用于教学,供大家学习理解newgrnn神经网络原理
%--------生成训练数据-------------------
x1 = 1:1:10;
X = [ x1; x1];
y = sin(X(1, :)) + X( 2, :);
test_x = [2 3]'; %测试数据
%---------参数预设----------------------
spread = 2; %扩展系数
%------调用自写Grnn函数获得广义回归神经网络-----------------
[w21,b2,w32] = trainGrnnNet( X,y,spread )
py = predictGrnnNet(w21,b2,w32,test_x) %模型预测
%------调用matlab神经网络工具箱训练广义回归神经网络
net=newgrnn(X,y,spread); % 用工具箱设计广义回归网络
pyByBox = sim(net, test_x) % 工具箱对测试数据的预测结果
% -------检查自写代码与工具箱的结果是否一致------------------------------
testResult = isequal( py, pyByBox);
disp(['testIsequal = ',num2str(testResult)]);
web('www.bbbdata.com')
end
% 广义神经网络的生成函数
function [w21,b2,w32] = trainGrnnNet(X,y,spread)
%生成广义神经网络只要将输入输出存到w21,w32中,
%再用spread生成影响径向基宽度的b2就可以
w21 = X';
w32 = y;
b2 = ones( size(X,2), 1)*sqrt( -log(.5))/spread;
end
% 广义神经网络的预测函数
function y = predictGrnnNet(w21,b2,w32,X)
y = [];
for i = 1:size(X,2)
cur_x = X(:,i);
hv = b2.*sqrt(sum((ones(size(w21,1),1)*cur_x' - w21).^2,2)); % 计算隐节点的值
ha = exp(-(hv.*hv)); % 计算隐节点激活值
cur_y = w32*(ha./sum(ha)); % 计算输出
y = [y,cur_y];
end
end
版本:matlab 2014b
代码运行结果
代码运行结果如下
1、训练好的网络参数
2、与工具箱结果的比较
✍️PASS
从运行结果可以看到,
自写代码与工具箱的结果一样
说明扒出的逻辑与工具箱的一致
本节对代码的结构进行简要说明
代码结构说明
代码包含了三个函数
具体如下
👉1. testGrnnNet:测试用例主函数,直接运行时就是执行该函数
1、数据:生成一个2输入1输出的训练数据,
2、用自写的函数构建一个广义回归网络,与用网络进行预测
3、使用工具箱newgrnn训练一个广义回归网络
比较自写函数与工具箱训练结果是否一致
👉2. trainGrnnNet:训练主函数
训练一个广义回归神经网络
👉3. predictGrnnNet:预测主函数
传入需要预测的X,与网络的权重矩阵,即可得到预测结果
End