PyTorch系列教程:评估和推理模式下模型预测
使用PyTorch时,将模型从训练阶段过渡到推理阶段是至关重要的一步。在推理过程中,该模型用于对以前从未见过的新数据进行预测。这种转换的一个重要方面是使用推理模式,它通过禁用仅在训练期间需要的操作来帮助优化模型的性能。
理解推理模式
训练期间需要的某些特征(例如autograd相关操作)来加快张量计算速度。当使用推理模式时,不构建计算图,这减少了内存的使用,并加快了前传过程中的内存分配和释放。
推理模式的引入是对传统 torch.no_grad() 上下文的一种改进,专门针对那些已知不需要梯度的应用场景。虽然这两种方法都能通过不存储梯度信息来节省内存,但推理模式更进一步,进行了更多的优化。
设置评估模式
PyTorch提供了一种简化的方式来设置评估模式,这可以使用模型的.eval()方法来完成。该方法将模型的状态从训练模式切换到评估模式。这里有一个快速的演示:
import torch
import torch.nn as nn# Define a simple neural network
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):return self.linear(x)# Initialize the model
model = SimpleModel()# Set model to evaluation mode
model.eval()
通过调用model.eval()
,某些层(如dropout和批处理归一化)的行为将与训练期间不同,从而确保模型的预测尽可能准确。
禁用梯度计算
在推理过程中,不需要计算梯度,可以节省计算资源,加快评估速度。PyTorch提供了上下文管理器torch.no_grad()
来关闭这些计算:
input_tensor = torch.randn(1, 10) # Example input# Disabling gradient computation
with torch.no_grad():output = model(input_tensor)print(output)
使用torch.no_grad(), PyTorch
在调用模型时防止跟踪历史和未来的计算,使评估过程更快,内存更高效。
设置推理模式
在PyTorch中,实现推理模式非常简单。PyTorch库提供的key方法是torch.inference_mode()。以下是如何利用此功能:
import torch# Sample PyTorch Model
class SimpleModel(torch.nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = torch.nn.Linear(1, 1)def forward(self, x):return self.linear(x)# Initialize model
model = SimpleModel()# Data for making prediction
input_data = torch.tensor([[5.0]])# Using Inference Mode
with torch.inference_mode():prediction = model(input_data)print(f"Prediction: {prediction.item()}")
在上面的代码中,我们首先定义一个简单的线性模型。通过将预测代码逻辑包装在with torch.inference_mode()块中,你可以指示PyTorch在推理模式下运行,从而提高执行速度并减少所定义操作的内存占用。
理解模型输出
PyTorch
模型输出原始分数,而不是概率。为了将这些输出转换成适合解释的形式(例如,用于分类的softmax),通常需要额外的步骤:
# Using the same output from the model
# Apply Softmax to convert scores into probabilities
probabilities = torch.softmax(output, dim=1)# Getting the predicted class
_, predicted_class = torch.max(probabilities, 1)print('Predicted probabilities:', probabilities)
print('Predicted class:', predicted_class)
torch.softmax
函数通常用于将模型输出转换为分类任务,torch.max
决定了最可能类别的概率。
加载预训练模型
在探索推理时,通常使用PyTorch
社区提供的预训练模型。让我们看看如何加载预训练的模型,比如ResNet:
from torchvision import models# Load a pre-trained ResNet model
resnet_model = models.resnet18(pretrained=True)# Set the model to eval mode
resnet_model.eval()
-
models.resnet18
:ResNet(残差网络)是由微软研究院的Kaiming He等人在2015年提出的深度卷积神经网络架构。ResNet通过引入“残差连接”解决了深层网络训练中的梯度消失问题,使得训练更深的网络成为可能。resnet18
是ResNet系列中的一种,包含18层(包括卷积层和全连接层)。
-
参数
pretrained=True
:- 当设置为
True
时,表示加载在ImageNet数据集上预训练好的模型权重。这些权重是通过在大规模ImageNet数据集上训练得到的,能够有效地提取图像特征。 - 预训练模型对于迁移学习非常有用,可以在新的任务上快速适应和提升性能,而无需从头开始训练模型。
- 当设置为
这段代码的主要目的是加载一个在ImageNet上预训练好的ResNet-18模型,并将其设置为评估模式,以便在后续步骤中进行推理(如图像分类)而不进行参数更新。这在迁移学习、特征提取或任何需要使用预训练模型进行预测的任务中非常常见。这些模型加载完成后即可用于推理,这极大地加快了开发周期,因为它们允许在无需从头构建模型的情况下对最先进的架构进行实验。
比较推理模式和无梯度模式
虽然torch.no_grad()
和 torch.inference_mode()
都通过防止梯度跟踪来优化模型的推理过程,但它们在用例和效率上存在差异。考虑一下这个简单的时间比较:
import time# Experiment setup
iterations = 1000# Measure with no_grad
start_time = time.time()
with torch.no_grad():for _ in range(iterations):_ = model(input_data)
end_time = time.time()
print(f'Using no_grad: {end_time - start_time:.5f} seconds')# Measure with inference_mode
start_time = time.time()
with torch.inference_mode():for _ in range(iterations):_ = model(input_data)
end_time = time.time()
print(f'Using inference_mode: {end_time - start_time:.5f} seconds')
上面的脚本说明了如何比较两种模式下多次运行的时间,以查看执行时间的差异。通常,期望torch.inference_mode()通过比torch.no_grad()更有效地执行来提供更好的性能。
- 用例和注意事项
在大多数生产场景中,应该使用Inference Mode,特别是在部署预测速度至关重要的模型时。这包括实时数据处理应用程序、嵌入式设备和在资源受限环境中运行的应用程序。但是,请记住,这不会阻止自动梯度,因此不应该在需要计算梯度的情况下使用它,例如在训练期间或当您仍然需要进一步操作的梯度信息时。
最后总结
评估模式是在实际应用程序中有效部署和利用PyTorch模型的一个关键方面。通过利用model.eval()、torch.no_grad(),并了解如何处理模型输出,你可以显著地最大化机器学习应用程序的性能。
将推理模式集成到PyTorch应用程序中可以显著提高性能。随着模型变得越来越复杂,数据集越来越大,优化预测时间的需求变得越来越重要。凭借PyTorch API的简单性,利用这些优化有效地与未来高效,高性能的机器学习模型保持一致。
相关文章:
PyTorch系列教程:评估和推理模式下模型预测
使用PyTorch时,将模型从训练阶段过渡到推理阶段是至关重要的一步。在推理过程中,该模型用于对以前从未见过的新数据进行预测。这种转换的一个重要方面是使用推理模式,它通过禁用仅在训练期间需要的操作来帮助优化模型的性能。 理解推理模式 …...
Linux注册进程终止处理函数
atexit() 是一个标准库函数,用于注册在进程正常终止时要调用的函数。通过 atexit(),你可以确保在程序结束时自动执行一些清理工作,比如释放资源、保存状态等。 函数原型如下: #include <stdlib.h> int atexit(void (*func…...
Lumerical INTERCONNECT 中的自相位调制 (SPM)
一、自相位调制的数学介绍 A.非线性薛定谔方程(NLSE): NLSE 是光学中的一个关键方程。它告诉我们光脉冲在具有非线性和色散特性的介质中的行为方式。该方程如下所示: i ∂A/∂z β2/2 ∂A/∂t γ|A|A 0 其中: - …...
DICOM服务中的C-STORE、 C-FIND、C-MOVE、C-GET、Worklist
DICOM服务说明 DICOM(Digital Imaging and Communications in Medicine)是一种用于处理、存储、打印和传输医学影像的标准。DICOM定义了多种服务类,其中C-STORE、C-FIND、C-MOVE和C-GET是与影像数据查询和检索相关的四个主要服务类ÿ…...
Python的pdf2image库将PDF文件转换为PNG图片
您可以使用Python的pdf2image库将PDF文件转换为PNG图片。以下是一个完整的示例,包含安装步骤、代码示例和注意事项。 安装依赖库 首先,您需要安装pdf2image库: pip install pdf2imagepdf2image依赖于poppler库来解析PDF文件。 Windows系统…...
在Blender中给SP分纹理组
在Blender中怎么分SP的纹理组/纹理集 其实纹理组就是材质 把同一组的材质分给同一组的模型 导入到sp里面自然就是同一个纹理组 把模型导入SP之后 就自动分好了...
import模块到另一个文件夹报错:ModuleNotFoundError: No module named xxx
1. 问题 打开项目文件夹my_code,将bb.py的函数或者类import到aa.py中,然后运行aa.py文件,可能会报错ModuleNotFoundError: No module named xxx。 E:\Desktop\my_code ├── a │ ├── train.sh │ └── aa.py └── b└── b…...
[SystemVerilog]例化
SystemVerilog 的例化方式和Verilog 类似 如果信号输入输出name一致 abc abc_inst( .a(a), .b(b), c(c) ); 使用SystemVerilog abc abc_inst( .a, .b, .c ); 或者 abc abc_inst( .* ); 在SystemVerilog中,可以简化例化方式。 可以使用…...
Java方法详解
Java方法详解 方法1.方法的概念(1).什么是方法(2).方法的定义(3).实参与形参的关系 2.方法重载(1).方法重载的概念 3.递归(C语言详细讲过) 方法 1.方法的概念 (1).什么是方法 方法类似于C语言中的函数,我们重在体会与理解,不必…...
springboot自动插入创建时间和更新时间到数据库
springboot自动插入创建时间和更新时间到数据库 1.添加TableField注解2.添加TimeMetaObjectHandler配置3.测试 1.添加TableField注解 /*** 创建时间*/TableField(fill FieldFill.INSERT) // 插入时生效private LocalDateTime createTime;/*** 修改时间*/TableField(fill Fiel…...
如何将JAR交由Systemctl管理?
AI越来越火了,我们想要不被淘汰就得主动拥抱。推荐一个人工智能学习网站,通俗易懂,风趣幽默,最重要的屌图甚多,忍不住分享一下给大家。点击跳转到网站 废话不多说,进入正题。下面开始说如何使用 systemctl…...
VMware Workstation Pro安装openKylin 2.0全流程指南
原文链接:VMware Workstation Pro安装openKylin 2.0全流程指南 Hello,大家好啊!今天给大家带来一篇在VMware Workstation Pro 上安装 openKylin 2.0 SP1 的文章。openKylin 2.0 作为国产开源桌面操作系统,目前已经发布了最新版本&…...
网络安全检查漏洞内容回复 网络安全的漏洞
网络安全的核心目标是保障业务系统的可持续性和数据的安全性,而这两点的主要威胁来自于蠕虫的暴发、黑客的攻击、拒绝服务攻击、木马。蠕虫、黑客攻击问题都和漏洞紧密联系在一起,一旦有重大安全漏洞出现,整个互联网就会面临一次重大挑战。虽…...
数据仓库的特点
数据仓库的主要特点可以概括为:面向主题、集成性、非易失性、时变性、高性能和可扩展性、支持复杂查询和分析、分层架构以及数据质量管理。 1. 面向主题(Subject-Oriented) 数据仓库是面向主题的,而不是面向事务的。这意味着数据…...
02_NLP文本预处理之文本张量表示法
文本张量表示法 概念 将文本使用张量进行表示,一般将词汇表示为向量,称为词向量,再由各个词向量按顺序组成矩阵形成文本表示 例如: ["人生", "该", "如何", "起头"]># 每个词对应矩阵中的一个向量 [[1.32, 4,32, 0,32, 5.2],[3…...
青蛙跳杯子(BFS)
#include <iostream> #include <queue> #include <string> #include <unordered_set> using namespace std;int main() {string a, b;cin >> a >> b; int n a.size(); // 字符串长度int d[] {1, -1, -2, 2, -3, 3}; // 跳跃距离queue&…...
【前端基础】1、HTML概述(HTML基本结构)
一、网页组成 HTML:网页的内容CSS:网页的样式JavaScript:网页的功能 二、HTML概述 HTML:全称为超文本标记语言,是一种标记语言。 超文本:文本、声音、图片、视频、表格、链接标记:由许许多多…...
Arm64架构的Linux服务器安装jdk8
一、下载 JDK8 打开浏览器,访问 oracle官网找到适用于自己服务器的 arm64 架构的 JDK8 安装包。 二、安装 JDK8 将下载好的 JDK 压缩包上传到服务器上 解压 JDK 压缩包: tar -zxvf jdk-8uXXX-linux-arm64.tar.gz选择安装目录,我将 JDK 安装…...
深入探索Python机器学习算法:模型调优
深入探索Python机器学习算法:模型调优 文章目录 深入探索Python机器学习算法:模型调优模型调优1. 超参数搜索方法1.1 网格搜索(Grid Search)1.2 随机搜索(Random Search)1.3 贝叶斯优化(Bayesia…...
【Linux】冯诺依曼体系结构-操作系统
一.冯诺依曼体系结构 我们所使用的计算机,如笔记本等都是按照冯诺依曼来设计的: 截止目前,我们所知道的计算机都是由一个一个的硬件组装起来的,这些硬件又由于功能的不同被分为了输入设备,输出设备,存储器…...
Linux第五讲----gcc与g++,makefile/make
1.代码编译 1.1预处理 我们通过vim编辑完文件之后,想看一下运行结果这时我们便可以试用gcc编译C语言,g编译c. 编译代码: 上述两种方法均可,code.c是我的c语言文件,mycode是我给编译后产生的二进制文件起的名&#x…...
FastGPT 源码:基于 LLM 实现 Rerank (含Prompt)
文章目录 基于 LLM 实现 Rerank函数定义预期输出实现说明使用建议完整 Prompt 基于 LLM 实现 Rerank 下边通过设计 Prompt 让 LLM 实现重排序的功能。 函数定义 class LLMReranker:def __init__(self, llm_client):self.llm llm_clientdef rerank(self, query: str, docume…...
Virtual Box虚拟机安装Mac苹果Monterey和big sur版本实践
虚拟机安装苹果实践,在Windows10系统,安装Virtual Box7.1.6,安装虚拟苹果Monterey版本Monterey (macOS 12) 。碰到的主要问题是安装光盘不像Windows那么容易拿到,而且根据网上很多文章制作的光盘,在viritualBox里都无法…...
【高并发】Java 并行与串行深入解析:性能优化与实战指南
Java 并行与串行深入解析:性能优化与实战指南 在高性能应用开发中,我们常常会面临 串行(Serial) 和 并行(Parallel) 的选择。串行执行任务简单直观,但并行能更高效地利用 CPU 资源,…...
软考中级-数据库-3.2 数据结构-数组和矩阵
数组 一维数组是长度固定的线性表,数组中的每个数据元素类型相同。n维数组是定长线性表在维数上的扩张,即线性表中的元素又是一个线性表。 例如一维数组a[5][a1,a2,a3,a4,a5] 二维数组a[2][3]是一个2行2列的数组 第一行[a11,a12,a13] 第二行[a21,a22,a23…...
LeetCode 解题思路 9(Hot 100)
解题思路: 遍历并调整数组: 对于每个元素 nums[i],若其值为正且不超过数组长度 len,则将其逐步交换到它应该在的位置。查找缺失的正整数: 遍历调整后的数组,若某个位置的值不等于其索引加1,则说…...
交叉编译 perl-5.40.0 perl-cross-1.5.3
1.下载地址: https://www.cpan.org/src/5.0/ https://github.com/arsv/perl-cross/tags2.编译 # 进入源码目录 cd /opt/snmp/perl # 合并perl-cross到Perl源码 cp -R perl-cross-1.5.3/* perl-5.40.0/ cd perl-5.40.0./configure --targetaarch64-poky-linux --p…...
go前后端开源项目go-admin,本地启动
https://github.com/go-admin-team/go-admin 教程 1.拉取项目 git clone https://github.com/go-admin-team/go-admin.git 2.更新整理依赖 go mod tidy会整理依赖,下载缺少的包,移除不用的,并更新go.sum。 # 更新整理依赖 go mod tidy 3.编…...
突破光学成像局限:全视野光学血管造影技术新进展
全视野光学血管造影(FFOA)作为一种实时、无创的成像技术,能够提取生物血液微循环信息,为深入探究生物组织的功能和病理变化提供关键数据。然而,传统FFOA成像方法受到光学镜头景深(DOF)的限制&am…...
RefuseManualStart/Stop增强Linux系统安全性?详解systemd单元保护机制
一、引子:一个“手滑”引发的血案 某天凌晨,运维工程师小张在维护生产服务器时,误输入了 systemctl start reboot.target,导致整台服务器瞬间重启,线上服务中断30分钟,直接损失数十万元。事后排查发现&…...
国产编辑器EverEdit - 超级丰富的标签样式设置!
1 设置-高级-标签 1.1 设置说明 选择主菜单工具 -> 设置 -> 常规,在弹出的选项窗口中选择标签分类,如下图所示: 1.1.1 多文档标签样式 默认 平坦 渐变填充 1.1.2 停靠窗格标签样式 默认 平坦 渐变填充 1.1.3 激活Tab的…...
装饰器模式:灵活扩展对象功能的利器
一、从咖啡加料说起:什么是装饰器模式? 假设您走进咖啡馆点单: 基础款:美式咖啡(15元)加料需求:加牛奶(3元)、加焦糖(5元)、加奶油(…...
# [Linux] [Anaconda]解决在 WSL Ubuntu 中安装 Anaconda 报错问题
在 Windows 10 中安装了 WSL(Windows Subsystem for Linux)并使用 Ubuntu 后,你可能会下载 Anaconda 的 Linux 版本进行安装。但在安装过程中,可能会遇到 tar (child): bzip2: Cannot exec: No such file or directory 这样的错误…...
【回溯】216. 组合总和 III
题目 216. 组合总和 III 思路 不知道for有几层时,使用回溯,比上一题多了一个条件,组合需要和为n。 代码 class Solution { private:vector<vector<int>>result;vector<int>path;void backtracking(int target,int k,i…...
AI编程工具-(四)
250304今天用【通义灵码】做了下简单的分析建模工作。不够丝滑,但是在数据预处理方面还是有用。 目录 准备工作一分析工作建模结论 这个数据集是网上随手找的时许指标数据,然后分析时序指标A和B关联关系。 准备工作一 问大模型,这个场景有哪…...
一种事件驱动的设计模式-Reactor 模型
Reactor 模型 是一种事件驱动的设计模式,主要用于处理高并发的 I/O 操作(如网络请求、文件读写等)。其核心思想是通过事件分发机制,将 I/O 事件的监听和处理解耦,从而高效管理大量并发连接,避免传统多线程模…...
AI-Ollama本地大语言模型运行框架与Ollama javascript接入
1.Ollama Ollama 是一个开源的大型语言模型(LLM)平台,旨在让用户能够轻松地在本地运行、管理和与大型语言模型进行交互。 Ollama 提供了一个简单的方式来加载和使用各种预训练的语言模型,支持文本生成、翻译、代码编写、问答等多种…...
XPath路径表达式
1. 绝对路径表达式 语法:/根元素/子元素/子子元素... 特点**:**必须从根元素开始,完整地逐层写路径。 示例代码: <!-- XML结构 --> <school> <class id"1"> <student>小明</student> &l…...
大语言模型的逻辑:从“鹦鹉学舌”到“举一反三”
引言 近年来,大语言模型(LLM)在自然语言处理领域取得了突破性进展,其强大的文本生成和理解能力令人惊叹。然而,随着应用的深入,人们也开始关注LLM的“逻辑”问题:它究竟是机械地模仿人类语言&a…...
从0到1构建AI深度学习视频分析系统--基于YOLO 目标检测的动作序列检查系统:(0)系统设计与工具链说明
文章大纲 系统简介Version 1Version2环境摄像机数据流websocket 发送图像帧RTSP 视频流树莓派windows消息队列参考文献项目地址提示词系统简介 Version 1 Version2 环境 # 配置 conda 源 # 配置conda安装源 conda config --add channels https://mirrors.tuna.tsinghua.edu.c…...
在Linux环境部署SpringBoot项目
在xshell中手动开放8080端口 sudo ufw allow 8080/tcp systemctl reload ufw systemctl restart ufw 配置文件要求 也可以使用maven来分平台 部署到linux服务器上 1.建一个文件夹 2.将jar包拖拽到文件夹中 3.运行nohup java -jar jar包 &的命令启动程序 //后台启动 …...
8. 保存应用数据
一、课程笔记 1.0 引入 针对那些体积小,访问频率高,且对它的速度有一定要求的轻量化数据。例如,用户偏好设置用配置参数等,使用传统的惯性数据库进行存储,不惊险的笨重,还可能引入不必要的性能开销。 此时…...
ADC采集模块与MCU内置ADC性能对比
2.5V基准电压源: 1. 精度更高,误差更小 ADR03B 具有 0.1% 或更小的初始精度,而 电阻分压方式的误差主要来自电阻的容差(通常 1% 或 0.5%)。长期稳定性更好,分压电阻容易受到温度、老化的影响,长…...
量子算法:英译名、概念、历史、现状与展望?
李升伟 整理 #### 英译名 量子算法的英文为 **Quantum Algorithm**。 #### 概念 量子算法是利用量子力学原理(如叠加态、纠缠态和干涉)设计的算法,旨在通过量子计算机高效解决经典计算机难以处理的问题。其核心在于利用量子比特(…...
水仙花数(华为OD)
题目描述 所谓水仙花数,是指一个n位的正整数,其各位数字的n次方和等于该数本身。 例如153是水仙花数,153是一个3位数,并且153 13 53 33。 输入描述 第一行输入一个整数n,表示一个n位的正整数。n在3到7之间&#x…...
基于编程语言的建筑行业施工图设计系统开发可行性研究————从参数化建模到全流程自动化的技术路径分析
基于编程语言的建筑行业施工图设计系统开发可行性研究————从参数化建模到全流程自动化的技术路径分析 文章目录 **基于编程语言的建筑行业施工图设计系统开发可行性研究————从参数化建模到全流程自动化的技术路径分析** 摘要引言一、技术可行性深度剖析1.1 现有编程语言…...
【Linux】【网络】UDP打洞-->不同子网下的客户端和服务器通信(未成功版)
【Linux】【网络】UDP打洞–>不同子网下的客户端和服务器通信(未成功版) 上次说基于UDP的打洞程序改了五版一直没有成功,要写一下问题所在,但是我后续又查询了一些资料,成功实现了,这次先写一下未成功的…...
C# 中的Action和Func是什么?Unity 中的UnityAction是什么? 他们有什么区别?
所属范围:Action 和 Func 是 C# 语言标准库中的委托类型,可在任何 C# 项目里使用;UnityAction 是 Unity 引擎专门定义的委托类型,只能在 Unity 项目中使用。 返回值:Action 和 UnityAction 封装的方法没有返回值&…...
SparkStreaming之03:容错、语义、整合kafka、Exactly-Once、ScalikeJDBC
SparkStreaming进阶 一 、要点:star:4.1 SparkStreaming容错4.1.1 SparkStreaming运行流程4.1.2 如果Executor失败?:star:4.1.3 如果Driver失败?4.1.4 数据丢失如何处理:star:4.1.5 当一个task很慢容错 :star:4.2 SparkSreaming语义4.3 SparkStreaming与…...
让单链表不再云里雾里
一日不见,如三月兮!接下来与我一起创建单链表吧! 目录 单链表的结构: 创建单链表: 增加结点: 插入结点: 删除结点: 打印单链表: 单链表查找: 单链表…...