神经网络

【实例】蒙特卡罗手写数字识别

作者 : 老饼 发表日期 : 2023-06-07 02:54:33 更新日期 : 2024-12-05 10:01:19
本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com



蒙特卡罗算法是一个强大的算法,本文讲解如何使用蒙特卡罗算法来解决手写数字识别问题

通过本文,可以进一步加强对蒙特卡罗算法的理解和实际应用





    01. 蒙特卡罗手写数字识别-数据介绍    





本节介绍手写数字识别的问题与数据说明





     手写数字数据介绍     


matlab2018a中自带的digitimages.mat数据就是写手数字的数据,
共包含3000个样本,每个样本是28*28的图片数据,
 不妨每个数字都打印5个样本示例,如下
     
下面我们使用蒙特卡罗法来识别手写数字







    02. 蒙特卡罗法识别手写数字-算法设计    





本节讲解如何设计一个蒙特卡罗算法来识别手写数字





     蒙特卡罗法用于手写数字识别的算法设计    


蒙特卡罗法用于数字识别的算法设计
数字识别实际也是类别识别
所以算法设计与《螃蟹识别》中的流程是一致的,
具体算法设计如下:

将历史样本存储起来作为一个样本库,
当来了一个新样本时,就在样本库各个类别各抽取n个样本
然后判别新样本中与抽出的样本哪个最相似,就判为哪一个类别
如此重复抽取t次,
最后统计t次中,被判为哪个类别的次数最多,就认为样本属于哪个类别
 ✍️备注:这里我们使用欧氏距离作为相似度的度量,欧氏距离越小,就认为越相似
蒙特卡罗法用于手写数字识别的算法流程图
 
算法流程如下:






    03. 蒙特卡罗法识别手写数字-代码实现    





本节通过代码实现蒙特卡罗法识别手写数字,并展示相关结果





     蒙特卡罗法识别手写数字-代码实现  


依据上述算法设计,编码蒙特卡罗法用于识别手写数字
共三部分代码
 
👉1. 预测流程脚本主函数 
👉2. 数据处理函数         
👉3. 预测函数               

具体代码如下:
%------代码说明:展示蒙特卡罗法预测手写数字 -----------------
% 来自《老饼讲解神经网络》www.bbbdata.com ,matlab版本:2018a 
global sample_x                                            % 定义样本库的x数据
global sample_y                                            % 定义样本库的y数据
load digitimages.mat                                       % 加载手写数字数据
setdemorandstream(88888);                                  
										                   
% 图片数据预处理                                           
[h,w,pic_num] = size(images);                              % 获取手写数字图片的大小与样本数量
sample_x = zeros(20*20,pic_num);                           % 初始化图片样本数据
for i = 1:pic_num                                          % 逐张图片处理
    cur_x = process_img(images(:,:,i));                    % 处理当前图片
    sample_x(:,i)=cur_x(:);                                % 存储当前图片,
end                                                        
sample_y = full(ind2vec(Y'+1));                            % 将图片对应的数字转为one-hot矩阵
										                   
% 数据分割                                                 
test_num = 300;   % 测试数据个数                           
test_idx = randperm(pic_num,test_num);                     % 随机选择test_num个作为测试样本
test_x   = sample_x(:,test_idx);                           % 从样本中抽出测试样本的x
test_y   = sample_y(:,test_idx);                           % 从样本中抽出测试样本的y
sample_x(:,test_idx) = [];                                 % 移除测试样本的x
sample_y(:,test_idx) = [];                                 % 移除测试样本的y
													       
% 通过样本库来预测测试样本对应的数字                       
py = zeros(size(test_y));                                  % 初始化预测结果
for i = 1:size(test_x,2)                                   % 逐个样本进行预测
   py(:,i)= mc_predict_number(test_x(:,i));                % 预测当前样本
   if(mod(i,10)==0)                                        % 每隔10个样本打印一次进度
       disp(['进度:',num2str(round(i/test_num*100)),'%'])  % 打印当前预测进度
   end
end

% 统计与打印预测准确率
y_label  =  vec2ind(test_y)-1;                             % 将测试数据的真实结果由one-hot格式转为类别标签形式
py_label = vec2ind(py)-1;                                  % 将预测结果的one-hot格式转为类别标签
acc_rate = sum(py_label==y_label)/length(y_label);         % 计算准确率
disp(['预测准确率:',num2str(acc_rate)])                    % 打印准确率

图片的预处理函数process_img如下:

% 预处理图片的函数
function deal_img = process_img(img)
deal_img = img>50;                            % 将值>50的作为1,<50的作为0
deal_img = truncImgsPadding(deal_img);        % 对图片上下左右空白处进行裁剪
deal_img = imresize(deal_img,[20,20]);        % 将图片转换为20*20的Size
end

% 裁剪图片空白边缘部分
function trunc_img = truncImgsPadding(imgs)
% 裁剪左右两边的空白处
sum_imgs      = sum(imgs);                    % 按列求和
csum_imgs     = cumsum(sum_imgs);             % 计算累计值
[~,right_idx] = max(csum_imgs);               % 根据累计值找出右边第一个非0列
left_idx      = find(csum_imgs>0);            % 根据累计值找出非0列
left_idx      = left_idx(1);                  %  第一个非0列就是左边第一个非0列
trunc_img     = imgs(:,left_idx:right_idx);   % 进行左右裁剪
											 
% 裁剪上下的空白处                           
sum_imgs    = sum(trunc_img,2);               % 按行求行
csum_imgs   = cumsum(sum_imgs);               % 计行累计值
[~,bot_idx] =  max(csum_imgs);                % 根据累计值找出底部第一个非0行
top_idx     = find(csum_imgs>0);              % 根据累计值找出非0行
top_idx     = top_idx(1);                     % 第一个非0行就是顶部第一个非0行
trunc_img   = trunc_img(top_idx:bot_idx,:);   % 对上下进行裁剪
end

蒙特卡罗预测函数mc_predict_number如下:

function y = mc_predict_number(x)
% 用蒙特卡罗法判断样本的类别     
global sample_x                                                      % 样本库的x数据
global sample_y                                                      % 样本库的y数据
setdemorandstream(88888);                                                            
t = 100;                                                             % 裁决次数
n = 200;                                                             % 每个类别抽样数量
[class_num,sample_num] = size(sample_y);                             % 类别个数与样本个数
class_idx = cell(class_num,1);                                       % 初始化各个类别的样本索引
for i = 1:class_num                                                  
    class_idx{i} = find(sample_y(i,:));                              % 找出属于第i个类别的样本索引
end                                                                  
															         
% 进行抽样裁决x的类别                                                                
rs = zeros(class_num,t);                                             % 初始化裁决结果表  
for i = 1:t                                                          
    select_idx = zeros(n*class_num,1);                               % 本次抽样的样本索引
    for j = 1:class_num                                              % 逐类别抽样
        cur_class_idx = class_idx{j};                                % 属于第i个类别的样本索引
        cur_select_idx = randperm(length(cur_class_idx),n);          % 随机抽出n个样本
        select_idx((j-1)*n+1:j*n) = cur_class_idx(cur_select_idx);   % 记录本次抽出的样本索引
    end
															        
   select_sample = sample_x(:,select_idx);                           % 抽取出本次抽样的样本
   d = sum((select_sample-x).^2);                                    % 计算各个样本与x的距离
   [~,win_idx] = min(d);                                             % 找出最小距离的样本作为本次胜出的样本
   win_y = sample_y(:,select_idx(win_idx));                          % 根据样本的索引找出y
   rs(:,i) = win_y;                                                  % 记录本次的获胜的y
end                                                                 
															        
% 统计多次抽样裁决的结果,用于决定最终x的所属类别                    
win_stat    = sum(rs,2);                                             % 统计各个类别胜出的次数
[~,win_idx] = max(win_stat);                                         % 找出哪个类别胜出次数最多
y = zeros(class_num,1);                                              % 初始化x的类别y
y(win_idx)  = 1;                                                     % 将胜出次数最多的类别,作为x的类别
end                                                                 



     运行结果    


运行结果如下:
 
从结果可以看到,只是简单的使用均方差作为图片的相似评估函数,预测准确率就已经达到了99.66%
可见,蒙特卡罗算法在手写数字识别上已经达到了较好的效果
比较不足的是,预测速度相对较忙,不如由于它是可并行的,如果改为并行,速度上会加快许多











 End 






联系老饼