当前位置: 首页 > news >正文

Tensorflow图像识别 Tensorflow手写体识别(二)

资源介绍

我们从

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

这条链接(MNIST官网)中下载好数据集,如下:

        下载下来以后整理成包含四个压缩包的文件MNIST_data(不要解压),然后把数据集直接拷贝到我们的代码目录下面,执行一个复制,粘贴到当前目录下面

这是本次项目中要用到的所有文件,您可以下载我的链接:

https://download.csdn.net/download/llf000000/86922664

数据集

数据集介绍

        之前我以为这60000个样本都是png图片,但是如果真的下载下来的话会很占内存。

        MNIST数据库由NIST的特殊数据库3和特殊数据库1构成,其中包含手写数字的二进制图像。
NIST最初指定SD-3为训练集,SD-1为测试集。然而,SD-3比SD-1更干净、更容易识别。究其原因,可以发现SD-3是在人口普查局员工中收集的,而SD-1是在高中生中收集的。从学习实验中得出合理的结论要求实验结果不依赖于训练集的选择和样本集的测试。因此,有必要通过混合NIST的数据集来建立一个新的数据库。

        MNIST训练集由SD-3的30000个模式和SD-1的30000个模式组成。我们的测试集由来自SD-3的5000个模式和SD-1的5000个模式组成。60000个模式训练集包含了大约250个作家的例子。我们确保训练集和测试集的编写器集是不相交的。

        SD-1包含了58527位数字图像,由500位不同的作家编写。与SD-3不同,SD-3中来自每个writer的数据块按顺序出现,SD-1中的数据被置乱。SD-1的Writer标识是可用的,我们使用这些信息来解读Writer。然后我们把SD-1分成两部分:第一批250名作家写的字符进入了我们的新训练集。剩下的250个作家被放在我们的测试集中。因此,我们有两组,每一组有近30000个例子。(SD1加起来总数是6000)

数据集格式

        文件中的所有整数都以多数非英特尔处理器使用的MSB first(高端)格式存储。英特尔处理器和其他低端计算机的用户必须翻转标头的字节。

有4个文件:
train-images-idx3-ubyte:训练集图像
train-labels-idx1-ubyte:训练集标签
t10k-images-idx3-ubyte:测试集图像
t10k-labels-idx1-ubyte:测试集标签

(1)训练集包含60000个示例。
(2)测试集包含10000个示例。测试集的前5000个示例取自原始的NIST训练集。最后5000个是从最初的NIST测试集中提取的。前5000个比后5000个更干净、更简单。

训练集是有60000个用例的,也就是说这个文件里面包含了60000个标签内容,每一个标签的值为0到9之间的一个数;

参考链接:

MNIST数据集的图片读取显示,并保存图片(python代码)_weixin_43094275的博客-CSDN博客_mnist图片显示

代码查看手写体识别案例

        由于数据集每个图片直接下载下来不现实,,而这些数据集已经被神秘力量整理成可被代码识别的压缩包,通过查阅资料得知我们可以通过编写代码可视化这四个数据集。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt# MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签
mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)
# load data
train_X = mnist.train.images
train_Y = mnist.train.labels
print(train_X.shape, train_Y.shape)  # 输出训练集样本和标签的大小
# 查看数据,例如训练集中第一个样本的内容和标签
print(train_X[0])  # 是一个包含784个元素且值在[0,1]之间的向量
print(train_Y[0])
# 可视化样本,下面是输出了训练集中前4个样本
fig, ax = plt.subplots(nrows=2, ncols=2, sharex='all', sharey='all')
ax = ax.flatten()
for i in range(4):img = train_X[i].reshape(28, 28)# ax[i].imshow(img,cmap='Greys')ax[i].imshow(img)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

 

手写体识别案例

# 03_mnist.py
# 手写体识别案例
# 模型:全连接模型
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pylab  # 用于显示图片# 定义样本读取对象
# 这个就是定义一个专门用于mnist数据集的读取的对象
mnist = input_data.read_data_sets("MNIST_data/",  # 数据集所在目录one_hot=True)  # 标签是否采用独热编码
# 定义占位符,用于表示图像数据、标签
# 因为这些数据都要从样本中读进来,穿进来,所以我们要定义一个占位符
x = tf.placeholder(tf.float32, [None, 784])  # 图像数据,N行784列
y = tf.placeholder(tf.float32, [None, 10])  # 标签(图像真实类别), N行784列# 定义权重、偏置
w = tf.Variable(tf.random_normal([784, 10]))  # 权重,784行10列
b = tf.Variable(tf.zeros([10]))  # 偏置, 10个偏置   十路输出所以又10个偏置# 构建模型,计算预测结果
pred_y = tf.nn.softmax(tf.matmul(x, w) + b)
'''
把x和w相乘,n行784列的矩阵×784行10列的矩阵,产生一个n行十列的输出,输出的10个值加上偏置,
这就是神经网络的计算公式,然后把它交给softmax函数进行挤压,转换成0到1的相对概率,
这个就作为我们的预测值'''
# 损失函数
cross_entropy = -tf.reduce_sum(y * tf.log(pred_y), reduction_indices=1)
'''
真实的值×预测的值求对数,然后求和,reduce_sum在指定的维度上求和
'''
cost = tf.reduce_mean(cross_entropy)  # 求均值
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
'''
学习率为0.01,然后调用这个对象的minimize方法,把损失函数的值优化到最小
,优化的目标函数就是cost
'''
batch_size = 100  # 批次大小
saver = tf.train.Saver()  # saver
'''
用于模型的保存和加载
'''
model_path = "model/mnist/mnist_model.ckpt"  # 模型路径
'''
checkpoint
'''with tf.Session() as sess:sess.run(tf.global_variables_initializer())  # 初始化# 开始训练for epoch in range(10):# 计算总批次total_batch = int(mnist.train.num_examples / batch_size)avg_cost = 0.0for i in range(total_batch):# 从训练集读取一个批次的样本batch_xs, batch_ys = mnist.train.next_batch(batch_size)'''xs是图像数据,ys是标签'''params = {x: batch_xs, y: batch_ys}  # 参数字典o, c = sess.run([optimizer, cost],  # 执行的opfeed_dict=params)  # 喂入参数'''第一个op 执行optimizer执行梯度下降第二个op 执行cost取得损失函数的值'''avg_cost += (c / total_batch)  # 计算平均损失值print("epoch:%d, cost=%.9f" % (epoch + 1, avg_cost))print("训练结束.")# 模型评估# 比较预测结果和真实结果,返回布尔类型的数组correct_pred = tf.equal(tf.argmax(pred_y, 1),  # 求预测结果中最大值的索引tf.argmax(y, 1))  # 求真实结果中最大的索引# 将布尔类型数组转换为浮点数,并计算准确率# 因为计算均值、准确率公式相同,所以调用计算均值的函数计算准确率accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))  # cast将correct_pred转换成浮点类型print("accuracy:", accuracy.eval({x: mnist.test.images,  # 测试集下的图像数据y: mnist.test.labels}))  # 测试集下图像的真实类别'''eval等价于放到session里面去run()'''# 保存模型save_path = saver.save(sess, model_path)print("模型已保存:", save_path)# 从测试集中随机读取2张图像,执行预测
with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver.restore(sess, model_path)  # 加载模型# 从测试集中读取样本batch_xs, batch_ys = mnist.test.next_batch(2)output = tf.argmax(pred_y, 1)  # 直接取出预测结果中的最大值output_val, predv = sess.run([output, pred_y],  # 执行的opfeed_dict={x: batch_xs})  # 预测,所以不需要传入标签print("预测最终结果:\n", output_val, "\n")print("真实结果:\n", batch_ys, "\n")print("预测概率:\n", predv, "\n")# 显示图片im = batch_xs[0]  # 第一个测试样本im = im.reshape(-1, 28)  # 28列,-1表示经过计算的值(是多少就是多少),行数根据图形的大小来算,算出来有多少行就有多少行pylab.imshow(im)  # 显示图像pylab.show()im = batch_xs[1]  # 第二个测试样本im = im.reshape(-1, 28)  # 28列,-1表示经过计算的值pylab.imshow(im)pylab.show()

 

相关文章:

git pull 和 git fetch的区别?

`git pull`和`git fetch`都是Git版本控制系统中用于与远程仓库交互的命令,但它们在操作和结果上有一些关键的区别: 1. **操作内容**: - `git fetch`:这个命令仅仅下载远程仓库的更新信息(即远程分支的最新提交),并将这些更新保存到本地仓库的远程分支跟踪信息中。它不…...

如何在服务器上传/下载文件

从服务器下载文件到本地 打开xshell,输入:ssh root159.xxx.xxx.xx 然后需要输入密码 cd到目录文件夹下 cd /enmotech apt install zip zip -r uploads.zip uploads apt install lrzsz sz uploads.zip 从本地上传文件到服务器 如果文件是放在E盘…...

Camera基础知识三

参考资料:极客笔记 侵权联删Camera sensor状态机:状态机:POWER OFF、hardware standby、software、streaming 没电的时候就是power off状态,上电了进入hardware standby状态,xshutdown也就是reset,进入software standby状态。PLL寄存器配置进去之后就进入streaming状态Ca…...

D. Birthday Gift

原题链接 题解 1.异或是01变1,11变0,或是01变1,11变1,所以或的越多(即分的组越多),结果越大 2.我们令x=x+1,这样小于等于x的 问题就变成了小于x 的问题。 3.对于某一位而言,如果有奇数个元素在这一位上是1,那么不管怎么分,最后的结果肯定是1,如果是偶数,那么最后的…...

P7137 [THUPC2021 初赛] 切切糕 题解

题目传送门 前置知识 博弈论 解法 由于本题是 CF1628D1 Game on Sum (Easy Version) 的扩展,故先从 CF1628D1 Game on Sum (Easy Version) 讲解。 CF1628D1 Game on Sum (Easy Version) 设 \(x_{i}\) 表示第 \(i\) 轮时 Alice 选择的数。 设 \(f_{i,j}\) 表示已经进行了 \(i\)…...

那么iPaaS平台的应用场景有哪些呢?

随着全球步入数字化转型的关键阶段,企业的各类业务功能正面临前所未有的颠覆性革新。传统的、孤立的信息系统和业务管理模式已难以适应快速变化的市场环境和日益增长的业务复杂性。iPaaS(Integration Platform as a Service,集成平台即服务)技术解决方案应运而生,以其强大…...

Tensorflow图像识别 Tensorflow手写体识别(二)

资源介绍 我们从 MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges 这条链接(MNIST官网)中下载好数据集,如下: 下载下来以后整理成包含四个压缩包的文件MNIST_data(不要解压&#x…...

盘点上海IB国际学校,你会选哪一所呢?

之前,小编给大家盘点了上海热门的AP学校和Alevel学校,同时也介绍了国际课程的具体情况;今天就和大家聊聊上海的IB国际学校。IB即是国际文凭组织IBO(International Baccalaureate Organisation)为全球学生开设从幼儿园到大学预科的课程&#x…...

jmockit-01-test 之 jmockit 入门使用案例

拓展阅读 jmockit-01-jmockit 入门使用案例 jmockit-02-概览 jmockit-03-Mocking 模拟 jmockit-04-Faking 伪造 jmockit-05-代码覆盖率 mockito-01-入门介绍 mockito-02-springaop 整合遇到的问题,失效 jmockit 说明 jmockit 可以提供基于 mock 的测试能力…...

010——服务器开发环境搭建及开发方法(下)

目录 三、 第一个驱动程序 四、 buildroot 4.1 制作根文件系统 4.2 buildroot使用 五、 uboot 009——服务器开发环境搭建及开发方法(上)-CSDN博客 三、 第一个驱动程序 # 1. 使用不同的开发板内核时, 一定要修改KERN_DIR # 2. KERN_DIR中的内核要…...

Machine Learning机器学习之统计分析

目录 前言 机器学习之统计分析 统计学的主要目标包括: 统计学核心概念: 统计基础: 训练误差: 常见的损失函数: 正则化和交叉验证 博主介绍:✌专注于前后端、机器学习、人工智能应用领域开发的优质创作者、秉…...

蚂蚁庄园今日答案

蚂蚁庄园是一款爱心公益游戏,用户可以通过喂养小鸡,产生鸡蛋,并通过捐赠鸡蛋参与公益项目。用户每日完成答题就可以领取鸡饲料,使用鸡饲料喂鸡之后,会可以获得鸡蛋,可以通过鸡蛋来进行爱心捐赠。其中&#…...

Java:接口应用(Comparable接口与比较器)

目录 1.引例2.Comparable接口使用3.Comparable接口的局限性4.使用comparaTo实现排序5.比较器(Comparator接口) 1.引例 class Student{private String name;private int age;public Student(String name, int age) {this.name name;this.age age;} } p…...

LeetCode 1027——最长等差数列

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 假设我们以 f[d][nums[i]]表示以 nums[i] 为结尾元素间距为 d 的等差数列的最大长度,那么,如果 nums[i]-d 也存在于 nums 数组中,则有: f [ d ] [ n u m s [ i ] ] …...