PyTorch DataLoader 参数详解
在使用 PyTorch 的 DataLoader
时,有许多参数可以调整,这些参数能够帮助我们平衡数据加载效率、内存使用和训练过程的稳定性。下面介绍几个常用参数,并讲解它们的作用:
-
dataset
- 含义: 数据集对象,必须实现
__len__
和__getitem__
方法。 - 作用: 定义数据的存储及如何获取单个数据样本。
- 示例: 在下例中,我们使用自定义的
SDDataset
类来加载数据。
- 含义: 数据集对象,必须实现
-
batch_size
- 含义: 每个批次加载的数据样本数。
- 作用: 控制模型每次前向传播时同时输入的数据量。过大可能会导致显存溢出,过小可能训练效率低。
- 示例:
batch_size=16
表示每次加载 16 个样本。
-
shuffle
- 含义: 是否在每个 epoch 开始时随机打乱数据的顺序。
- 作用: 能够帮助模型更好地收敛,防止固定数据顺序带来的潜在相关性问题。
- 示例: 一般在训练阶段设置为
shuffle=True
,而在测试阶段则可能关闭该选项。
-
sampler / batch_sampler
- 含义: 通过自定义采样器来控制数据的采样方式。
- 作用: 当数据集有不平衡问题或需要分布式采样时,可以自定义样本提取逻辑。
- 注意: 一旦指定了
sampler
,则不要同时设置shuffle
参数,否则会冲突。
-
num_workers
- 含义: 数据加载时启动的子进程数。
- 作用: 增加并行数据加载可以提升数据预处理速度,尤其在 CPU 运算较重的情况下。
- 示例:
num_workers=4
意味着会使用 4 个子进程并行加载数据。 - 注意: 在 Windows 下或某些嵌入式环境中,过多的子进程可能会导致问题。
-
pin_memory
- 含义: 是否将加载的数据复制到 CUDA 固定内存中。
- 作用: 加快 GPU 与 CPU 之间的数据传输,适用于 GPU 训练场景。
- 示例: 设置
pin_memory=True
用于加速训练。
-
drop_last
- 含义: 当数据集总数不能被批次大小整除时,是否丢弃最后一个不完整的 batch。
- 作用: 确保每个批次的样本数一致,特别是当模型中使用 BatchNorm 等层时较为重要。
- 示例:
drop_last=True
在训练过程中使用。
-
collate_fn
- 含义: 合并数据样本为一个 batch 的函数。
- 作用: 当数据样本结构不统一或者需要特殊处理(比如不同大小图像的归一化、文本数据的填充)时,可以自定义数据合并逻辑。
-
timeout
- 含义: 指定数据加载子进程等待数据的时间(秒)。
- 作用: 在数据加载缓慢或某些子进程异常时,可以在超时后终止等待,避免进程挂起。
-
prefetch_factor(从 PyTorch 1.7 开始支持)
- 含义: 每个 worker 在返回 batch 前预取的样本数。
- 作用: 提前准备数据以提高加载效率,但会占用更多内存。
-
persistent_workers(较新版本支持)
- 含义: 设置为
True
时,worker 进程在整个训练过程中保持活跃,避免每个 epoch 重新启动的开销。 - 作用: 适用于数据加载开销较大的场景,从而进一步提高效率。
- 含义: 设置为
示例:加载 SD 模型训练数据(图片 & JSON)
下面给出一个完整的代码示例,用于演示如何定义自定义数据集,并使用 DataLoader
加载训练 Stable Diffusion 模型时的图片和对应 JSON 注释数据。
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms# 自定义数据集类,用于加载图片和 JSON 数据
class SDDataset(Dataset):def __init__(self, image_dir, json_path, transform=None):"""image_dir: 存放图片文件的目录json_path: 存放图片对应标注(例如描述文本或其他元数据)的 JSON 文件路径transform: 对图片进行的预处理转换(例如 resize, normalization 等)"""self.image_dir = image_dirself.transform = transform# 加载 JSON 数据, 假设 JSON 文件格式为一个字典的列表,每个字典包含 "filename" 和 "label" 键with open(json_path, 'r') as f:self.data_info = json.load(f)def __len__(self):return len(self.data_info)def __getitem__(self, index):# 获取第 index 个数据样本的信息item = self.data_info[index]image_path = os.path.join(self.image_dir, item['filename'])# 打开并转换图片格式image = Image.open(image_path).convert('RGB')# 如果定义了预处理,则进行转换if self.transform:image = self.transform(image)# 读取标签数据,例如描述文本或其他元数据label = item['label']return image, label# 定义图片预处理操作
transform = transforms.Compose([transforms.Resize((256, 256)), # 调整图片大小transforms.ToTensor() # 转换为 Tensor 格式
])# 创建数据集实例
dataset = SDDataset(image_dir='/path/to/images', # 图片所在目录json_path='/path/to/annotations.json', # JSON 注释数据文件路径transform=transform
)# 创建 DataLoader
data_loader = DataLoader(dataset,batch_size=16, # 每批次加载 16 个样本shuffle=True, # 每个 epoch 前打乱数据顺序num_workers=4, # 使用 4 个子进程并行加载数据pin_memory=True, # 使用 CUDA 固定内存,加快数据传输速度drop_last=True # 如果最后一个批次不完整,则丢弃该批次
)# 遍历 DataLoader 示例
for batch_idx, (images, labels) in enumerate(data_loader):# 模型训练或其他操作print(f"Batch {batch_idx}:")print("Images shape:", images.shape)print("Labels:", labels)
代码说明
• 自定义 Dataset:
在 SDDataset 类中,init 方法加载 JSON 文件,getitem 方法则根据 JSON 中提供的文件名加载图片并返回对应的标签数据。这样可以方便地将图片和与之相关的元信息(如描述文本)同时提供给训练过程。
• 图片预处理:
通过 transforms.Compose 进行图片大小调整和转换为 Tensor,确保图片满足模型输入格式要求。
• DataLoader 参数设置:
• batch_size=16 控制每个批次的样本数量
• shuffle=True 在每个 epoch 打乱数据顺序,提升训练效果
• num_workers=4 利用多进程加速数据加载
• pin_memory=True 优化 GPU 数据传输
• drop_last=True 确保每个批次样本数量一致
总结
使用 DataLoader
时,可以根据具体任务和硬件环境调整主要参数:
- batch_size:控制每个批次的样本数;
- shuffle:决定数据顺序是否随机打乱;
- num_workers:通过多进程加速数据加载;
- pin_memory:用于加快 GPU 数据传输;
- drop_last:确保批次样本数一致;
- collate_fn:自定义 batch 拼接逻辑。
在训练 Stable Diffusion(SD)模型时,通过自定义 Dataset
类整合图片和 JSON 数据,并合理设置 DataLoader
的参数,可以显著提升数据加载效率和训练稳定性。
💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!
相关文章:
PyTorch DataLoader 参数详解
在使用 PyTorch 的 DataLoader 时,有许多参数可以调整,这些参数能够帮助我们平衡数据加载效率、内存使用和训练过程的稳定性。下面介绍几个常用参数,并讲解它们的作用: dataset 含义: 数据集对象,必须实现 …...
PowerBI 计算时间用EDATE
我在原表基础上,根据日期字段,计算去年时间 CONCATENATEX(DISTINCT(SELECTCOLUMNS(VALUES(日期表),"去年", FORMAT(DATEADD([日期], -1, YEAR), "yyyyMM"))), [YearMonth],",") 我发现很奇怪的现象,假如某个日…...
GRBL运动控制算法(四)加减速运算
前言 在数控系统和运动控制领域,GRBL 作为一款高效、轻量化的开源固件,因其卓越的性能和简洁的架构被广泛应用于各类嵌入式运动控制场景。GRBL加减速算法的实现尤为关键——它直接决定了运动控制的精度、效率与设备稳定性。 本文深入解析加减速运算的核…...
CSS 学习提升网站或者项目
有几个不错的开源项目可以帮助你练习和提升CSS技能: CSS-Tricks CSS-Tricks 提供了很多关于CSS的技巧和教程,可以通过实践它们来提高CSS技能。你可以在CSS-Tricks上找到很多有趣的项目和代码示例。 Frontend Mentor Frontend Mentor 是一个非常适合练习…...
PolarDB 读已提交事务隔离级别 select ... for update, where条件未用索引,查不到数据的时候不会锁表
由于没有给字段设置唯一性,所以改为通过查询语句加锁确保唯一性,但是发现select count(*) 为0时,不会加锁,所以在insert方法后面需要加锁二次查询确保唯一性。 在 PolarDB 的读已提交事务隔离级别下,SELECT ... FOR UP…...
Python基础——Matplotlib库
绘图基础 Matplotlib 库太大,画图通常仅仅使用其中的核心模块 matplotlib.pyplot,并给其一个别名 plt,即 import matplotlib.pyplot as plt。为了使图形在展示时能很好的嵌入到 Jupyter 的 Out[ ] 中,需要使用%matplotlib inline…...
群晖Hyper Backup备份的东西怎么还原?
一、背景 前面写了一篇文章关于群晖NAS中最简单的备份方案,Hyper Backup 方案 群晖NAS最简单的备份教程(只备份需要的目录到不同的硬盘),留了个尾,即怎么还原备份的东西,这里完结一下。 二、还原方案 2.…...
记录IBM服务器检测到备份GPT损坏警告排查解决过程
服务器设备:IBM x3550 M4 Server IMM默认IP地址:192.168.70.125 用户名:USERID 密码:PASSW0RD(注意是零0) 操作系统:Windows Hyper-V Server 2016 IMM Web System Status Warning࿱…...
蓝桥杯嵌入式十五届模拟二(串口DMA,占空比的另一种测量方式)
一.LED 先配置LED的八个引脚为GPIO_OutPut,锁存器PD2也是,然后都设置为起始高电平,生成代码时还要去解决引脚冲突问题 二.按键 按键配置,由原理图按键所对引脚要GPIO_Input 生成代码,在文件夹中添加code文件夹&#…...
22 | 如何继续提升 Go 开发技术?
提示: 所有体系课见专栏:Go 项目开发极速入门实战课;欢迎加入 云原生 AI 实战营 星球,12 高质量体系课、20 高质量实战项目助你在 AI 时代建立技术竞争力(聚焦于 Go、云原生、AI Infra)。 「Go 项目开发极速…...
一文详解OpenCV环境搭建:Windows使用CLion配置OpenCV开发环境
在计算机视觉和图像处理领域,OpenCV 是一个不可或缺的工具。其为开发者提供了一系列广泛的算法和实用工具,支持多种编程语言,并且可以在多个平台上运行。对于希望在其项目中集成先进视觉功能的开发者来说,掌握如何配置和使用OpenC…...
云原生周刊:深入探索 kube-scheduler-simulator
开源项目推荐 mcp-server-kubernetes mcp-server-kubernetes 是一个实现了模型上下文协议(MCP)的服务器,旨在通过自然语言与 K8s 集群进行交互。它支持连接到 K8s 集群,列出所有 Pod、服务、部署和节点,创建、描述、…...
总结一下常见的EasyExcel面试题
说一下你了解的POI和EasyExcel POI(Poor Obfuscation Implementation):它是 Apache 软件基金会的一个开源项目,为 Java 程序提供了读写 Microsoft Office 格式文件的功能,支持如 Excel、Word、PowerPoint 等多种文件格…...
【Java设计模式】第2章 UML急速入门
2-1 本章导航 UML类图与时序图入门 UML定义 统一建模语言(Unified Modeling Language):第三代非专利建模语言。特点:开放方法,支持可视化构建面向对象系统,涵盖模型、流程、代码等。UML分类(2.2版本) 结构式图形:系统静态建模(类图、对象图、包图)。行为式图形:事…...
Excel处理控件Spire.XLS系列教程:C# 设置 Excel 中的数字格式
在 Excel 工作表中,原始数据通常显示为缺乏直观性的普通数字。通过设置数字格式,可以将这些数字转换成更容易理解的形式。例如,将销售额数据设置为货币格式,即添加货币符号和千位分隔符,可使所代表的金额一目了然。将市…...
脚本启动 Java 程序
如果你想在后台启动一个 Java 程序,并在终端窗口中显示一个自定义的名字,可以通过编写一个简单的脚本来实现。以下是一个基于 Linux/macOS 的解决方案,使用 Bash 脚本启动 Java 程序,并在终端窗口中显示自定义标题。 示例脚本 创建…...
UniappX动态引入在线字体图标,不兼容css时可用。
优缺点 优点:不需要占用本地存储,可直接在线同步库图标,不用再手动引入ttf文件,不用手动添加键值对对应表。 缺点:受网速影响,字体库cdn路径可能会更改,ios端首次加载,可能会无图标…...
机器学习 | 强化学习基本原理 | MDP | TD | PG | TRPO
文章目录 📚什么是强化学习🐇监督学习 vs 强化学习🐇马尔科夫决策过程(MDP)📚基本算法(value-based & policy-based)🐇时序差分算法(TD)🐇SARSA和Q-learning🐇策略梯度算法(PG)🐇REINFORCE和Actor-Critic🐇信任区域策略优化算法(TRPO)学习视频…...
k8s之Service类型详解
1.ClusterIP 类型 2.NodePort 类型 3.LoadBalancer 类型 4.ExternalName 类型 类型为 ExternalName 的 Service 将 Service 映射到 DNS 名称,而不是典型的选择算符, 例如 my-service 或者 cassandra。你可以使用 spec.externalName 参数指定这些服务…...
AI平台如何实现推理?数算岛是一个开源的AI平台(主要用于管理和调度分布式AI训练和推理任务。)
数算岛是一个开源的AI平台,主要用于管理和调度分布式AI训练和推理任务。它基于Kubernetes构建,支持多种深度学习框架(如TensorFlow、PyTorch等)。以下是数算岛实现模型推理的核心原理、架构及具体实现步骤: 一、数算岛…...
linux开发环境
1.虚拟机环境搭建 在 Ubuntu 系统中,打开(如图中显示的窗口 )常见快捷键有: Ctrl Alt T:这是最常用的打开终端的快捷键组合 ,按下后会快速弹出一个新的终端窗口。 在 VMware 虚拟机环境中,若…...
OSPF复习
OSPF OSPF---开放最短路径优先协议 动态路由判定依据:选路,收敛速度,占用资源 OSPFV2和RIPV2的相同点: 1.都是无类别的路由协议; 2.都是通过组播来传播信息的;(RIP:224.0.0.9&am…...
AWS S3深度剖析:云存储的瑞士军刀
1. 引言 在当今数据驱动的世界中,高效、可靠、安全的数据存储解决方案至关重要。Amazon Simple Storage Service (S3)作为AWS生态系统中的核心服务之一,为企业和开发者提供了一个强大而灵活的对象存储平台。本文将全面解析S3的核心特性,帮助读者深入理解如何充分利用这一&q…...
pyTorch中 tensorboard的使用
目录 01.导包、 transforms数据转化、torchvision数据集、创建dataloaders、展示图片的封装函数 02定义模型 03定义损失函数与优化器 1.tensorboard的安装 2.tensorboard的使用 2.1添加图片 2.2 添加模型结构图 2.3 添加损失的变化 #pyTorch中的tensorboard 与 tens…...
Android audio(2)-audioservice
AudioService是Android的系统服务(systemservice),由SystemServer负责启动。提供Android APK 所需的非数据通路(playback/capture)相关的audio 功能实现,是binder通信中的server端,与之对应的 C…...
星城幻境:科技与千年文脉的交响诗-长沙
故事背景 故事发生在中国湖南长沙,通过六个充满未来感的城市景观,展现人工智能修复古建筑、生态摩天楼、全息水幕许愿等场景,描绘科技赋能下历史文脉与未来城市的共生图景。 故事内容 从岳麓书院清晨的智能修复到湘江夜空的数字烟花ÿ…...
记录学习的第二十三天
老样子,每日一题开胃。 我一开始还想着暴力解一下试试呢,结果不太行😂 接着两道动态规划。 这道题我本来是想用最长递增子序列来做的,不过实在是太麻烦了,实在做不下去了。 然后看了题解,发现可以倒着数。 …...
sql-labs靶场 less-1
文章目录 sqli-labs靶场less 1 联合注入 sqli-labs靶场 每道题都从以下模板讲解,并且每个步骤都有图片,清晰明了,便于复盘。 sql注入的基本步骤 注入点注入类型 字符型:判断闭合方式 (‘、"、’、“”…...
AI-人工智能-基于LC-MS/MS分子网络深度分析的天然产物成分解析的新策略
Anal Chem∣张卫东教授团队开发基于LC-MS/MS分子网络深度分析的天然产物成分解析的新策略 2024年9月23日,海军军医大学张卫东教授团队在Analytical Chemistry(IF6.7)上发表了题为“In-Depth Analysis of Molecular Network Based on Liquid …...
IntelliJ IDEA使用技巧(json字符串格式化)
文章目录 一、IDEA自动格式化json字符串二、配置/查找格式化快捷键 本文主要讲述idea中怎么将json字符串转换为JSON格式的内容并且有层级结构。 效果: 转换前: 转换后: 一、IDEA自动格式化json字符串 步骤一:首先创建一个临…...
【Java设计模式】第8章 单列模式讲解
8-1 单例模式讲解 定义与类型 定义:保证一个类仅有一个实例,并提供一个全局访问点。类型:创建型模式。适用场景 需要确保任何情况下绝对只有一个实例。实际应用: 网站计数器(单服务)。应用配置、线程池、数据库连接池。优点 减少内存开销(仅一个实例)。避免资源多重占…...
【Java设计模式】第4章 简单工厂讲解
4. 简单工厂模式 4.1 简单工厂讲解 定义:由一个工厂对象决定创建哪种产品类的实例,属于创建型模式,但不属于GoF 23种设计模式。适用场景: 工厂类负责创建的对象较少。客户端仅需传入参数,无需关心对象创建逻辑。优点: 客户端只需传入参数即可获取对象,无需知道创建细节…...
Spring Boot 常用依赖介绍
依赖总括 1. 核心依赖:Spring Web、Spring Data JPA、MySQL Driver。 2. 开发工具:Lombok、Spring Boot DevTools。 3. 安全与权限:Spring Security。 4. 测试与文档:Spring Boot Starter Test、Swagger。 5. 性能优化&#…...
判断矩阵A是否可以相似对角化
【例题1】 【例题2】...
第三方软件测试公司进行安全性测试有哪些好处?
在信息技术飞速发展的今天,软件已成为各行业运作的核心组成部分。然而,伴随而来的软件安全问题也愈发显著,因此软件产品安全性测试不容忽视。随着软件市场的激烈竞争,企业为了更好的专心产品开发,会将安全性测试服务交…...
下一代楼宇自控的中枢神经:ARM终端的生态
某跨国半导体工厂的洁净车间突然触发气体泄漏报警。此时,ARM应急广播终端在200毫秒内完成全楼宇语音播报,同步联动门禁系统解锁逃生通道,指挥中心大屏自动弹出事故区域监控画面——这套价值27万元的预警系统,在投产首年就避免了可…...
R语言进行判别分析
Fisher判别法、距离判别法、Bayes判别法基本理论、方法: Fisher判别法 理论基础: Fisher判别法旨在通过选择合适的投影方向,最大化不同类别之间的类间差异性,同时最小化类内差异性。这种投影方向使得在低维空间中样本点的类别可…...
Nacos 服务发现的流程是怎样的?客户端如何获取最新的服务实例列表?
服务发现是微服务架构的核心组件,它允许一个服务(消费者)动态地找到它需要调用的另一个服务(提供者)的网络地址(IP 和端口),而无需硬编码这些地址。 整体流程概览: 服务提供者 (Pr…...
Java全栈项目--校园快递管理与配送系统(5)
源代码续 <template><div class"app-container"><el-card class"box-card"><div slot"header" class"clearfix"><span>通知统计</span><div class"header-operations"><el-d…...
UE5 本地化
文章目录 打开本地化面板设置本地化翻译设置文本收集路径添加语言收集需要翻译的文本手动翻译导入导出编译 使用本地化启动代码修改语言 打开本地化面板 UE4: UE5: 设置本地化翻译 设置文本收集路径 UE5可以自动帮我们收集需要显示的文本ÿ…...
用c语言写一个linux进程之间通信(聊天)的简单程序
使用talk 用户在同一台机器上talk指令格式如下: talk 用户名ip地址 [用户终端号] 如果用户只登录了一个终端,那么可以不写用户终端号,如: talk userlocalhost可以使用who指令来查看当前有哪些用户登录,他的终端号…...
同时支持Vue2/Vue3的图片懒加载组件(支持懒加载 v-html 指令梆定的 html 内容)
🚀 vue-lazyload-imgs(LazyLoadImgs) 组件简介 详情见:https://npmjs.com/package/vue-lazyload-imgs 安装方法: npm i vue-lazyload-imgs(不要安装为开发依赖,应为产品依赖) 适用环…...
Qt容器类在元对象系统中使用
解释 “QVector没有被注册到Qt的元对象系统中”这句话的意思是:QVector<double>这种数据类型没有被Qt的元对象系统(Meta-Object System)识别和管理。Qt的元对象系统是Qt框架的核心部分,它提供了信号与槽机制、动态属性系统…...
Qt中的信号与槽及其自定义
信号源:哪个控件发的信号 信号的类型:用户进行不同的操作就会触发不同的信号 如点击按钮,在输入框移动光标,勾选一个复选框,选 择一个下拉框 信号的处理方式:槽(slot)----也就是函数,Qt中用con…...
【已完结STM32】--自学江协科技笔记汇总
以下学习笔记代码均来自b站江协科技视频 笔记汇总完结 文章笔记对应江科大视频新建工程【2-2】新建工程江科大STM32-GPIO输出 点亮LED,LED闪烁,LED流水灯,蜂鸣器(学习笔记)_unit32-t rcc-apb2periph-CSDN博客 【3-1】…...
科技快讯 | 索诺瓦携手清华大学共筑听力无障碍未来;中国探月工程总设计师:未来月球上能打电话;Shopify要求员工证明AI无法取代其工作
索诺瓦携手清华大学共筑听力无障碍未来 2024年末,60岁以上人口超3.1亿,听力损失比例高达11%。清华大学无障碍发展研究院与索诺瓦集团深化合作,共同推动听力无障碍环境建设。2023年9月,《无障碍环境建设法》实施,2024年…...
[实战] 天线阵列波束成形原理详解与仿真实战(完整代码)
天线阵列波束成形原理详解与仿真实战 1. 引言 在无线通信、雷达和声学系统中,波束成形(Beamforming)是一种通过调整天线阵列中各个阵元的信号相位和幅度,将电磁波能量集中在特定方向的技术。其核心目标是通过空间滤波增强目标方…...
北京自在科技:让万物接入苹果Find My网络的″钥匙匠″
在AirTag掀起全球防丢热潮的今天,越来越多的第三方产品开始接入苹果Find My网络——从充电宝到电动车,从行李箱到保温杯,用户只需打开iPhone的「查找」App,就能实时定位这些物品。 北京自在科技有限责任公司早在苹果推出Find My开…...
区块链是怎么存储块怎么找到前一个块
前言:学习区块链的过程中在想怎么管理区块链呢 📌 推荐项目回顾: 👉 Jeiwan 的 blockchain_go 项目 GitHub 地址:https://github.com/Jeiwan/blockchain_go ❓它是怎么存储区块 & 找前一个区块的? 项…...
聚类算法 ap 聚类 谱聚类
AP聚类(Affinity Propagation Clustering)是一种基于消息传递的聚类算法,由Brendan J. Frey和Delbert Dueck于2007年提出。与传统的聚类算法(如K-Means)不同,AP聚类不需要预先指定聚类数量,而是…...