当前位置: 首页 > news >正文

卷积神经网络--手写数字识别

本文我们通过搭建卷积神经网络模型,实现手写数字识别。

pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载

MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试

图像是灰度的,28x28像素

首先,下载数据集

import torch
from torchvision import datasets #封装与图像相关的模型,数据集
from torchvision.transforms import ToTensor # #数据转换,张量,将其他类型的数据转换为tensor张量training_data=datasets.MNIST(root='data',#表示下载的手写数字到哪个路径train=True,#读取下载后数据中的训练集download=True,#如果之前已经下载过,就不用再下载transform=ToTensor(),#张量,图片不能直接传入神经网络模型
)test_data=datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)

打包数据

from torch.utils.data import DataLoader train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

判断当前设备是否支持GPU

device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')

构建卷积神经网络模型

from torch import nn #导入神经网络模块class CNN(nn.Module):def __init__(self):#初始化类super(CNN,self).__init__()#初始化父类self.conv1=nn.Sequential(# 将多个层(如卷积、激活函数、池化等)按顺序打包,输入数据会​​依次通过这些层​​,无需手动编写每一层的传递逻辑。nn.Conv2d(#2D 卷积层,提取空间特征。in_channels=1,#输入通道数out_channels=16,#输出通道数kernel_size=3,#卷积核大小stride=1,#步长padding=1,#填充),nn.ReLU(),#激活函数,引入非线性变换,使得神经网络能够学习复杂的非线性变换,增强表达能力nn.MaxPool2d(kernel_size=2)# 2x2最大池化(尺寸减半))self.conv2=nn.Sequential(nn.Conv2d(16,32,3,1,1),nn.ReLU(),# nn.Conv2d(32,32,3,1,1),# nn.ReLU(),nn.MaxPool2d(2),)self.conv3=nn.Sequential(nn.Conv2d(32,64,3,1,1))self.out=nn.Linear(64*7*7,10)def forward(self,x):#前向传播x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)# 展平为向量(保留batch_size,合并其他维度)output=self.out(x)  # 全连接层输出return output

返回的output结果大致如图所示

 模型传入GPU

model=CNN().to(device)
print(model)

  损失函数,衡量的是​​模型预测的概率分布​​与​​真实的类别分布​​之间的差异。

loss_fn=nn.CrossEntropyLoss()

  优化器,用于在训练神经网络时更新模型参数,目的是​​在神经网络训练过程中,自动调整模型的参数(权重和偏置),以最小化损失函数​​。

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)

 模型训练

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)loss=loss_fn(pred,y)# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()               #梯度值清零loss.backward()                     #反向传播计算得到每个参数的梯度值optimizer.step()                    #根据梯度更新网络参数loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:>7f}[number:{batch_size_num}]')batch_size_num+=1epochs=10for i in range(epochs):print(f'第{i}次训练')train(train_dataloader, model, loss_fn, optimizer)

模型测试

def test(dataloader,model,loss_fn):size = len(dataloader.dataset)# 测试集总样本数num_batches = len(dataloader)# 测试集总批次数model.eval()#进入到模型的测试状态,所有的卷积核权重被设为只读模式test_loss, correct = 0, 0# 初始化累计损失和正确预测数#禁用梯度计算with torch.no_grad():#一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')test(test_dataloader,model,loss_fn)

得到结果如图所示

相关文章:

卷积神经网络--手写数字识别

本文我们通过搭建卷积神经网络模型,实现手写数字识别。 pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载 MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试 图像是灰度的,28x28像素 …...

JavaScript-原型、原型链详解

一、构造函数 在 JavaScript 中,构造函数是一种特殊的函数,用于创建和初始化对象,它就像一个 “对象模板”。通过 new 关键字调用构造函数时,会创建一个新对象,并将构造函数中的属性和方法 “绑定” 到这个新对象上。…...

深度学习框架PyTorch——从入门到精通(3.3)YouTube系列——自动求导基础

这部分是 PyTorch介绍——YouTube系列的内容,每一节都对应一个youtube视频。(可能跟之前的有一定的重复) 我们需要Autograd做什么?一个简单示例训练中的自动求导开启和关闭自动求导自动求导与原地操作 自动求导分析器高级主题&…...

永磁同步电机控制算法-VF控制

一、原理介绍 V/F 控制又称为恒压频比控制,给定VF 控制曲线 电压是频率的tt例函数 即控制电压跟随频率变化而变化以保持磁通恒定不变。 二、仿真模型 在MATLAB/simulink里面验证所提算法,搭建仿真。采用和实验中一致的控制周期1e-4,电机部分计算周期为…...

【Docker 运维】Java 应用在 Docker 容器中启动报错:`unable to allocate file descriptor table`

文章目录 一、根本原因二、判断与排查方法三、解决方法1、限制 Docker 容器的文件描述符上限2、在执行脚本中动态设置ulimit的值3、升级至 Java 11 四、总结 容器内执行脚本时报错如下,Java 进程异常退出: library initialization failed - unable to a…...

SpringBoot + Vue 实现云端图片上传与回显(基于OSS等云存储)

前言 在实际生产环境中,我们通常会将图片等静态资源存储在云端对象存储服务(如阿里云OSS、七牛云、腾讯云COS等)上。本文将介绍如何改造之前的本地存储方案,实现基于云端存储的图片上传与回显功能。 一、技术选型 云存储服务&a…...

Session与Cookie的核心机制、用法及区别

Python中Session与Cookie的核心机制、用法及区别 在Web开发中,Session和Cookie是两种常用的用于跟踪用户状态的技术。它们在实现机制、用途和安全性方面都有显著区别。本文将详细介绍它们的核心机制、用法以及它们之间的主要区别。 一、Cookie的核心机制与用法 1…...

离线安装rabbitmq全流程

在麒麟系统(如银河麒麟)上离线安装 RabbitMQ 的具体操作步骤如下: 一、准备工作 确认系统版本:确认麒麟系统的版本,例如银河麒麟高级服务器 V10。确定 RabbitMQ 及依赖版本:根据系统版本确定兼容的 Rabbi…...

llama-webui docker实现界面部署

1. 启动ollama服务 [nlp server]$ ollama serve 2025/04/21 14:18:23 routes.go:1007: INFO server config env"map[OLLAMA_DEBUG:false OLLAMA_FLASH_ATTENTION:false OLLAMA_HOST: OLLAMA_KEEP_ALIVE:24h OLLAMA_LLM_LIBRARY: OLLAMA_MAX_LOADED_MODELS:4 OLLAMA_MAX_…...

第1 篇:你好,时间序列!—— 开启时间数据探索之旅

第 1 篇:你好,时间序列!—— 开启时间数据探索之旅 (图片来源: Stephen Dawson on Unsplash) 你有没有想过: 明天的天气会是怎样?天气预报是怎么做出来的?某支股票未来的价格走势如何预测?购物…...

C++算法(11):vector作为函数参数的三种传递方式详解

在C中,std::vector是最常用的动态数组容器之一。当我们需要将vector传递给函数时,不同的传递方式会对性能和功能产生显著影响。本文将详细介绍三种常见的传递方式及其适用场景,帮助开发者根据需求选择最合适的方法。 1. 按值传递(…...

版本控制利器——SVN简介

版本控制利器——SVN简介 在软件开发和项目管理的领域中,版本控制是一项至关重要的工作。它能帮助团队成员高效协作,确保代码的安全性和可追溯性。今天,我们就来详细介绍一款经典的版本控制系统——SVN(Subversion)。…...

链式栈和线性栈

‌1. 线性栈&#xff08;顺序栈&#xff09;‌ ‌结构定义‌&#xff1a; #include <iostream> using namespace std;#define MAX_SIZE 100 // 预定义最大容量// 线性栈结构体 typedef struct {int* data; // 存储数据的数组int top; // 栈顶指针&…...

消息中间件RabbitMQ:简要介绍及其Windows安装流程

一、简要介绍 定义&#xff1a;RabbitMQ 是一个开源消息中间件&#xff0c;用于实现消息队列和异步通信。 场景&#xff1a;适用于分布式系统、异步任务处理、消息解耦、负载均衡等场景。 比喻&#xff1a;RabbitMQ 就像是快递公司&#xff0c;负责在不同系统间安全快速地传递…...

足球 AI 智能体技术解析:从数据采集到比赛预测的全链路架构

一、引言 在足球运动数字化转型的浪潮中&#xff0c;AI 智能体正成为理解比赛、预测赛果的核心技术引擎。本文从工程实现角度&#xff0c;深度解析足球 AI 的技术架构&#xff0c;涵盖数据采集、特征工程、模型构建、实时计算到决策支持的全链路技术方案&#xff0c;揭示其背后…...

VTK知识学习(53)- 交互与Widget(四)

1、测量类Widget 1&#xff09;概述 与测量相关的主要 Widget如下&#xff1a; vtkDistanceWidget:用于在二维平面上测量两点之间的距离。vtkAngleWidget:用于二维平面的角度测量。vtkBiDimensionalWidget:用于测量二维平面上任意两个正交方向的轴长。 按照前面提到的步骤创…...

基础服务系列-Windows10 安装AnacondaJupyter

下载 https://www.anaconda.com/products/individual 安装 安装Jupyter 完成安装 启动Jupyter 浏览器访问 默认浏览器打开&#xff0c;IE不兼容&#xff0c;可以换个浏览器 修改密码 运行脚本...

使用c++调用deepseek的api(附带源码)

可以给服务器添加deepseek这样就相当于多了一个智能ai助手 deepseek的api申请地址使用格式测试效果源码 deepseek的api申请地址 这边使用硅基流动的api&#xff0c;注册就有免费额度 硅基流动: link 使用格式 api的调用格式&#xff0c;ds的api调用就是用固定协议然后发送到…...

HarmonyOS-ArkUI: animateTo 显式动画

什么是显式动画 啊, 尽管有点糙,但还是解释一下吧, 显式动画里面的“显式”二字, 是程序员在代码调用的时候,就三令五申,明明白白调用动画API而创建的动画。 这个API的名字就是: animateTo。这就是显式动画。说白了您可以大致理解为,显式动画,就是调用animateTo来完成…...

Spring AI MCP

MCP是什么 MCP是模型上下文协议&#xff08;Model Context Protocol&#xff09;的简称&#xff0c;是一个开源协议&#xff0c;由Anthropic&#xff08;Claude开发公司&#xff09;开发&#xff0c;旨在让大型语言模型&#xff08;LLM&#xff09;能够以标准化的方式连接到外…...

Kubernetes 创建 Jenkins 实现 CICD 配置指南

Kubernetes 创建 Jenkins 实现 CICD 配置指南 拉取 Jenkins 镜像并推送到本地仓库 # 从官方仓库拉取镜像&#xff08;若网络不通畅可使用国内镜像源&#xff09; docker pull jenkins/jenkins:lts-jdk11# 国内用户可去下面地址寻找镜像源并拉取&#xff1a; https://docker.a…...

01_Flask快速入门教程介绍

一、课程视频 01_Flask快速入门教程介绍 二、课程特点 讲课风格通俗易懂&#xff0c;理论与实战相结合 教程&#xff1a;视频 配套文档 配套的代码 最新本版&#xff0c;Python版本是3.12&#xff0c;Flask版本是3.10 即使是从没接触过Flsk的小白也看得懂学得会 三、适用人…...

SSH反向代理

SSH反向代理 一、过程 1、 确保树莓派和阿里云服务器的 SSH 服务正常运行 检查树莓派的ssh服务 sudo systemctl status ssh如果未启用&#xff0c;请启动并设置开机自启&#xff1a; sudo systemctl enable ssh sudo systemctl start ssh检查阿里云服务器的SSH服务 sudo …...

第 5 篇:初试牛刀 - 简单的预测方法

第 5 篇&#xff1a;初试牛刀 - 简单的预测方法 经过前面四篇的学习&#xff0c;我们已经具备了处理时间序列数据的基本功&#xff1a;加载、可视化、分解以及处理平稳性。现在&#xff0c;激动人心的时刻到来了——我们要开始尝试预测 (Forecasting) 未来&#xff01; 预测是…...

深度学习中的归一化技术:从原理到实战全解析

摘要&#xff1a;本文系统解析深度学习中的归一化技术&#xff0c;涵盖批量归一化&#xff08;BN&#xff09;、层归一化&#xff08;LN&#xff09;、实例归一化&#xff08;IN&#xff09;、组归一化&#xff08;GN&#xff09;等核心方法。通过数学原理、适用场景、优缺点对…...

流量抓取工具(wireshark)

协议 TCP/IP协议簇 网络接口层&#xff08;没有特定的协议&#xff09;PPPOE 物理层数据链路层 网络层: IP(v4/v6) ARP&#xff08;地址解析协议) RARP ICMP(Internet控制报文协议) IGMP传输层&#xff1a;TCP(传输控制协议&#xff09;UDP&#xff08;用户数据报协议)应用层…...

【原创】Ubuntu20.04 安装 Isaac Gym 仿真器

Isaac Gym 是 NVIDIA 开发的一个基于GPU的机器人仿真平台。其高效的 GPU 加速能力和大规模并行仿真性能&#xff0c;成为强化学习训练和机器人控制研究的重要选择。 本文将介绍 Isaac Gym 的安装过程【简易】。 1.配置环境 Ubuntu20.04 安装 NVIDIA 显卡驱动 Ubuntu20.04 安…...

AI 速读 SpecReason:让思考又快又准!

在大模型推理的世界里&#xff0c;速度与精度往往难以兼得。但今天要介绍的这篇论文带来了名为SpecReason的创新系统&#xff0c;它打破常规&#xff0c;能让大模型推理既快速又准确&#xff0c;大幅提升性能。想知道它是如何做到的吗&#xff1f;快来一探究竟&#xff01; 论…...

从“堆料竞赛”到“体验深耕”,X200 Ultra和X200s打响手机价值升维战

出品 | 何玺 排版 | 叶媛 vivo双旗舰来袭&#xff01; 4月21日&#xff0c;vivo X系列春季新品发布会盛大开启&#xff0c;带来了一场科技与创新的盛宴。会上&#xff0c;消费者期待已久的X200 Ultra及X200s两款旗舰新品正式发布。 vivo两款旗舰新品发布后&#xff0c;其打破…...

Macbook IntelliJ IDEA终端无法运行mvn命令

一、背景 idea工具里执行Maven命令mvn package&#xff0c;报错提示 zsh: command not found: mvn。 macOS&#xff0c;默认使用的是zsh&#xff0c;环境变量通常配置在 ~/.zshrc 文件中。 而我之前一直是配置在~/.bash_profile文件中。 二、环境变量 vi ~/.zshrc设置MAVE…...

CentOS 7进入救援模式——VirtualBox虚拟机

​ 目录 1. 在`VirtualBox`环境下,开机按F12,进入`VirtualBox temporary boot device selection `界面,按`c`键,选中`CD-ROM `回车。2. 选中`Troubleshooting`(故障排除),进入`Troubleshooting`界面3. 接下来会显示救援模式菜单,通常选择`"1) Continue"`(除非您…...

AI软件栈:LLVM分析(六)

LLVM后端代码生成的关键步骤 文章目录 指令选择指令调度寄存器分配 指令选择 完成从基于LLVM IR的DAG转换为基于特定目标平台的DAG&#xff08;注意&#xff0c;此时描述格式依然是DAG形态&#xff09;基于TabGen完成指令重映射&#xff08;典型的处理包括&#xff1a;指令拆散…...

【第十六届 蓝桥杯 省 C/Python A/Java C 登山】题解

题目链接&#xff1a;P12169 [蓝桥杯 2025 省 C/Python A/Java C] 登山 思路来源 一开始想的其实是记搜&#xff0c;但是发现还有先找更小的再找更大的这种路径&#xff0c;所以这样可能错过某些最优决策&#xff0c;这样不行。 于是我又想能不能从最大值出发往回搜&#xf…...

Github 热点项目 Jumpserver开源堡垒机让服务器管理效率翻倍

Jumpserver今日喜提160星&#xff0c;总星飙至2.6万&#xff01;这个开源堡垒机有三大亮点&#xff1a;① 像哆啦A梦的口袋&#xff0c;支持多云服务器一站式管理&#xff1b;② 安全审计功能超硬核&#xff0c;操作记录随时可回放&#xff1b;③ 网页终端无需装插件&#xff0…...

5V 1A充电标准的由来与技术演进——从USB诞生到智能手机时代的电力革命

点击下面图片带您领略全新的嵌入式学习路线 &#x1f525;爆款热榜 88万阅读 1.6万收藏 一、起源&#xff1a;USB标准与早期电力传输需求 1. USB的诞生背景 1996年&#xff0c;由英特尔、微软、IBM等公司组成的USB-IF&#xff08;USB Implementers Forum&#xff09;发布了…...

驱动开发硬核特训 · Day 16:字符设备驱动模型与实战注册流程

&#x1f3a5; 视频教程请关注 B 站&#xff1a;“嵌入式 Jerry” 一、为什么要学习字符设备驱动&#xff1f; 在 Linux 驱动开发中&#xff0c;字符设备&#xff08;Character Device&#xff09;驱动 是最基础也是最常见的一类驱动类型。很多设备&#xff08;如 LED、按键、…...

外网如何连接内网中的mysql数据库服务器

一、MySQL 产品简介 mysql是一款数据库产品&#xff0c;它主要用于存储、管理和检索数据&#xff0c;对用户的数据进行存储管理 二、运维人员遇到的问题 当内网服务器部署好mysql数据库后&#xff0c;外网如何安全的访问数据库进行增删改查&#xff0c;是运维人员遇到的一个…...

你的大模型服务如何压测:首 Token 延迟、并发与 QPS

写在前面 大型语言模型(LLM)API,特别是遵循 OpenAI 规范的接口(无论是 OpenAI 官方、Azure OpenAI,还是 DeepSeek、Moonshot 等众多兼容服务),已成为驱动下一代 AI 应用的核心引擎。然而,随着应用规模的扩大和用户量的增长,仅仅关注模型的功能是不够的,API 的性能表…...

4月谷歌新政 | Google Play今年对“数据安全”的管控将全面升级!

大家好&#xff0c;我是牢鹅&#xff01;每年的Q2季度是Google Play重要政策更新的时间节点&#xff0c;一般都伴随着重磅政策的更新&#xff0c;今年也不例外。4月10日&#xff0c;谷歌政策迎来2025年第二次更新&#xff0c;本次政策更新内容相较3月政策更新&#xff0c;不管是…...

第十四届蓝桥杯 2023 C/C++组 有奖问答

目录 题目&#xff1a; 题目描述&#xff1a; 题目链接&#xff1a; 思路&#xff1a; 核心思路&#xff1a; 思路详解&#xff1a; 代码&#xff1a; 代码详解&#xff1a; 题目&#xff1a; 题目描述&#xff1a; 题目链接&#xff1a; 蓝桥云课 有奖问答 思路&…...

【Redis】SpringDataRedis

Spring Data Redis 使得开发者能够更容易地与 Redis 数据库进行交互&#xff0c;并且支持不同的 Redis 客户端实现&#xff0c;如 Jedis 和 Lettuce。Spring Data Redis 会自动选择一个客户端&#xff0c;通常情况下&#xff0c;Spring Boot 默认使用 Lettuce 作为 Redis 客户端…...

XAttention

XAttention: Block Sparse Attention with Antidiagonal Scoring 革新Transformer推理的高效注意力机制资源​​ ​​论文链接​​&#xff1a;XAttention: Block Sparse Attention with Antidiagonal Scoring ​​代码开源​​&#xff1a;GitHub仓库 XAttention是韩松团队提…...

07.Python代码NumPy-排序sort,argsort,lexsort

07.Python代码NumPy-排序sort&#xff0c;argsort&#xff0c;lexsort 提示&#xff1a;帮帮志会陆续更新非常多的IT技术知识&#xff0c;希望分享的内容对您有用。本章分享的是NumPy的使用语法。前后每一小节的内容是存在的有&#xff1a;学习and理解的关联性&#xff0c;希望…...

无人机飞控运行在stm32上的RTOS实时操作系统上,而不是linux这种非实时操作系统的必要性

飞控程序需要运行在STM32等微控制器&#xff08;MCU&#xff09;的实时操作系统&#xff08;RTOS&#xff09;而非Linux等非实时操作系统&#xff08;如通用Linux内核&#xff09;&#xff0c;主要原因在于实时性、资源占用、硬件适配性以及系统可靠性等方面的实质性差异。以下…...

Leetcode - 周赛446

目录 一、3522. 执行指令后的得分二、3523. 非递减数组的最大长度三、3524. 求出数组的 X 值 I四、3525. 求出数组的 X 值 II 一、3522. 执行指令后的得分 题目链接 本题就是一道模拟题&#xff0c;代码如下&#xff1a; class Solution {public long calculateScore(String…...

Linux——系统安全及应用

目录 一&#xff1a;账号安全控制 1&#xff0c;基本安全措施 系统账号清理 密码安全控制 命令历史&#xff0c;自动注销 2&#xff0c;用户切换与提权 su命令的用法 PAM认证 3&#xff0c;sudo命令——提升执行权限 在配置文件/etc/sudoers中添加授权 通过sudo执行…...

随机面试--<二>

编译安装软件的流程 1-安装所需源代码 2-配置安装环境 3-进行相关设置 4-编译 5-安装 nginx安装新模块的流程 1-准备与原nginx版本相同的源码包&#xff0c;准备模块安装包 2-准备编译安装环境 3-配置参数 来源于nginx -V配置原模块 以及--add-module 增加模块 4-mak…...

LeetCode面试经典 150 题(Java题解)

一、数组、字符串 1、合并两个有序数组 从后往前比较&#xff0c;这样就不需要使用额外的空间 class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int l mn-1, i m-1, j n-1;while(i > 0 && j > 0){if(nums1[i] > nums2[j])…...

【技术追踪】Differential Transformer(ICLR-2025)

Differential Transformer&#xff1a;大语言模型新架构&#xff0c; 提出了 differential attention mechanism&#xff0c;Transformer 又多了一个小 trick~ 论文&#xff1a;Differential Transformer 代码&#xff1a;https://github.com/microsoft/unilm/tree/master/Diff…...

报告系统状态的连续日期 mysql + pandas(连续值判断)

本题用到知识点&#xff1a;row_number(), union, date_sub(), to_timedelta()…… 目录 思路 pandas Mysql 思路 链接&#xff1a;报告系统状态的连续日期 思路&#xff1a; 判断连续性常用的一个方法&#xff0c;增量相同的两个列的差值是固定的。 让日期与行号 * 天数…...