本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com
本文展示如何使用BP神经网络来实现图像类别的识别,作为BP神经网络多类别的模型示例
通过本文,可以了解BP神经网络是如何实现多类别识别的,以及如何应用于图片类别识别
本节介绍手写数字识别的问题与数据说明
手写数字数据介绍
matlab2018a中自带的digitimages.mat数据就是写手数字的数据,
共包含3000个0-9的手写数字样本,每个样本是28*28的图片数据
不妨每个数字都打印5个样本示例,如下
手写数字识别属于一个多分类问题,下面我们使用BP神经网络算法来识别手写数字
BP神经网络应用于手写数字识别-代码实现
在matlab中使用patternnet函数就可以构建一个用于模式识别的BP神经网络
具体代码实现如下:
% 本代码用于展示BP神经网络应用于手写数字识别(多分类模型)
% 转载请说明来自 《老饼讲解神经网络》 www.bbbdata.com
clear all ;close all;
setdemorandstream(88); % 老饼为了每次运行的结果一致设定随机种子,实际中可以去掉
% 加载数据
load digitimages.mat % 加载手写数字数据
[h,w,pic_num] = size(images); % 获取手写数字图片的大小与样本数量
X = double(reshape(images,[h*w,pic_num])); % 将图片转为列向量
y = full(ind2vec(Y'+1)); % 将图片对应的数字转为one-hot矩阵
net = patternnet(120); % 建立模式识别网络,隐层设为120个
[net,tr,py] = train(net,X,y); % 将数据放到网络中训练
% 打印训练与测试效果
train_y = y(:,tr.trainInd); % 训练样本的y
train_py = py(:,tr.trainInd); % 训练样本的预测y
test_y = y(:,tr.testInd); % 测试样本的y
test_py = py(:,tr.testInd); % 测试样本的预测y
train_err_rate = sum( vec2ind(train_py)==vec2ind(train_y))/size(train_y,2) % 训练错误率
test_err_rate = sum( vec2ind(test_py)==vec2ind(test_y))/size(test_y,2) % 测试错误率
运行结果如下:
可以看到,测试数据集的准确率达到了98%
End