神经网络-一篇入门

【实例】蒙特卡罗手写数字识别-GUI版本

作者 : 老饼 发表日期 : 2023-03-25 00:02:23 更新日期 : 2023-11-25 03:56:12
本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com




蒙特卡罗算法是一个强大的算法,本文讲解如何使用蒙特卡罗算法来解决手写数字识别问题

本文提供实现蒙特卡罗算法识别手写数字的实现代码,并附加一个GUI界面用于手写数字并进行识别

通过本文,可以进一步加强对蒙特卡罗算法的理解和实际应用





    01. 蒙特卡罗用于数字识别-数据介绍    



本节介绍手写数字识别的问题与数据说明



     手写数字数据介绍     


matlab2018a中自带的digitimages.mat数据就是写手数字的数据,
共包含3000个样本,每个样本是28*28的图片数据,
 不妨每个数字都打印5个样本示例,如下
     
下面我们使用蒙特卡罗法来识别手写数字






    02. 蒙特卡罗法识别手写数字-算法设计    



本节讲解如何设计一个蒙特卡罗算法来识别手写数字



     蒙特卡罗法用于手写数字识别的算法设计    


蒙特卡罗法用于数字识别的算法设计
数字识别实际也是类别识别
所以算法设计与《螃蟹识别》中的流程是一致的,
具体算法设计如下:

将历史样本存储起来作为一个样本库,
当来了一个新样本时,就在样本库各个类别各抽取n个样本
然后判别新样本中与抽出的样本哪个最相似,就判为哪一个类别
如此重复抽取t次,
最后统计t次中,被判为哪个类别的次数最多,就认为样本属于哪个类别
 ✍️备注:这里我们使用欧氏距离作为相似度的度量,欧氏距离越小,就认为越相似
蒙特卡罗法用于手写数字识别的算法流程图
 
算法流程如下:






    03. 蒙特卡罗法识别手写数字-代码实现(GUI版本)    



本节通过代码实现蒙特卡罗法识别手写数字,并展示相关结果




     蒙特卡罗法识别手写数字-代码实现  


依据上述算法设计,编码蒙特卡罗法用于识别手写数字,并配上GUI界面
共四部分代码
 
👉1. GUI预测界面主函数               
👉2. 数据加载函数                       
👉3. 数据处理函数                      
👉4. 蒙特卡罗手写数字预测函数 

备注:数据处理函数3、预测函数4与非GUI版本的《蒙特卡罗识别手写数字》是一致的
具体代码如下:
%------代码说明:展示蒙特卡罗法预测手写数字(GUI版本) -----------------
% 来自《老饼讲解神经网络》www.bbbdata.com ,matlab版本:2018a 
function initDrawGUI()
clc;clear all ;close all ;
global draw_enable                                                % 绘画状态
global draw_x                                                     % 绘画的x坐标
global draw_y                                                     % 绘画的y坐标
draw_enable = 0;                                                  % 初始化绘画状态
init_img_sample()                                                 % 初始化数据
% 界面控件                                                        
hMainFig = figure('Tag','mainFig','Name','手写数字识别');         % 新建一个界面
set(hMainFig,'WindowButtonDownFcn',@ButttonDownFcn)               % 设置界面鼠标按下的回调函数
set(hMainFig,'WindowButtonUpFcn',@ButttonUpFcn)                   % 设置界面鼠标按上的回调函数
set(hMainFig,'WindowButtonMotionFcn',@ButttonMotionFcn)           % 设置界面鼠标移动的回调函数
% 坐标轴控件
haxes = axes('Parent',hMainFig);                                  % 新建一个坐标轴
set(haxes,'position',[0.1 0.2  0.8  0.7 ]);                       % 设置坐标轴控件的位置                   
set(haxes, 'XLim', [-3,3], 'YLim', [-2,2],'Box','on');            % 设置坐标轴的范围
haxes.XAxis.Visible = 'off';                                      % 隐藏坐标轴x轴
haxes.YAxis.Visible = 'off';                                      % 隐藏坐标轴y轴

% 建一个按钮-用于清空图像
hbuttonClear =  uicontrol(...
    'Parent',hMainFig,...
    'String','清空',...
    'position',[150 30 100 40],...
    'Callback',@buttonClearCallBack,...
    'Style','pushbutton');

% 建一个按钮-用于识别图像
hbutton =  uicontrol(...
    'Parent',hMainFig,...
    'String','识别',...
    'position',[350 30 100 40],...
    'Callback',@buttonRecCallBack,...
    'Style','pushbutton');

% 清除按钮的回调函数(用于清除画面)
    function buttonClearCallBack(hObject,eventdata)
        cla;                                                         % 清除当前图像
    end
% 识别按钮的回调函数(用于数字识别)
    function buttonRecCallBack(hObject,eventdata)
        % 保存图片
        tmp_f_handle = figure('visible','off');                      % 新建一个figure
        tmp_axes     = copyobj(haxes,tmp_f_handle);                  % 将坐标轴内容复制一份
        tmp_axes.Title.Visible='off';                                % 隐藏标题
        set(tmp_axes,'units','default','position','default');        % 新坐标轴的设置
        print(tmp_f_handle, '-djpeg', 'tmp_img_for_recognize.jpg');  % 保存图片
        delete(tmp_f_handle);                                        % 删除临时figure
        img_rgb = imread('tmp_img_for_recognize.jpg');               % 读取图片
        img2    = rgb2gray(img_rgb);                                 % 将图片转为灰度图片
        img     = zeros(size(img2));                                 % 初始化图片
        img(img2==255) = 0;                                          % 将灰度图片中为白色的地方转为0
        img(img2~=255) = 255;                                        % 将灰度图片中不是白色的地方转为255
        if(all(img(:)==0)||all(img(:)==255))                         % 检测是否没有绘画
            title(['请先绘画']);                                     % 提示先绘画
            drawnow                                                  % 显示标题
            return                                                   % 直接返回
        end
        img = process_img(img);                                      % 处理图片
        title(['识别中....']);                                       % 标记正在识别
        drawnow                                                      % 显示标题
        predict_y = mc_predict_number(img(:));                       % 将图片使用蒙特卡罗法进行预测
        number = find(predict_y)-1;                                  % 将识别的one-hot转回数字
        title(['识别结果',num2str(number)]);                         % 显示识别结果
    end
% 鼠标按下的回调函数
    function  ButttonDownFcn(src,event)
        draw_enable = 1;                                             % 标记当前为绘画状态
        p = get(haxes,'CurrentPoint');                               % 获取当前的鼠标坐标
        draw_x(1) = p(1,1);                                          % 将当前的鼠标x坐标更新为绘画起点的x
        draw_y(1) = p(1,2);                                          % 将当前的鼠标y坐标更新为绘画起点的y
    end

% 鼠标弹起的回调函数
    function ButttonUpFcn(src,event)
        draw_enable = 0;                                             % 标记当前为非绘画状态
    end

% 鼠标移动回调函数(用于画图)
    function ButttonMotionFcn(src,event)
        if (draw_enable==1)                                          % 如果处于画画状态
            p= get(haxes,'CurrentPoint');                            % 获取鼠标位置点
            draw_x(2) = p(1,1);                                      % 将当前鼠标点的x作为画图结束点的x
            draw_y(2) = p(1,2);                                      % 将当前鼠标点的y作为画图结束点的y
            hold on                                                  % 保留之前的画图
            line(haxes,draw_x,draw_y,'LineWidth',12)                 % 画图
            draw_x(1) = draw_x(2);                                   % 将绘画终点的x更新为下次绘画起点的x
            draw_y(1) = draw_y(2);                                   % 将绘画终点的y更新为下次绘画起点的y
        end
    end
end


图片数据的加载函数init_img_sample.m代码如下:

function init_img_sample()
% 本部分加载手写数字的样本数据
global sample_x                          % 样本库的x数据
global sample_y                          % 样本库的y数据
load digitimages.mat                     % 加载手写数字数据
[h,w,pic_num] = size(images);            % 获取手写数字图片的大小与样本数量
sample_x = zeros(20*20,pic_num);         % 初始化图片样本数据
for i =1:pic_num                         % 逐张图片处理
    cur_x = process_img(images(:,:,i));  % 处理当前图片
    sample_x(:,i)=cur_x(:);              % 存储当前图片
end
sample_y = full(ind2vec(Y'+1));          % 将图片对应的数字转为one-hot矩阵
end

图片的处理函数process_img.m代码如下:

% 预处理图片的函数
function deal_img = process_img(img)
deal_img = img>50;                            % 将值>50的作为1,<50的作为0
deal_img = truncImgsPadding(deal_img);        % 对图片上下左右空白处进行裁剪
deal_img = imresize(deal_img,[20,20]);        % 将图片转换为20*20的Size
end

% 裁剪图片空白边缘部分
function trunc_img = truncImgsPadding(imgs)
% 裁剪左右两边的空白处
sum_imgs      = sum(imgs);                    % 按列求和
csum_imgs     = cumsum(sum_imgs);             % 计算累计值
[~,right_idx] = max(csum_imgs);               % 根据累计值找出右边第一个非0列
left_idx      = find(csum_imgs>0);            % 根据累计值找出非0列
left_idx      = left_idx(1);                  %  第一个非0列就是左边第一个非0列
trunc_img     = imgs(:,left_idx:right_idx);   % 进行左右裁剪
											 
% 裁剪上下的空白处                           
sum_imgs    = sum(trunc_img,2);               % 按行求行
csum_imgs   = cumsum(sum_imgs);               % 计行累计值
[~,bot_idx] =  max(csum_imgs);                % 根据累计值找出底部第一个非0行
top_idx     = find(csum_imgs>0);              % 根据累计值找出非0行
top_idx     = top_idx(1);                     % 第一个非0行就是顶部第一个非0行
trunc_img   = trunc_img(top_idx:bot_idx,:);   % 对上下进行裁剪
end

蒙特卡罗法的预测手写数字的函数mc_predict_number代码如下:

function y = mc_predict_number(x)
% 用蒙特卡罗法判断样本的类别     
global sample_x                                                      % 样本库的x数据
global sample_y                                                      % 样本库的y数据
setdemorandstream(88888);                                                            
t = 100;                                                             % 裁决次数
n = 200;                                                             % 每个类别抽样数量
[class_num,sample_num] = size(sample_y);                             % 类别个数与样本个数
class_idx = cell(class_num,1);                                       % 初始化各个类别的样本索引
for i = 1:class_num                                                  
    class_idx{i} = find(sample_y(i,:));                              % 找出属于第i个类别的样本索引
end                                                                  
															         
% 进行抽样裁决x的类别                                                                
rs = zeros(class_num,t);                                             % 初始化裁决结果表  
for i = 1:t                                                          
    select_idx = zeros(n*class_num,1);                               % 本次抽样的样本索引
    for j = 1:class_num                                              % 逐类别抽样
        cur_class_idx = class_idx{j};                                % 属于第i个类别的样本索引
        cur_select_idx = randperm(length(cur_class_idx),n);          % 随机抽出n个样本
        select_idx((j-1)*n+1:j*n) = cur_class_idx(cur_select_idx);   % 记录本次抽出的样本索引
    end
															        
   select_sample = sample_x(:,select_idx);                           % 抽取出本次抽样的样本
   d = sum((select_sample-x).^2);                                    % 计算各个样本与x的距离
   [~,win_idx] = min(d);                                             % 找出最小距离的样本作为本次胜出的样本
   win_y = sample_y(:,select_idx(win_idx));                          % 根据样本的索引找出y
   rs(:,i) = win_y;                                                  % 记录本次的获胜的y
end                                                                 
															        
% 统计多次抽样裁决的结果,用于决定最终x的所属类别                    
win_stat    = sum(rs,2);                                             % 统计各个类别胜出的次数
[~,win_idx] = max(win_stat);                                         % 找出哪个类别胜出次数最多
y = zeros(class_num,1);                                              % 初始化x的类别y
y(win_idx)  = 1;                                                     % 将胜出次数最多的类别,作为x的类别
end                                                                 


     运行结果    


将上述四个函数保存后,运行initDrawGUI.m,显示如下界面:
  
在上面进行书写后,点击识别,结果如下:
 
可以看到,已经可以正确地预测出所写的数字




    效果分析与补充    


经笔者测试,有些数字并不是那么的准确,
在《蒙特卡罗手写数字识别》一文中,使用样本库的样本进行测试,准确率达到99.66%,
但在本文的手写板中,笔者发现,准确率并不是那么的高,时不时就会发生预测不准的情况

粗略分析,主要来源于两方面,
1.样本库字体与实际手写不一致
样本库中的样本来源于国外,与我们的手写字体并不是那么的一致
这应该是引起预测错误的最主要原因
2.算法设计较为粗糙
在算法设计中,为了学习的简便性,
只是简单的使用mse函数来评估图片的相似度,
这对于实际应用中的复杂场景来说,过于粗糙,
精细化图片相似度的评估函数后应该能大大提高准确率











 End 






联系老饼