使用PyTorch实现图像增广与模型训练实战
本文通过完整代码示例演示如何利用PyTorch和torchvision实现常用图像增广方法,并在CIFAR-10数据集上训练ResNet-18模型。我们将从基础图像变换到复杂数据增强策略逐步讲解,最终实现一个完整的训练流程。
一、图像增广基础操作
1.1 准备工作
#matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2ld2l.set_figsize()
img = d2l.Image.open('/workspace/data/cat.png')
d2l.plt.imshow(img)
1.2 图像变换工具函数
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5, titles=None):y = [aug(img) for _ in range(num_rows*num_cols)]d2l.show_images(y, num_rows, num_cols, titles, scale)
二、常用图像增广方法
2.1 水平/垂直翻转
# 水平翻转
apply(img, torchvision.transforms.RandomHorizontalFlip())# 垂直翻转
apply(img, torchvision.transforms.RandomVerticalFlip())
2.2 随机裁剪
shape_aug = torchvision.transforms.RandomResizedCrop((200,200), scale=(0.1,1), ratio=(0.5,2))
apply(img, shape_aug)
2.3 颜色调整
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.3, hue=0.5)
apply(img, color_aug)
2.4 组合增广策略
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug
])
apply(img, augs)
三、CIFAR-10数据增强实战
3.1 数据加载与可视化
all_images = torchvision.datasets.CIFAR10(train=True, root='/workspace/data', download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)
3.2 数据预处理配置
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])test_augs = torchvision.transforms.ToTensor()
3.3 数据加载函数
def load_cifar10(is_train, augs, batch_size):dataset = torchvision.datasets.CIFAR10(root='../data', train=is_train,transform=augs, download=True)return torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=is_train, num_workers=4)
四、模型训练实现
4.1 训练核心函数
def train_batch_ch13(net, X, y, loss, trainer, devices):if isinstance(X, list):X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])y = y.to(devices[0])net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum
4.2 模型初始化
batch_size = 1024
devices = d2l.try_all_gpus()
net = d2l.resnet18(10, 3)def init_weights(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)
4.3 训练入口函数
def train_with_data_aug(train_augs, test_augs, net, lr=0.001):train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)loss = nn.CrossEntropyLoss(reduction='none')optimizer = torch.optim.Adam(net.parameters(), lr=lr)d2l.train_ch13(net, train_iter, test_iter, loss, optimizer, 10, devices)# 启动训练
train_with_data_aug(train_augs, test_augs, net)
五、训练结果分析
执行训练后可以看到类似如下输出:
train loss 0.018, train acc 0.895
test acc 0.856
典型训练过程特征:
-
训练损失持续下降
-
验证准确率稳步提升
-
最终测试准确率可达85%以上
六、关键知识点总结
-
图像增广作用:通过随机变换增加数据多样性,提升模型泛化能力
-
组合策略:合理组合几何变换与颜色变换可以达到最佳效果
-
训练技巧:
-
使用Xavier初始化保证参数合理分布
-
Adam优化器自动调整学习率
-
多GPU并行加速训练
-
七、扩展改进方向
1.尝试更多增广组合:
advanced_augs = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(15),torchvision.transforms.RandomPerspective(),torchvision.transforms.RandomGrayscale(p=0.1)
])
2.调整网络结构:
net = d2l.resnet50(10, 3) # 使用更深层的ResNet-50
3.优化参数:
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
完整代码已通过测试,可直接复制到Jupyter Notebook中运行。实际效果可能因硬件配置有所差异,建议使用GPU环境进行训练。如果遇到数据集下载问题,请检查root
参数指定的路径是否正确。
相关文章:
使用PyTorch实现图像增广与模型训练实战
本文通过完整代码示例演示如何利用PyTorch和torchvision实现常用图像增广方法,并在CIFAR-10数据集上训练ResNet-18模型。我们将从基础图像变换到复杂数据增强策略逐步讲解,最终实现一个完整的训练流程。 一、图像增广基础操作 1.1 准备工作 #matplotli…...
PyTorch实现糖尿病预测的CNN模型:从数据加载到模型部署全解析【N折交叉验证、文末免费下载】
本文将详细介绍如何使用PyTorch框架构建一个卷积神经网络(CNN)来预测糖尿病,包含完整的代码实现、技术细节和可视化分析。 1. 项目概述 本项目使用经典的Pima Indians Diabetes数据集,通过5折交叉验证训练一个1D CNN模型,最终实现糖尿病预测…...
红队专题-漏洞挖掘-代码审计-反序列化
漏洞挖掘-代码审计-反序列化 加固/防御命令执行相关日志Tools-JNDIExploitJNDI Java Naming and Directory Interface Java命名目录接口注入原理payload参数渗透测试-php命令执行-RCE+Struts2拿webshell普通权限 命令执行 拿 webshellCMD echo 写入一句话 php文件菜刀连接Strut…...
【2025软考高级架构师】——计算机系统基础(7)
摘要 本文主要介绍了计算机系统的组成,包括硬件和软件两大部分。硬件由处理器、存储器、总线、接口和外部设备等组成,软件则涵盖系统软件和应用软件。文章还详细阐述了冯诺依曼计算机的组成结构,包括 CPU、主存储器、外存等,并解…...
【网络原理】TCP协议如何实现可靠传输(确认应答和超时重传机制)
目录 一. TCP协议 二. 确定应答 三. 超时重传 一. TCP协议 1)端口号 源端口号:发送方端口号目的端口号:接收方端口号 16位(2字节)端口号,可以表示的范围(0~65535) 源端口和目的…...
Java synchroinzed和ReentrantLock
synchronized —— JVM亲儿子的暗黑兵法 核心思想:“锁即对象,对象即锁!” 底层三板斧 对象头里的锁密码 每个Java对象头里藏了两个骚东西: Mark Word:32/64位的比特修罗场,存哈希码、GC年龄࿰…...
【Linux】vim配置----超详细
目录 一、插件管理器准备 二、目录准备 三、安装插件 一、插件管理器准备 Vim-plug 是一个Vim插件管理器,利用异步并行可以快速地安装、更新和卸载插件。它的安装和配置都非常简单,而且在操作过程中会给出很多易读的反馈信息,是一个自由、…...
驱动开发硬核特训 · Day 15:电源管理核心知识与实战解析
在嵌入式系统中,电源管理(Power Management)并不是“可选项”,而是实际部署中影响系统稳定性、功耗、安全性的重要一环。今天我们将以 Linux 电源管理框架 为基础,从理论结构、内核架构,再到典型驱动实战&a…...
如何使用人工智能大模型,免费快速写工作计划?
如何使用人工智能大模型,免费快速写工作计划? 具体视频教程https://edu.csdn.net/learn/40406/666579...
延长(暂停)Windows更新
延长(暂停)Windows更新 因为不关闭更新有时候就会出现驱动或者软硬件不兼容,导致蓝屏出现。 注:为什么选择延长更新而不是用软件暂停更新,因为使用软件暂停更新会出现一下问题,比如微软商店打不开等等 键…...
QT实现串口透传的功能
在一些产品的开发的时候,需要将一个串口的数据发送给另外一个串口进行转发。 具体的代码如下: #include "mainwindow.h" #include "ui_mainwindow.h"MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::Ma…...
分布类相关的可视化图像
目录 一、直方图(Histogram) 1.定义 2.特点 3.局限性 4.类型 5.应用场景 6.使用Python实现 二、密度图(Density Plot) 1.定义 2.特点 3.局限性 4.类型 5.应用场景 6.使用Python实现 三、箱线图(Box Plo…...
【android bluetooth 框架分析 02】【Module详解 12】【 BidiQueue、BidiQueueEnd、Queue介绍】
1. BidiQueue 和 BidiQueueEnd 蓝牙协议栈里面有很多 BidiQueue ,本节就专门来梳理这块内容。 2. BidiQueue 介绍 BidiQueue,是 Host 与 Controller 层通信的中枢之一, acl_queue_、sco_queue_、iso_queue_ 都是 BidiQueue 类型。让我们一起看一下这个…...
c++通讯录管理系统
通讯录是一个可以记录亲人,好友的信息工具。 功能包括: 1,添加联系人:向通讯录添加新人,包括(姓名,性别年龄,联系电话,家庭住址) 2,显示联系人…...
React 打包
路由懒加载 原本的加载方式 #使用lazy()函数声明的路由页面 使用Suspense组件进行加载 使用CDN优化...
day1 python训练营
变量与输出 print(1,2,3,sep\n,endsep用来区分两个变量,end会紧跟最后一个变量) print(1,2,3,sepaaa,endsep用来区分两个变量,3后面不会再输出aaa) 格式化字符串 变量名值 print(f"变量名{变量名}") 变量的基础运算 ,-*,/ 注意*不要忘写。比如2j就不…...
C语言状态字与库函数详解:概念辨析与应用实践
C语言状态字与库函数详解:概念辨析与应用实践 一、状态字与库函数的核心概念区分 在C语言系统编程中,"状态字"和"库函数"是两个经常被混淆但本质完全不同的概念,理解它们的区别是掌握系统编程的基础。 1. 状态字&…...
软件测试笔记(测试的概念、测试和开发模型介绍、BUG介绍)
软件测试笔记 认识测试 软件测试是啥? 说白了,就是检查软件的功能和效果是不是用户真正想要的东西。比如用户说“我要一个能自动算账的软件”,测试就是看这个软件到底能不能准确算账、有没有漏掉功能。 软件测试定义:软件测试就…...
Python多进程同步全解析:从竞争条件到锁、信号量的实战应用
1. 进程同步的必要性 在多进程编程中,当多个进程需要访问共享资源时,会出现竞争条件问题。例如火车票售卖系统中,如果多个售票窗口同时读取和修改剩余票数,可能导致数据不一致。 1.1 竞争条件示例 from multiprocessing import…...
Vue3 + TypeScript,关于item[key]的报错处理方法
处理方法1:// ts-ignore 注释忽略报错 处理方法2:item 设置为 any 类型...
Spring源码中关于抽象方法且是个空实现这样设计的思考
Spring源码抽象方法且空实现设计思想 在Spring源码中onRefresh()就是一个抽象方法且空实现,而refreshBeanFactory()方法就是一个抽象方法。 那么Spring源码中onRefresh方法定义了一个抽象方法且是个空实现,为什么这样设置,好处是什么。为…...
Pandas数据可视化
在当今这个数据驱动的时代,数据可视化已经成为数据分析不可或缺的一部分。通过图形化的方式展示数据,我们能够更直观地理解数据的分布、趋势和关系,从而做出更加精准的决策。Pandas,作为Python中最为流行的数据处理库,…...
string类(详解)
【本节目标】 1. 为什么要学习string类 2. 标准库中的string类 3. string类的模拟实现 4. 扩展阅读 1. 为什么学习string类? 1.1 C语言中的字符串 C 语言中,字符串是以 \0 结尾的一些字符的集合,为了操作方便, C 标准库中提供…...
零基础上手Python数据分析 (19):Matplotlib 高级图表定制 - 精雕细琢,让你的图表脱颖而出!
写在前面 —— 超越默认样式,掌握 Matplotlib 精细控制,打造专业级可视化图表 上一篇博客,我们学习了 Matplotlib 的基础绘图功能,掌握了如何绘制常见的折线图、柱状图、散点图和饼图,并进行了基本的图表元素定制,例如添加标题、标签、图例等。 这些基础技能已经能让我…...
【上位机——MFC】MFC入门
MFC库中相关类简介 CObject MFC类库中绝大部分类的父类,提供了MFC类库中一些基本的机制。 对运行时类信息的支持。对动态创建的支持。对序列化的支持。 CWinApp 应用程序类,封装了应用程序、线程等信息。 CDocument 文档类,管理数据 F…...
ASP.NET Core 最小 API:极简开发,高效构建(下)
在上篇文章 ASP.NET Core 最小 API:极简开发,高效构建(上) 中我们添加了 API 代码并且测试,本篇继续补充相关内容。 一、使用 MapGroup API 示例应用代码每次设置终结点时都会重复 todoitems URL 前缀。 API 通常具有…...
【leetcode100】一和零
1、题目描述 给你一个二进制字符串数组 strs 和两个整数 m 和 n 。 请你找出并返回 strs 的最大子集的长度,该子集中 最多 有 m 个 0 和 n 个 1 。 如果 x 的所有元素也是 y 的元素,集合 x 是集合 y 的 子集 。 示例 1: 输入:…...
代码随想录算法训练营第五十三天 | 105.有向图的完全可达性 106.岛屿的周长
105.有向图的完全可达性 题目链接:101. 孤岛的总面积 文章讲解:代码随想录 视频讲解:图论:岛屿问题再出新花样 | 深搜优先搜索 | 卡码网:101.孤岛总面积_哔哩哔哩_bilibili 思路: 1.确认递归函数&…...
在 Debian 10.x 安装和配置 Samba
1. 更新系统 sudo apt update sudo apt upgrade -y2. 安装 Samba sudo apt install samba -y3. 配置 Samba 备份默认配置文件 sudo cp /etc/samba/smb.conf /etc/samba/smb.conf.bak编辑配置文件 sudo nano /etc/samba/smb.conf示例配置(共享目录) …...
Python中的短路运算
近期在学习python的过程中遇到此问题,遂总结记录 在”and“逻辑判定布尔类型时: 若判定对象均为True,则输出最后一个判别为True的对象 若判定对象的数据类型中有布尔类型,且最终结果为False,则输出布尔类型False 若判定对象的…...
Java8-遍历list取出两个字段重新组成list集合
在Java 8中,可以使用Stream API遍历List并提取两个字段重新组合成新的List。 以下是几种常见方法: 方法1:使用自定义类 定义一个包含目标字段的类:public class FieldHolder {private final String field1;private final int field2;public FieldHolder(String field1, i…...
【C++ 程序设计】实战:C++ 实践练习题(31~40)
目录 31. 数列:s 1 + 2 + 3 + … + n 32. 数列:s 1 - 2 - 3 - … - n 33. 数列:s 1 + 2 - 3 + … - n 34. 数列:s 1 - 2 + 3 - … &#…...
【笔记】SpringBoot实现图片上传和获取图片接口
上传图片接口 接口接收图片文件和布尔类型的是否生成缩略图参数。 生成保存图片文件的文件夹,文件夹的命名为上传图片的日期“根目录\file\cover\202504”,如果文件夹已存在则不生成。接下来拼接文件名,生成30位的随机数拼接到原文件名防止文件名相同的…...
Linux 下依赖库的问题
假设你在 某用户 user_name 下安装了一个 rquests库。 然后你在命令行使用 python3 -c (...)验证。发现没有任何问题。 然后你使用python3 xxx.py 发现执行验证也没有问题。 这个时候你信心慢慢的写了一个C的代码在代码中system调用这个.py文件。 然…...
STM32 HAL 水位传感器驱动程序
工作原理是输出模拟量电压值,只需要使用stm32adc读取电压再转换一下即可 本代码中,水位传感器连接在PA0,可通过宏定义快速设置电压区间和水位之间的关系 water_level.c /***************************************************************…...
DeepSeek R1 7b,Langchain 实现 RAG 知识库 | LLMs
DeepSeek R1 7b,Langchain 实现 RAG 知识库 | LLMs DeepSeek R1 7b,Langchain 实现 RAG 知识库DeepSeek R1Chat via ConsoleChat via Browser LangchainFAQs GitHub https://github.com/hailiang-wang/ollama-get-started DeepSeek R1 7b,La…...
【C语言】char unsigned char signed char
在C语言中,char 和 unsigned char 虽然都是1字节(通常8位)的数据类型,但它们在符号处理、数值范围和用途上有显著区别。以下是详细对比: 1. 核心区别 特性charunsigned char符号性可能是signed或unsigned(由编译器决定)明确无符号(仅非负数)数值范围通常 -128 到 1270…...
硬件电路(24)-NE555振荡电路
一、概述 NE555 是一款能产生高精度定时脉冲的双极性集成电路。内部包括阈值比较器、触发比较器、RS触发 器、输出电路等四部分电路构成。它可通过外接少量的阻容器件,组成定时触发电路、脉宽调制电路、音 频振荡器等等电路。广泛应用于玩具、信号交通、自动化控制等…...
Transformer系列(二):自注意力机制框架
自注意力机制框架 一、K-Q-V的自注意力机制二、位置表征1. 通过学习嵌入来进行位置表征2. 通过直接改变 α \alpha α来进行位置表征 三、逐元素非线性变换四、未来掩码(future mask)五、总结 上篇博客:NLP中放弃使用循环神经网络架构讲解了循环神经网络…...
安全技术和防火墙
传输层4.7层防火墙 传输层(4)四层防火墙:ip地址 mac地址 协议 端口号来控制数据流量 应用层防火(7)墙/代理服务器: ip地址 mac地址 协议 端口号来控制数据流量 真实传输的数据(把前面的ip地址…...
深度可分离卷积与普通卷积的区别及原理
1. 普通卷积 普通卷积使用一个滤波器在输入特征图的所有通道上滑动,同时对所有通道进行加权求和,生成一个输出通道。如果有多个滤波器,则生成多个输出通道。假设上一层的特征图有 n 个通道,每个通道是一个二维的图像(…...
STM32时钟树
1、认识时钟树 H:high 高 L:low 低 S:speed 速度 I:internal 内部 E:external 外部 HSE就是高速外部时钟源 HSI就是告诉内部时钟源 外部时钟一般需要接一个时钟源,也就是晶振,这个需要外接&…...
致迈协创C1pro考勤系统简介
1.应用背景 该套件的“数据映射引擎”技术,完成了OA系统与考勤机硬件设备的无缝联接。V5具有良好交互特性和B/S的程序架构,使得客户管理层和HR相关管理人员通过V5能实时查询统计人员的考勤情况,从而及时有效的完成人员考勤的监控与管理&#…...
pivot_root:原理、用途及最简单 Demo
什么是 pivot_root pivot_root 是 Linux 系统中的一个系统调用(和对应的命令行工具),用于更改进程的根文件系统。与 chroot 类似,pivot_root 将一个指定目录设置为进程的新根目录(/),但它比 ch…...
【小沐杂货铺】基于Three.JS绘制卫星轨迹Satellite(GIS 、WebGL、vue、react,提供全部源代码)
🍺三维数字地球系列相关文章如下🍺:1【小沐学GIS】基于C绘制三维数字地球Earth(OpenGL、glfw、glut)第一期2【小沐学GIS】基于C绘制三维数字地球Earth(OpenGL、glfw、glut)第二期3【小沐学GIS】…...
MySQL -数据类型
博客主页:【夜泉_ly】 本文专栏:【暂无】 欢迎点赞👍收藏⭐关注❤️ 目录 前言数值类型intbitfloat 字符串charvarcharenum set 日期和时间类型 前言 在之前的操作篇, 我们用到的大多是DDL(数据定义语言)。 在建表时,…...
数据通信学习笔记之OSPF的邻居角色
邻居与邻接 OSPF 使用 Hello 报文发现和建立邻居关系 在以太网链路上,缺省时,OSPF 采用组播的形式发送 Hello 报文 (目的地址 224.0.0.5) OSPF Hello 报文中包含了路由器的 RouterID、邻居列表等信息。 邻居状态: 邻居:2-way 邻…...
2025第十六届蓝桥杯python B组满分题解(详细)
目录 前言 A: 攻击次数 解题思路: 代码: B: 最长字符串 解题思路: 代码: C: LQ图形 解题思路: 代码: D: 最多次数 解题思路: 代码: E: A * B Problem 解题思路&…...
计算机组成原理笔记(十七)——4.2定点加减运算
定点数的加减运算包括原码、补码和反码3种带符号数的加减运算,其中补码加减运算实现起来最方便。 4.2.1原码加减运算 原码加减运算详解 原码是计算机中表示数值的基本方式之一,其特点为最高位为符号位(0表正,1表负)…...
javase 学习
一、Java 三大版本 javaSE 标准版 (桌面程序; 控制台开发) javaME 嵌入式开发(手机、小家电)基本不用,已经淘汰了 javaEE E业级发开(web端、 服务器开发) 二、Jdk ,jre jvm 三…...