PyTorch处理数据--Dataset和DataLoader
在 PyTorch 中,Dataset
和 DataLoader
是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。
一、Dataset:数据集的抽象
Dataset
是一个抽象类,用于表示数据集的接口。你需要继承 torch.utils.data.Dataset
并实现以下两个方法:
__len__()
: 返回数据集的总样本数。__getitem__(idx)
: 根据索引idx
返回一个样本(数据和标签)。
示例:自定义 Dataset
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transform # 数据预处理/增强函数def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {"data": self.data[idx], "label": self.labels[idx]}if self.transform:sample = self.transform(sample)return sample
使用场景
- 加载图像、文本、表格数据等。
- 支持数据预处理(如归一化、裁剪)和数据增强(如随机翻转)。
二、 DataLoader:高效加载数据
DataLoader
负责将 Dataset
包装成一个可迭代对象,支持批量加载、多线程加速和数据打乱。
基本用法
from torch.utils.data import DataLoader# 假设 dataset 是你的 CustomDataset 实例
data_loader = DataLoader(dataset,batch_size=32, # 批量大小shuffle=True, # 是否打乱数据(训练时建议开启)num_workers=4, # 多线程加载数据的进程数drop_last=False # 是否丢弃最后不足一个 batch 的数据
)
遍历 DataLoader
for batch in data_loader:data = batch["data"] # 形状:[batch_size, ...]labels = batch["label"] # 形状:[batch_size]# 将数据送入模型训练...
三、pytorch内置数据集
PyTorch 提供了一系列内置数据集,这些数据集可以直接用于训练模型。这些数据集涵盖了多种领域,如图像、文本、音频等。以下是一些常用的PyTorch内置数据集:
图像数据集
-
MNIST: 手写数字数据集,包含0到9的手写数字图片。
from torchvision import datasets mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
-
CIFAR10/CIFAR100: 包含彩色图片的数据集,CIFAR10有60000张32x32的彩色图片,分为10个类别;CIFAR100类似但有100个类别。
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
-
ImageNet: 包含超过1400万张图片的非常庞大的数据集,常用于图像识别和分类任务。
import torchvision.datasets as datasets imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
-
STL10: 一个用于计算机视觉研究的小型图像数据集,包含96x96的彩色图片。
stl10_train = datasets.STL10(root='./data', split='train', download=True)
-
SVHN: 包含数字图片的数据集,与MNIST类似但包含更多实际场景的图片。
svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本数据集
1.Text8: 一个用于自然语言处理的小型文本数据集。
from torchtext.datasets import Text8
text8_train = Text8(split=('train',))
2. AG_NEWS: 包含新闻文章的文本数据集,分为4个类别。
from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))
音频数据集
1. Speech Commands: 一个用于语音识别的数据集,包含约65,000个单词发音的音频文件。
from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)
使用方法
要使用这些数据集,首先需要导入torchvision
(对于图像数据集)、torchtext
(对于文本数据集)或torchaudio
(对于音频数据集),然后使用其提供的类来加载数据。通常还包括一些数据预处理步骤,例如转换(transforms)。
import torchvision.transforms as transforms
from torchvision import datasetstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
四、完整代码示例
步骤 1:创建数据集
import numpy as np
from torch.utils.data import Dataset, DataLoader# 生成示例数据(假设是 10 个样本,每个样本是长度为 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,)) # 二分类标签class MyDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return {"data": self.data[idx],"label": self.labels[idx]}dataset = MyDataset(data, labels)
步骤 2:创建 DataLoader
data_loader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=2
)
步骤 3:使用 DataLoader 训练模型
model = ... # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch in data_loader:x = batch["data"]y = batch["label"]# 前向传播outputs = model(x)loss = loss_fn(outputs, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()
五、常见问题解决
(1)数据格式不匹配
- 问题:
DataLoader
返回的数据形状与模型输入不匹配。 - 解决:检查
Dataset
的__getitem__
返回的数据类型和形状,确保与模型输入一致。
(2)多线程加载卡顿
- 问题:设置
num_workers>0
时程序卡死或报错。 - 解决:在 Windows 系统中,多线程可能需要将代码放在
if __name__ == "__main__":
块中运行。
(3)数据增强
- 使用
torchvision.transforms
中的工具(如RandomCrop
、RandomHorizontalFlip
)对图像数据进行增强:from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]), ])
(4)内存不足
- 对于大型数据集,使用
torch.utils.data.DataLoader
的persistent_workers=True
(PyTorch 1.7+)或优化数据加载逻辑。
六、高级功能
- 分布式训练:使用
torch.utils.data.distributed.DistributedSampler
配合多 GPU。 - 预加载数据:使用
torch.utils.data.TensorDataset
直接加载 Tensor 数据。 - 自定义采样器:通过
sampler
参数控制数据采样顺序(如平衡类别采样)。
相关文章:
PyTorch处理数据--Dataset和DataLoader
在 PyTorch 中,Dataset 和 DataLoader 是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。 一、Dataset:数据集的抽象 Dataset 是一个抽象类,用于表示数据集的接口。你…...
Linux搭建NFS服务
1.概述 Network File System的缩写,它最大的功能是可以通过网络使用挂载的方式,让不同的机器、不同的操作系统可以共享彼此的文件 2.名称 软件名 nfs-utils服务名 nfs或者nfs-server 3.端口 nfs-server tcp/2049 负责建立连接 rpcbind tcp/111 负责…...
ubuntu服务器server版安装,ssh远程连接xmanager管理,改ip网络连接。图文教程
ventoy启动服务器版iso镜像,注意看server名称,跟之前desktop版ubuntu不一样。没有gui界面。好,进入命令行界面。语言彻底没汉化了,选英文吧,别的更看不懂。 跟桌面版ubuntu类似,选择是否精简系统࿰…...
GC overhead limit exceeded---Java 虚拟机 (JVM) 在进行垃圾回收内存量非常少解决
背景: 我正在跑一个数据处理较为复杂的程序。然后调试了很多遍,出现了GC问题,如下图bug. GC overhead limit exceeded-这个bug错误通常表示 Java 虚拟机 (JVM) 在进行垃圾回收时花费了过多的时间,并且回收的内存量非常少。…...
Pytorch学习笔记(十二)Learning PyTorch - NLP from Scratch
这篇博客瞄准的是 pytorch 官方教程中 Learning PyTorch 章节的 NLP from Scratch 部分。 官网链接:https://pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html 完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwdaa2m 提取码: …...
学习日记0327
A cross-domain knowledge tracing model based on graph optimal transport 我们使用gnn来学习这些节点的特征。在此基础上,我们使用显式分布距离度量对齐来自两个不同域的特征向量,旨在最小化域差异,实现最大的跨域知识转移。 AEGOT-CDKT…...
Postman 下载文件指南:如何请求 Excel/PDF 文件?
在 Postman 中进行 Excel/PDF 文件的请求下载和导出,以下是简明的步骤,帮助你轻松完成任务。首先,我们将从新建接口开始,逐步引导你完成整个过程。 Postman 请求下载/导出 excel/pdf 文件教程...
【HTML】验证与调试工具
个人主页:Guiat 归属专栏:HTML CSS JavaScript 文章目录 1. HTML 验证工具概述1.1 验证的重要性1.2 常见 HTML 错误类型 2. W3C 验证服务2.1 W3C Markup Validation Service2.2 使用 W3C 验证器2.3 验证结果解读 3. 浏览器开发者工具3.1 Chrome DevTools…...
头歌实践教学平台--【数据库概论】--SQL
一、表结构与完整性约束的修改(ALTER) 1.修改表名 USE TestDb1; alter table your_table rename TO my_table; 2.添加与删除字段 #语句1:删除表orderDetail中的列orderDate alter table orderDetail drop orderDate; #语句2:添加列unitPrice alter t…...
2025.03.27【基因分析新工具】| MAST:解锁基因表达差异分析与网络构建
文章目录 1. MAST工具简介:探索生物信息分析的新利器1.1 什么是MAST工具?1.2 MAST工具的优势1.3 MAST工具的应用场景 2. MAST的安装方法:轻松入门的第一步2.1 安装R语言环境2.2 安装MAST包2.3 安装依赖库 3. MAST常用命令:掌握数据…...
JVM - 垃圾回收基本问题
通过一些问题来讨论在 JVM 中,垃圾回收的一些基本问题 为什么要有垃圾回收?Java 垃圾回收中是如何判断一个对象死亡的?请简单介绍一下刚才说到了引用计数法,引用计数法存在什么问题?刚才说到了可达性分析,…...
Python 爬虫案例
以下是一些常见的 Python 爬虫案例,涵盖了不同的应用场景和技术点: 1. 简单网页内容爬取 案例:爬取网页标题和简介 import requests from bs4 import BeautifulSoup url "https://www.runoob.com/" response requests.get(url) …...
从零构建大语言模型全栈开发指南:第三部分:训练与优化技术-3.1.3分布式数据加载与并行处理(PyTorch DataLoader优化)
👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 3.1.3 分布式数据加载与并行处理(`PyTorch DataLoader`优化)1. 大规模数据加载的挑战与瓶颈分析1.1 数据加载流程的时间分解2. PyTorch DataLoader的深度优化策略2.1 核心参数调优2.2 分布式数据分片策…...
2025年- G31-Lc105-102. 二叉树层次遍历--java版
1.题目描述 2.思路 思路一: 使用 队列 Queue 来存储当前层的所有节点。关键点在于 levelSize queue.size() 这一行,它决定了当前层的节点数量。 3.代码实现 /*** Definition for a binary tree node.* public class TreeNode {* int val;* Tr…...
Redis 和 MySQL双写一致性的更新策略有哪些?常见面试题深度解答。
目录 一. 业务数据查询,更新顺序简要分析 二. 更新数据库、查询数据库、更新缓存、查询缓存耗时对比 2.1 更新数据库(最慢) 2.2 查询数据库(较慢) 2.3 更新缓存(次快) 2.4 查询缓存&#…...
【DFS】羌笛何须怨杨柳,春风不度玉门关 - 4. 二叉树中的深搜
本篇博客给大家带来的是二叉树深度优先搜索的解法技巧,在后面的文章中题目会涉及到回溯和剪枝,遇到了一并讲清楚. 🐎文章专栏: DFS 🚀若有问题 评论区见 ❤ 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是我不断创作的…...
【Exception】MybatisPlusException: can not find lambda cache for this entity
文章目录 环境 | Environment复现步骤 | Reproduction steps报错日志 | Error log源码 | Source CodeUserServiceImpl.javaAddressServiceImpl.javaAbstractSubTableBaseServiceImpl.javaUserEntity.javaAddressEntity.javaSubTableBaseEntity.java 原因分析 | Analysis解决方案…...
Spring Security 全面指南:从基础到高级实践
一、Spring Security 概述与核心概念 1.1 Spring Security 简介 Spring Security 是 Spring 生态系统中的安全框架,为基于 Java 的企业应用提供全面的安全服务。它起源于 2003 年的 Acegi Security 项目,2008 年正式成为 Spring 官方子项目,…...
IP组播 C++简单应用
引言 在当今的网络世界中,数据的传输效率和带宽的合理利用是至关重要的。传统的单播和广播通信方式在某些场景下存在着局限性,而IP组播技术的出现为解决这些问题提供了一种有效的方案。本文将详细介绍IP组播的概念、工作原理、应用场景,并通…...
CentOS 7安装 mysql
CentOS 7安装 mysql 1. yum 安装 mysql 配置mysql源 yum -y install mysql57-community-release-el7-10.noarch.rpm安装MySQL服务器 yum -y install mysql-community-server启动MySQL systemctl start mysqld.service查看MySQL运行状态,运行状态如图ÿ…...
“十五五”时期航空弹药发展环境分析
1.“十五五”时期航空弹药发展环境分析 (标题:小二号宋体居中) 一、建言背景介绍 (一级标题:黑体三号,首行空两格) 航空弹药作为现代战争的核心装备,其发展水平直接关乎…...
es6的100个问题
基础概念 解释 let、const 和 var 的区别。什么是块级作用域?ES6 如何实现它?箭头函数和普通函数的主要区别是什么?解释模板字符串(Template Literals)的用途,并举例嵌套变量的写法。解构赋值的语法是什么…...
在直播间如何和观众进行互动
在抖音直播间实现高效互动需要**技术话术工具**的立体化组合,以下是程序员可落地的深度互动方案: --- ### 一、技术驱动型互动策略 #### 1. **实时代码演示(硬核互动)** - **OBS虚拟摄像头屏幕共享** python # 用Flask创建实…...
mysql--用户管理
MySQL 用户管理完整指南 1. 查看用户信息 查看所有用户 SELECT User, Host, authentication_string FROM mysql.user;查看用户详细信息 SELECT * FROM mysql.user \G查看当前登录用户 SELECT CURRENT_USER();查看特定用户的权限 SHOW GRANTS FOR usernamehost;2. 创建用户…...
.NET三层架构详解
.NET三层架构详解 文章目录 .NET三层架构详解引言什么是三层架构表示层(Presentation Layer)业务逻辑层(Business Logic Layer,BLL)数据访问层(Data Access Layer,DAL) .NET三层架构…...
机器学习之回归
1. 引言 回归分析是机器学习中的基本技术之一,广泛用于预测连续型变量。本文调研了线性回归、多项式回归、岭回归、Lasso回归及弹性网络回归,重点分析其数学原理、算法推导、求解方法及应用场景。 2. 线性回归 2.1 概述 线性回归假设因变量与自变量之间存在线性关系,其目…...
危险化合物安全处理,有机反应淬灭操作解析
化学淬灭操作是指在化学反应过程中,通过人为干预快速终止反应的技术。在有机化学反应中,某一反应底物是过量的,当化学反应进行到一定程度,目标产物已经获得,该过量反应底物继续存在会进一步反应生成副产物或者影响后处…...
【前端】使用 HTML、CSS 和 JavaScript 创建一个数字时钟和搜索功能的网页
文章目录 ⭐前言⭐一、项目结构⭐二、HTML 结构⭐三、CSS 样式⭐四、JavaScript 功能⭐五、运行效果⭐总结 标题详情作者JosieBook头衔CSDN博客专家资格、阿里云社区专家博主、软件设计工程师博客内容开源、框架、软件工程、全栈(,NET/Java/Python/C)、数…...
【Linux】调试器——gdb使用
目录 一、预备知识 二、常用指令 三、调试技巧 (一)监视变量的变化指令 watch (二)更改指定变量的值 set var 正文 一、预备知识 程序的发布形式有两种,debug和release模式,Linux gcc/g出来的二进制…...
Windows10清理机器大全集
Windows10清理机器大全集 写在前面先这么个标题,逐渐补充禁止Update移除Microsoft Compatibility Telemetrywindows-defender-remover其它 写在前面 看到标题,读者已经就吐了。 我是说,我非常认可: IT从业者,如果你银子比较充足&…...
解决IDEA中maven找不到依赖项的问题
直接去官网找到对应的依赖项jar包,并且下载到本地,然后安装到本地厂库中。 Maven官网:https://mvnrepository.com/ 一、使用mvn install:install-file命令 Maven提供了install:install-file插件,用于手动将jar包安装到本地仓库…...
端游熊猫脚本游戏精灵助手2025游戏办公脚本工具!游戏脚本软件免费使用
在当下这个崇尚高效与便捷的时代,自动化工具已然成为诸多开发者与企业提升工作效率的关键选择。熊猫精灵脚本助手作为一款极具实力的自动化工具,凭借其多样的功能以及广泛的应用场景,逐步成为众多用户的首要之选。 熊猫精灵脚本助手整合了丰…...
知识就是力量——物联网应用技术
基础知识篇 一、常用电子元器件1——USB Type C 接口引脚详解特点接口定义作用主从设备关于6P引脚的简介 2——常用通信芯片CH343P概述特点引脚定义 CH340概述特点封装 3——蜂鸣器概述类型驱动电路原文链接 二、常用封装介绍贴片电阻电容封装介绍封装尺寸与功率关系࿱…...
第4.1节:使用正则表达式
1 第4.1节:使用正则表达式 将正则表达式用斜杠括起来,就能用作模式。随后,该正则表达式会与每条输入记录的完整文本进行比对。(通常情况下,它只需匹配文本的部分内容就能视作匹配成功。)例如,以…...
Linux目录及文件管理
目录 一.Linux目录基本结构 1.常见目录及其作用 二.常用文件处理命令 1.七类常见的linux的文件 2.cat(查看文件内容) 3.more(分页查看文件内容) 4.less(分页查看文件内容) 5.head(从头部查看文件内容࿰…...
【MySQL】从零开始:掌握MySQL数据库的核心概念(五)
由于我的无知,我对生存方式只有一个非常普通的信条:不许后悔。 前言 这是我自己学习mysql数据库的第五篇博客总结。后期我会继续把mysql数据库学习笔记开源至博客上。 上一期笔记是关于mysql数据库的增删查改,没看的同学可以过去看看…...
进军场景智能体,云迹机器人又快了一步
(图片来源:Pixels) 2025年,AI和机器人行业都发生了巨大改变。 数科星球原创 作者丨苑晶 编辑丨大兔 2025年,酒店行业正掀起一股批量采购具备AI功能的软硬一体解决方案的热潮。 在DeepSeek、Manus等国产AI软件的推动…...
【实战ES】实战 Elasticsearch:快速上手与深度实践-5.2.1 多字段权重控制(标题、品牌、类目)
👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 电商商品搜索实战:多字段权重控制策略1. 业务场景与核心挑战1.1 典型搜索问题1.2 权重失衡的影响数据 2. 权重控制核心方案2.1 字段权重分配矩阵2.2 多策略组合方…...
Ubuntu24.04 离线安装 MySQL8.0.41
一、环境准备 1.1 官方下载MySQL8.0.41 完整包 1.2 上传包 & 解压 上传包名称是:mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundle.tar # 切换到上传目录 cd /home/MySQL8 # 解压: tar -xvf mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundl…...
【Django】教程-3-数据库相关介绍
【Django】教程-1-安装创建项目目录结构介绍 【Django】教程-2-前端-目录结构介绍 4.数据库连接配置 需要手动创建数据库,数据库无法自动创建 ,ORM可以创建表,操作表 注意:负责app下mondels.py写类时,无法在数据库中…...
OpenGL绘制文本
一:QPainter绘制 在 OpenGL 渲染的窗口中(如 QOpenGLWidget),通过 QPainter 直接绘制文本。Qt 会自动将 2D 内容(文本、图形)与 OpenGL 内容合成。在paintGL()里面绘制,如果有其他纹理…...
DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加行拖拽排序功能示例6,TableView16_06 分页表格拖拽排序
前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…...
【解决】导入PNG图片,转 Sprite 格式成功但资产未生效问题
开发平台:Unity 6.0 图片格式:.png 问题描述 当 PNG 成功转换为 Sprite(精灵)时,资产状态将显示扩展箭头,即表明该资产可 Sprite 使用。 解决方法:设置正确的 Sprite Mode Single 关于 Spr…...
【科研绘图系列】R语言绘制重点物种进化树图(taxa phylogenetic tree)
禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理画图输出图片系统信息介绍 【科研绘图系列】R语言绘制重点物种进化树图(taxa phylogenetic tree) 加载R包 library(tidyverse) library(ape…...
Flutter入门教程:从零开始的Flutter开发指南
Flutter入门教程:从环境搭建到应用发布 概述 本文提供了全面的Flutter入门教程,涵盖环境搭建、基础Widget使用、界面设计与美化,以及实战项目开发等内容。通过本教程,开发者能够快速上手Flutter开发,掌握开发跨平台应…...
CentOS 7 源码安装libjsoncpp-1.9.5库
安装依赖工具 sudo yum install cmake make gcc cmake 需要升级至 3.8.0 以上可参考:CentOS安装CMakegcc 需要升级至9.0 以上可参考:CentOS 7升级gcc版本 下载源码 wget https://github.com/open-source-parsers/jsoncpp/archive/refs/tags/1.9.5.…...
调用高德天气Api,并展示对应天气图标
1、申请高德key 点击高德官网申请 必须有key才能调用高德api 小提示:每日/每秒调用api次数有限,尽量不要循环调用。 每日大概5000,每秒3次 2、查看文档 高德官网天气api接口文档 请求示例: https://restapi.amap.com/v3/weat…...
DSP开发板的JTAG接口
(1)普中DSP28335 (2)研旭DSP28388 (3)延华DSP28335 (3)M新动力28377D电机控制板...
1.25-20GHz/500ns超快跳频!盛铂SWFA300国产捷变频频率综合器模块赋能雷达/5G/电子战高频精密控制 本振/频综模块
盛铂SWFA300捷变频频率综合器模块简述: 盛铂科技国产SWFA300捷变频频率综合器是一款在频率范围内任意两点频率的跳频时间在500nS以内的高速跳频源,其输出频率范围为1.25GHz至20GHz,频率的最小步进为10kHz。同时它拥有优秀的相位噪声特性&…...
nestjs 多环境配置
这里使用yaml进行多环境配置,需要安装nestjs/config、js-yaml、types/js-yaml js-yaml、types/js-yaml 主要用来读取yaml文件以及指定类型使用 官方教程:Documentation | NestJS - A progressive Node.js framework 1、下载 npm i --save nestjs/confi…...