PyTorch 模型转换为 ONNX 格式
PyTorch 模型转换为 ONNX 格式
在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH 和 ONNX 格式的区别,并介绍如何使用 Netron 可视化 ONNX 模型。
1. PTH 和 ONNX 的区别
PTH 格式
-
定义:PTH 是 PyTorch 框架的专有格式,通常用于保存模型的状态字典(state_dict),包括模型的结构和训练好的参数。
-
兼容性:
- PTH 文件只能在 PyTorch 中使用,无法直接在 C++ 环境中加载。虽然 PyTorch 提供了 C++ API(LibTorch),但 PTH 文件的加载和使用主要依赖于 Python 环境。
- 在 C++ 中使用 PTH 文件需要将模型转换为 PyTorch 的 C++ 格式,这可能会增加复杂性和开发时间。
-
用途:
- PTH 格式适合在 Python 环境中进行模型训练和调试,但在 C++ 中进行模型部署时,通常需要将模型转换为其他格式(如 ONNX)以便于跨平台使用。
- 在 C++ 中,使用 PTH 文件的灵活性较低,尤其是在需要与其他框架或系统集成时。
ONNX 格式
-
定义:ONNX(Open Neural Network Exchange)是一个开放的深度学习模型交换格式,旨在促进不同深度学习框架之间的互操作性。
-
兼容性:
- ONNX 文件可以在多个深度学习框架中使用,包括 PyTorch、TensorFlow、Caffe2 等,这使得它在 C++ 环境中的兼容性更强。
- ONNX 模型可以通过 ONNX Runtime、TensorRT、OpenVINO 等推理引擎在 C++ 中高效运行,支持多种硬件加速。
-
用途:
- ONNX 格式非常适合模型的部署和推理,特别是在需要跨平台或跨框架使用时。它允许开发者在 C++ 中轻松加载和运行模型,而无需依赖于 Python 环境。
- 在 C++ 中,使用 ONNX 模型可以简化工程化流程,便于与其他系统集成,提升模型的可移植性和可扩展性。
总结
在 C++ 进行深度学习模型的工程化时,选择 ONNX 格式通常更为合适,因为它提供了更好的跨平台兼容性和灵活性。PTH 格式虽然在 PyTorch 环境中非常方便,但在 C++ 中的使用受到限制,通常需要额外的转换步骤。ONNX 的开放性和广泛支持使其成为在多种环境中部署深度学习模型的首选格式。
2. 训练 MNIST 数据集的 CNN 模型
以下是使用 PyTorch 训练 MNIST 数据集的完整代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 1. 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # MNIST 数据集的均值和标准差
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 2. 定义 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) # 输入通道为1,输出通道为32self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 输入通道为32,输出通道为64self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 最大池化层self.fc1 = nn.Linear(64 * 7 * 7, 128) # 全连接层self.fc2 = nn.Linear(128, 10) # 输出层def forward(self, x):x = self.pool(torch.relu(self.conv1(x))) # 第一层卷积 + 激活 + 池化x = self.pool(torch.relu(self.conv2(x))) # 第二层卷积 + 激活 + 池化x = x.view(x.size(0), -1) # 展平输入x = torch.relu(self.fc1(x)) # 第一个全连接层x = self.fc2(x) # 输出层return x# 3. 训练模型
model = SimpleCNN().to(device) # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 优化器# 训练过程
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device) # 将数据移动到 MPS 设备optimizer.zero_grad() # 清空梯度outputs = model(images) # 前向传播loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device) # 将数据移动到 MPS 设备outputs = model(images)_, predicted = torch.max(outputs.data, 1) # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')# 5. 转换为 ONNX 格式
onnx_file_path = 'mnist_cnn_model.onnx'
dummy_input = torch.randn(1, 1, 28, 28).to(device) # 示例输入,形状为 [batch_size, channels, height, width]
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True,opset_version=11, do_constant_folding=True,input_names=['input'], output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f'Model has been converted to ONNX format and saved as {onnx_file_path}.')
3. 使用 Netron 可视化 ONNX 模型
一旦您将模型转换为 ONNX 格式,您可以使用 Netron 来可视化模型结构。Netron 是一个开源的模型可视化工具,支持多种深度学习框架的模型文件格式,包括 ONNX。
使用步骤:
-
下载 Netron:
- 您可以访问 Netron 的官方网站 在线使用,或者下载桌面版本。
-
打开 ONNX 模型:
- 如果使用在线版本,直接将
mnist_cnn_model.onnx
文件拖放到浏览器窗口中。 - 如果使用桌面版本,打开 Netron 应用,选择“File” > “Open Model”,然后选择您的 ONNX 文件。
- 如果使用在线版本,直接将
-
查看模型结构:
- 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
- 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
相关文章:
PyTorch 模型转换为 ONNX 格式
PyTorch 模型转换为 ONNX 格式 在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH …...
大数据-234 离线数仓 - 异构数据源 DataX 将数据 从 HDFS 到 MySQL
点一下关注吧!!!非常感谢!!持续更新!!! Java篇开始了! 目前开始更新 MyBatis,一起深入浅出! 目前已经更新到了: Hadoop࿰…...
【人工智能】使用Python实现序列到序列(Seq2Seq)模型进行机器翻译
解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 序列到序列(Sequence-to-Sequence, Seq2Seq)模型是解决序列输入到序列输出任务的核心架构,广泛应用于机器翻译、文本摘要和问答系统等自然语言处理任务中。本篇文章深入介绍 Seq2Seq 模型的原理及其核心组件(…...
elasticsearch安装ik分词器
本文主要记录如何安装ik分词器,如果你刚好刷到了这篇文章,希望对你有所帮助。 IKAnalyzer是一个开源的,基于java语言开发的轻量级的中文分词工具包。采用了特有的“正向迭代最细粒度切分算法“,支持细粒度和最大词长两种切分模式&…...
QT6之主站freemodbus1.6移植
本次使用的QT是6.8 下载1.6的freemodbus资源包:至少以上的吧 随便下载:官网也可以这个是STM芯片的教程,移植基本一样,略有不同; STM32 移植FreeModbus详细过程-CSDN博客 移植freemodbus: 添加资源文件&a…...
【错误❌】——槽函数定义好但未初始化
public slots:void onClose(); 初始化即可成功:...
数据结构(理解)
探索数据结构:计算机世界的基石 在计算机科学的领域中,数据结构就如同建筑中的基石,它们支撑着整个软件世界的运行。无论是简单的应用程序,还是复杂的大型系统,数据结构都在其中起着至关重要的作用。 一、什么是数据结…...
ROS2 细节知识学习
1. rosidl_generate_interfaces() 在 ROS2 中,rosidl_generate_interfaces是一个关键的构建工具功能。它主要用于从接口定义文件(如.msg消息文件、.srv服务文件和.action动作文件)生成不同编程语言(如 C、Python 等)可…...
SQL进阶——JOIN操作详解
在数据库设计中,数据通常存储在多个表中。为了从这些表中获取相关的信息,我们需要使用JOIN操作。JOIN操作允许我们通过某种关系(如相同的列)将多张表的数据结合起来。它是SQL中非常重要的操作,广泛应用于实际开发中。本…...
Android studio 签名加固后的apk文件
Android studio打包时,可以选择签名类型v1和v2,但是在经过加固后,签名就不在了,或者只有v1签名,这样是不安全的。 操作流程: 1、Android studio 对项目进行打包,生成有签名的apk文件ÿ…...
Mybatis-基础操作
Mybatis的基础操作就是通过Mybatis完成对数据的增删改查。我们通过例子来引入这些操作,之前的项目较久远,因此我们从零开始进行准备工作: 搭建项目 一、创建数据库user_list并插入数据: -- 创建数据库 create table user_list …...
【工具】JS解析XML并且转为json对象
【工具】JS解析XML并且转为json对象 <?xml version1.0 encodingGB2312?> <root><head><transcode>hhhhhhh</transcode></head><body><param>ccccccc</param><param>aaaaaaa</param><param>qqqq<…...
软件测试技术面试题及参考答案整理
一、什么是兼容性测试?兼容性测试侧重哪些方面? 参考答案: 兼容测试主要是检查软件在不同的硬件平台、软件平台上是否可以正常的运行,即是通常说的软件的可移植性。 兼容的类型,如果细分的话,有平台的兼容,网络兼…...
Python学习36天
面向对象编程综合 # 创建父类 class Employee:# 创建私有属性__name None__salary None# 创建构造器初始化属性def __init__(self, __name, __salary):self.__name __nameself.__salary __salarydef get_annual(self):# 返回员工年薪return self.__salary * 12# 创建公共方…...
C语言——海龟作图(对之前所有内容复习)
一.问题描述 海龟作图 设想有一只机械海龟,他在C程序控制下在屋里四处爬行。海龟拿了一只笔,这支笔或者朝上,或者朝下。当笔朝下时,海龟用笔画下自己的移动轨迹;当笔朝上时,海龟在移动过程中什么也不画。 …...
关于如何在k8s中搭建一个nsfw黄图鉴定模型
随着现在应用内图片越来越多,安全审查也是必不可少的一个操作了 下面手把手教你如何将huggingface中的黄图检测模型部署到自己的服务器上去 1.找到对应的模型 nsfw_image_detection 2.在本地先验证如何使用 首先安装transformers python库 pip install transform…...
istio结合wasm插件的实际应用
在 Istio 中,WASM 插件的常见使用场景和功能包括以下几个方面: 1. 流量管理与请求修改 请求与响应头处理:动态添加、删除或修改 HTTP 请求或响应头。URL 重写:根据特定规则调整请求的路径或参数。请求路由增强:实现复…...
日志logrus
https://blog.csdn.net/m0_70982551/article/details/143095729 https://blog.csdn.net/wslyk606/article/details/81670713 https://www.bilibili.com/opus/1002468521099132928 地鼠文档:https://www.topgoer.cn/docs/goday/goday-1crg2adjknouc 极客文档…...
11.29 代码随想录Day45打卡(动态规划)
115.不同的子序列 题目:给你两个字符串 s 和 t ,统计并返回在 s 的 子序列 中 t 出现的个数。 题解: class Solution:def numDistinct(self, s: str, t: str) -> int:dp [[0] * (len(t) 1) for _ in range(len(s) 1)]for i in range…...
springboot336社区物资交易互助平台pf(论文+源码)_kaic
毕 业 设 计(论 文) 社区物资交易互助平台设计与实现 摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此ÿ…...
【Maven】Nexus私服
6. Maven的私服 6.1 什么是私服 Maven 私服是一种特殊的远程仓库,它是架设在局域网内的仓库服务,用来代理位于外部的远程仓库(中央仓库、其他远程公共仓库)。一些无法从外部仓库下载到的构件,如项目组其他人员开发的…...
【python量化教程】如何使用必盈API的股票接口,获取最新分时KDJ数据
分时KDJ数据简介 股票分时 KDJ 数据是用于分析股票盘中短期走势的指标。它由未成熟随机指标 RSV 计算出 K 值、D 值、J 值。取值范围上,K 和 D 是 0 - 100,J 值可超出此范围。20 以下为超卖区、80 以上是超买区。关键信号有金叉(预示上涨&am…...
DI依赖注入详解
DI依赖注入 声明了一个成员变量(对象)之后,在该对象上面加上注解AutoWired注解,那么在程序运行时,该对象自动在IOC容器中寻找对应的bean对象,并且将其赋值给成员变量,完成依赖注入。 AutoWire…...
mysql sql语句 between and 是否边界值
在 MySQL 中,使用 BETWEEN 运算符时,边界值是包括在内的。这意味着 BETWEEN A AND B 查询会返回 A 和 B 之间的所有值,包括 A 和 B 自身。 示例 假设有一个表 employees,其中有一个 salary 列,您可以使用以下查询&am…...
飞塔防火墙只允许国内IP访问
飞塔防火墙只允许国内IP访问 方法1 新增地址对象,注意里面已经细分为中国内地、中国香港、中国澳门和中国台湾 方法2 手动新增国内IP的对象组,目前好像一共有8632个,每个对象最多支持600个IP段...
宠物之家:基于SpringBoot的领养平台
第1章 绪论 1.1 课题背景 二十一世纪互联网的出现,改变了几千年以来人们的生活,不仅仅是生活物资的丰富,还有精神层次的丰富。时代进步的标志,就是让人们过上更好的生活。在互联网诞生之前,地域位置往往是人们思想上不…...
golang 实现比特币内核:如何接入 RPC 后端获得特定交易的二进制数据
我们非常关注解析比特币的二进制数据,这使得我们的工作看起来是可行的。比特币是一个分布式网络系统,这意味着它需要全球各地的节点协同工作,甚至比特币核心库也需要连接其他节点来帮助它,就像查询交易费一样。 世界上没有免费的午餐。当你使用比特币系统进行交易时,你需…...
QML学习 —— 34、视频媒体播放器(附源码)
效果 说明 您可以单独使用MediaPlayer播放音频内容(如音频),也可以将其与VideoOutput结合使用以渲染视频。VideoOutput项支持未转换、拉伸和均匀缩放的视频演示。有关拉伸均匀缩放演示文稿的描述,请参见fillMode属性描述。 播放可能出错问题 出现的问题: DirectS…...
宝塔Linux面板上传PHP文件或者修改PHP文件,总是转圈圈,其他文件正常,解决办法
目录 问题描述 寻找解决方案 1.重启宝塔面板 2.清理宝塔缓存 3.升级面板 4.ssh远程 5.清空回收站 6.换网络 7. IDE远程编辑 总结: 问题描述 一直用宝塔linux面板,感觉非常好用,点点就能搞定,环境也很好配置。 公司搬家&…...
Flink——进行数据转换时,报:Recovery is suppressed by NoRestartBackoffTimeStrategy
热词统计案例: 用flink中的窗口函数(apply)读取kafka中数据,并对热词进行统计。 apply:全量聚合函数,指在窗口触发的时候才会对窗口内的所有数据进行一次计算(等窗口的数据到齐,才开始进行聚合…...
贪心算法题目合集
贪心算法题目合集 1319:【例6.1】排队接水 贪心策略思想 1319:【例6.1】排队接水 贪心策略思想 1319:【例6.1】排队接水 贪心算法与其说是算法,不如说是一种风格:每次做事情都选择自己认为的最优解。 贪心算法的题很…...
NSSCTF-做题笔记
[羊城杯 2020]easyre 查壳,无壳,64位,ida打开 encode_one encode_tow encode_three 那么我们开始一步一步解密,从最外层开始 def decode_three(encrypted_str):decrypted_str ""for char in encrypted_str:char_code …...
SpringBoot源码-spring boot启动入口ruan方法主线分析(一)
一、SpringBoot启动的入口 1.当我们启动一个SpringBoot项目的时候,入口程序就是main方法,而在main方法中就执行了一个run方法。 SpringBootApplication public class StartApp {public static void main(String[] args) {// testSpringApplication.ru…...
python json.dump()和json.dumps()的区别
用人话总结一下 json.dump()是针对文件的json和python的转换 json.dumps()主要是针对内容数据 json.dumps(obj, skipkeysFalse, ensure_asciiTrue, check_circularTrue, allow_nanTrue, clsNone, indentNone, separatorsNone, encoding“utf-8”, defaultNone, sort_keysFalse…...
快速排序hoare版本和挖坑法(代码注释版)
hoare版本 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h>// 交换函数 void Swap(int* p1, int* p2) {int tmp *p1;*p1 *p2;*p2 tmp; }// 打印数组 void _printf(int* a, int n) {for (int i 0; i < n; i) {printf("%d ", a[i]);}printf("…...
ELK(Elasticsearch + logstash + kibana + Filebeat + Kafka + Zookeeper)日志分析系统
文章目录 前言架构软件包下载 一、准备工作1. Linux 网络设置2. 配置hosts文件3. 配置免密登录4. 设置 NTP 时钟同步5. 关闭防火墙6. 关闭交换分区7. 调整内存映射区域数限制8. 调整文件、进程、内存资源限制 二、JDK 安装1. 解压软件2. 配置环境变量3. 验证软件 三、安装 Elas…...
SpringBoot中忽略实体类中的某个属性不返回给前端的方法
使用Jackson的方式: //第一种方式,使用JsonIgnore注解标注在属性上,忽略指定属性 public class PropertyDTO {JsonProperty("disable")private Integer disable;JsonProperty("placeholder")private String placeholde…...
Flink中普通API的使用
本篇文章从Source、Transformation(转换因子)、sink这三个地方进行讲解 Source: 创建DataStream本地文件SocketKafka Transformation(转换因子): mapFlatMapFilterKeyByReduceUnion和connectSide Outpu…...
【人工智能】从零构建一个文本分类器:用Python和TF-IDF实现
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 文本分类是自然语言处理(NLP)领域的基础任务之一,广泛应用于垃圾邮件检测、情感分析和新闻分类等场景。本篇文章从零开始,通过详细讲解 TF-IDF 特征提取方法,以及如何将其与机器学习算法结合,实现一…...
原型模式
功能:复制一个运行时的对象,包括对象各个成员当前的值。并且能够通过父类的指针来克隆出子类的对象 主要解决:在运行期建立原型 优点:性能提高、避免了构造函数的约束 步骤: 1、定义抽象原型,声明纯虚接…...
基于FPGA的FM调制(载波频率、频偏、峰值、DAC输出)-带仿真文件-上板验证正确
基于FPGA的FM调制-带仿真文件-上板验证正确 前言一、FM调制储备知识载波频率频偏峰值个人理解 二、代码分析1.模块分析2.波形分析 总结 前言 FM、AM等调制是学习FPGA信号处理一个比较好的小项目,通过学习FM调制过程熟悉信号处理的一个简单流程,进而熟悉…...
open-instruct - 训练开放式指令跟随语言模型
文章目录 关于 open-instruct设置训练微调偏好调整RLVR 污染检查开发中仓库结构 致谢 关于 open-instruct github : https://github.com/allenai/open-instruct 这个仓库是我们对在公共数据集上对流行的预训练语言模型进行指令微调的开放努力。我们发布这个仓库,并…...
Java爬虫:获取1688商品详情接口的技术实现与代码示例
引言 1688作为中国领先的B2B电子商务平台,拥有海量的商品信息。对于商家和市场研究人员来说,能够从1688获取商品详情信息,对于市场分析、竞品研究等具有重要价值。本文将介绍如何使用Java编写爬虫,以合法、高效的方式获取1688商品…...
详解Rust泛型用法
文章目录 基础语法泛型与结构体泛型约束泛型与生命周期泛型与枚举泛型和Vec静态泛型(const 泛型)类型别名默认类型参数Sized Trait与泛型常量函数与泛型泛型的性能 Rust是一种系统编程语言,它拥有强大的泛型支持,泛型是Rust中用于实现代码复用和类型安全…...
Spring Boot拦截器(Interceptor)详解
拦截器Interceptor 拦截器我们主要分为三个方面进行讲解: 介绍下什么是拦截器,并通过快速入门程序上手拦截器拦截器的使用细节通过拦截器Interceptor完成登录校验功能 1. 快速入门 什么是拦截器? 是一种动态拦截方法调用的机制ÿ…...
STM32-- 看门狗--介绍、使用场景、失效场景
STM32 中的看门狗(Watchdog Timer,简称 WDG)有两种主要类型:独立看门狗(IWDG) 和 窗口看门狗(WWDG)。它们的喂狗机制各有特点,主要区别如下: 1. 独立看门狗&a…...
Perplexica - AI 驱动的搜索引擎
更多AI开源软件: AI开源 - 小众AIhttps://www.aiinn.cn/sources Perplexica 是一个开源的 AI 驱动的搜索工具或 AI 驱动的搜索引擎,可以深入互联网寻找答案。受 Perplexity AI 的启发,它是一个开源选项,不仅可以搜索网络…...
Linux笔记--基于OCRmyPDF将扫描件PDF转换为可搜索的PDF
1--官方仓库 https://github.com/ocrmypdf/OCRmyPDF 2--基本步骤 # 安装ocrmypdf库 sudo apt install ocrmypdf# 安装简体中文库 sudo apt-get install tesseract-ocr-chi-sim# 转换 # -l 表示使用的语言 # --force-ocr 防止出现以下错误:ERROR - PriorOcrFoundE…...
MySQL聚合查询分组查询联合查询
#对应代码练习 -- 创建考试成绩表 DROP TABLE IF EXISTS exam; CREATE TABLE exam ( id bigint, name VARCHAR(20), chinese DECIMAL(3,1), math DECIMAL(3,1), english DECIMAL(3,1) ); -- 插入测试数据 INSERT INTO exam (id,name, chinese, math, engli…...
ffmpeg 预设的值 加速
centos 安装ffmpeg 编译安装 官网获取最新的linux ffmpeg 代码 https://ffmpeg.org//releases/ mkdir -p /data/app/ffmpeg cd /data/app/ffmpeg wget http://www.ffmpeg.org/releases/ffmpeg-7.1.tar.gz tar -zxvf ffmpeg-7.1.tar.gz#安装所需的编译环境 yum install -y \…...