PyTorch 基础要点详解:从模型构建到评估
在深度学习领域,PyTorch 作为一款广受欢迎的开源框架,为开发者提供了便捷高效的工具。今天,我们就深入探讨一下 PyTorch 中的几个关键要点:torch.nn.Linear
、torch.nn.MSELoss
、model.train()
以及 model.eval()
,了解它们如何助力模型开发与评估。
一、torch.nn.Linear
:神经网络的基石 —
— 全连接层
全连接层是构建神经网络的基础组件之一,而 torch.nn.Linear
类在 PyTorch 中就是用于创建全连接层的关键工具。
从功能上看,它实现了对输入数据的线性变换。给定输入向量 x
,权重矩阵 W
和偏置向量 b
,通过公式 y = xW^T + b
得到输出向量 y
。这看似简单的操作,却是复杂神经网络架构中的核心步骤,能够将输入特征进行整合与转换。
使用起来也相当便捷,例如创建一个输入维度为 10
,输出维度为 5
的线性层:
import torch
import torch.nn as nnlinear_layer = nn.Linear(in_features=10, out_features=5)input_tensor = torch.randn(3, 10)
output_tensor = linear_layer(input_tensor)print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output_tensor.shape)
这里,我们定义了 linear_layer
,当输入形状为 (3, 10)
的张量时,它能按照设定的线性变换规则输出形状为 (3, 5)
的张量。
在参数方面,in_features
指明输入特征数量,要与输入张量最后一维匹配;out_features
设定输出特征数量;bias
默认为 True
,决定是否添加偏置项。
值得注意的是,权重和偏置会自动初始化,当然也能按需手动调整。并且输入张量的最后一维必须符合 in_features
要求,它还支持批量处理,只要最后一维正确,前面的维度可用于表示批量大小。
全连接层在多层感知机(MLP)、图像分类的 CNN 后续层以及自然语言处理的各类
模型中都有广泛应用,是实现复杂任务的重要基石。
二、torch.nn.MSELoss
:回归问题的 “裁判”
在处理回归任务时,我们需要一个标准来衡量模型预测值与真实值之间的偏差,torch.nn.MSELoss
就是这样一个常用的损失函数。
它基于均方误差(Mean Squared Error,MSE)概念,计算预测值与真实值误差平方的平均值。直观地说,MSE 值越小,模型预测就越接近真实值,反映出模型的拟合效果越好。
使用示例如下:
import torch
import torch.nn as nnmse_loss = nn.MSELoss()y_true = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
y_pred = torch.tensor([1.2, 1.8, 3.1], dtype=torch.float32)loss = mse_loss(y_pred, y_true)
print("均方误差损失值:", loss.item())
构造函数中的 reduction
参数决定损失计算方式:'none'
不缩减,返回每个样本损失;'mean'
求平均值,是默认值;'sum'
则求和。
其数学原理依循经典的 MSE 计算公式,根据 reduction
取值不同有不同形式,在回归任务如房价预测、股票价格预测中广泛应用,同时也用于评估训练后模型在测试集上的性能。
使用时要确保输入的真实值和预测值张量数据类型一致,通常为 torch.float32
或 torch.float64
,且形状必须相同,否则会报错。
三、model.train()
:开启模型训练之旅
当我们着手训练模型时,model.train()
就是那个 “启动开关”。
它的核心作用是告知模型当前进入训练阶段,使得模型中的特定层能遵循训练规则运作。以 Dropout
层为例,在训练模式下,它会按照设定概率随机丢弃神经元,防止模型过拟合。假设设置 Dropout
概率为 0.5,每次前向传播都有一半神经元可能被暂时 “弃用”,迫使模型学习更具鲁棒性的特征。
Batch Normalization
层在训练时,会依据当前批次数据动态计算均值和方差,以此对输入归一化,加速收敛并缓解梯度问题。
以下是简单的训练示例:
import torch
import torch.nn as nn
import torch.optim as optimclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(10, 20)self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = self.fc1(x)x = self.dropout(x)x = self.fc2(x)return xmodel = SimpleModel()
model.train()criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)inputs = torch.randn(32, 10)
labels = torch.randn(32, 1)for epoch in range(10):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
在训练过程中,务必记得调用 model.train()
,它只会影响像 Dropout
和 Batch Normalization
这类在训练、评估行为有别的层,其他层如常运作。
四、model.eval()
:精准评估模型表现
模型训练完毕,进入评估环节,model.eval()
就派上用场了。
它的使命是将模型切换到评估模式,确保评估结果的准确性与稳定性。对于 Dropout
层,评估时不再随机丢弃神经元,而是让所有神经元参与计算,毕竟此时需要完整模型的输出。
Batch Normalization
层则使用训练过程中统计积累的全局均值和方差进行归一化,避免因批次不同带来的波动。
使用场景多为在验证集或测试集上预测,常结合 torch.no_grad()
一起使用,避免不必要的梯度计算,示例如下:
import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(10, 20)self.bn = nn.BatchNorm1d(20)self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = self.fc1(x)x = self.bn(x)x = self.dropout(x)x = self.fc2(x)return xmodel = SimpleModel()
model.eval()input_data = torch.randn(1, 10)with torch.no_grad():output = model(input_data)print(output)
与 model.train()
相对应,评估时务必调用 model.eval()
,否则可能导致评估结果偏差。它同样只作用于特定层,保障评估过程的精准。
综上,掌握 torch.nn.Linear
、torch.nn.MSELoss
、model.train()
和 model.eval()
这些要点,就如同握住了 PyTorch 模型开发与评估的关键钥匙,能帮助我们构建更强大、精准的深度学习模型,开启深度学习的探索之旅。
相关文章:
PyTorch 基础要点详解:从模型构建到评估
在深度学习领域,PyTorch 作为一款广受欢迎的开源框架,为开发者提供了便捷高效的工具。今天,我们就深入探讨一下 PyTorch 中的几个关键要点:torch.nn.Linear、torch.nn.MSELoss、model.train() 以及 model.eval(),了解它…...
Dockerfile中CMD命令未生效
今天在使用dockerfile构建容器镜像时,最后一步用到CMD命令启动start.sh,但是尝试几遍都未能成功执行脚本。最后查阅得知:Dockerfile中可以有多个cmd指令,但只有最后一个生效,CMD会被docker run之后的参数替换。 CMD会…...
Linux平台MQTT测试抓包分析
Linux平台搭建MQTT测试环境-CSDN博客基于这里的测试代码抓包 sudo tcpdump -i any -w mqtt1.cap 上述源码中 tcp://localhost:1883 配置连接: Broker Address: localhostPort: 1883 整体通信流程 1. Subscriber和Broker(代理服务器)建立…...
Docker全方位指南
目录 前言 第一部分:Docker基础与安装 1.1 什么是Docker? 1.2 Docker的适用场景 1.3 全平台安装指南 1.4 配置优化 第二部分:Docker核心操作与原理 2.1 镜像管理 2.2 容器生命周期 2.3 网络模型 2.4 Docker Compose 第三部分&…...
【经典DP】三步问题 / 整数拆分 / 不同路径II / 过河卒 / 下降路径最小和 / 地下城游戏
⭐️个人主页:小羊 ⭐️所属专栏:动态规划 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 动态规划总结Fibonacci数列BC140 杨辉三角杨辉三角三步问题最小花费爬楼梯孩子们的游戏解码方法整数拆分不同路径不同路径II过…...
Koji/OBS编译节点OS版本及工具版本管理深度实践指南
引言 在分布式编译框架Koji/OBS中,有效管理编译节点的操作系统(OS)版本及工具版本是确保构建环境稳定性、兼容性和安全性的关键。本文将从多版本共存、自动化更新、兼容性管理等多个维度,系统阐述如何高效管理编译节点的OS版本及…...
39、web前端开发之Vue3保姆教程(三)
四、Vue3中集成Element Plus 1、什么是Element Plus Element Plus 是一款基于 Vue 3 的开源 UI 组件库,旨在为开发者提供一套高质量、易用的组件,用于快速构建现代化的 web 应用程序。 Element Plus 提供了大量的 UI 组件,包括但不限于: 表单组件:输入框、选择器、开关…...
多类型医疗自助终端智能化升级路径(代码版.下)
医疗人机交互层技术实施方案 一、多模态交互体系 1. 医疗语音识别引擎 # 基于Wav2Vec2的医疗ASR系统 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torchaudioclass MedicalASR:def __init__(self):self.processor = Wav2Vec2Processor.from_pretrai…...
Git代码管理
这里写目录标题 分支管理策略TrunkBased🌱 核心理念✅优点❌缺点适用场景 GitFlow✅ GitFlow 的优点❌ GitFlow 的缺点适用场景 AOneFlow✅ AOneFlow 的优点❌缺点适用场景 如何选择分支策略?代码提交规范🌱分支管理🔄代码更新⚔️…...
CubeMX配置STM32F103PWM连续频率输出
要求: 输出2-573Hz频率,输出频率步长小于1Hz 一、CubeMX配置 auto-reload preload在下个周期加载ARR Output compare preload 在下个周期加载CCR 二、 程序 1.启动PWM输出 HAL_TIM_PWM_Start(&htim2,TIM_CHANNEL_1); 2.根据频率调整PSC、ARR、…...
举例说明计算机视觉(CV)技术的优势和挑战。
计算机视觉(CV)技术是人工智能领域的一个重要分支,通过让计算机“看”和“理解”图像或视频,可以实现许多实际应用。以下是计算机视觉技术的优势和挑战的例子: 优势: 自动化处理:CV技术可以自动化地处理大量图像或视频数据,实现快速而准确的分析和识别。提高效率:在许…...
工程师 - FTDI SPI converter
中国网站:FTDIChip- 首页 UMFT4222EV-D UMFT4222EV-D - FTDI 可以下载Datasheet。 UMFT4222EVUSB2.0 to QuadSPI/I2C Bridge Development Module Future Technology Devices International Ltd. The UMFT4222EV is a development module which uses FTDI’s FT4222H…...
河畔石上数(C++)
在 C 里,std::set 是标准模板库(STL)提供的一种关联容器,它能高效地存储唯一元素,并且元素会按照特定的顺序排列,默认是升序。下面从多个方面为你详细介绍 std::set。 1. 头文件包含 若要使用 std::set&a…...
《线性表、顺序表与链表》教案(C语言版本)
🌟 各位看官好,我是maomi_9526! 🌍 种一棵树最好是十年前,其次是现在! 🚀 今天来学习C语言的相关知识。 👍 如果觉得这篇文章有帮助,欢迎您一键三连,分享给更…...
【用Cursor 进行Coding 】
「我」:“添加 XXX 功能” [Claude-3.7]:“好的,我完成了,还顺手做了 19个你没要求不需要的功能、甚至还修改了原有999行正常代码 ~ 不用谢” [Gemini-2.5]:“好的,我会…...
vue2 打包时增加时间戳防止浏览器缓存,打包后文件进行 js、css 压缩
文章目录 前言一、什么是浏览器缓存二、展示效果三、vue.config.js 代码四、代码压缩部分服务器不支持五、感谢 前言 vue 开发过程中,项目前端代码需要更新,更新后由于浏览器缓存导致代码没有及时更新所产生错误,所以在打包时增加时间戳防止…...
TIM定时器
一、TIM定时器 STM32高级定时器实战:PWM、捕获与死区控制详解-CSDN博客 二、相关函数 1.TIM_TimeBaseInitTypeDef结构体讲解 typedef struct {uint16_t TIM_Prescaler; // 预分频器,用于设置定时器计数频率uint16_t TIM_CounterMode; /…...
S130N-ISI 全栈方案与云平台深度协同:重构 PLC 开发新范式
一、什么是 PLC? 1.技术定义 PLC(Power Line Communication)是一种创新的通信技术,它以电力线作为天然的传输介质,通过先进的信号调制技术将高频数据信号叠加于工频电流之上,实现电力输送与数据通信的双频共…...
Jenkins 插件文件优先使用 .jpi 后缀
.hpi 和 .jpi 文件本质上是 Jenkins 插件的打包格式,两者的区别主要体现在历史和命名习惯上: ✅ .hpi(Hudson Plugin) 来源:最初是 Hudson 项目的插件格式。含义:Hudson Plugin 的缩写。用途:早…...
# 决策树与PCA降维在电信客户流失预测中的应用
决策树与PCA降维在电信客户流失预测中的应用 在数据分析和机器学习领域,电信客户流失预测是一个经典的案例。本文将通过Python代码实现,探讨决策树模型在电信客户流失预测中的应用,并结合PCA降维技术优化模型性能,同时对比降维前…...
go语言的语法糖以及和Java的区别
1. Go 语言的语法糖及简化语法 Go 语言本身设计理念是简洁、清晰,虽然不像某些动态语言那样“花哨”,但它提供了几种便捷语法,使代码更简洁: 1.1 短变量声明(Short Variable Declaration) 语法࿱…...
WebRtc 视频流卡顿黑屏解决方案
// node webrtc视频转码服务 const url "http://10.169.xx.xx:8000" <video :ref"videoRefs${index}" :id"videoRefs4_${index}" :src"item" controls:key"item" autoplay muted click"preventDefaultClick"…...
信息安全测评中心-国产化!
项目上使用产品,必须通过国家信息安全测评/ 信息技术产品安全测评,有这个需求的话,可以到CN信息安全测评中心官网中的--测评公告一栏中,找符合要求的产品。 测评公告展示的包括硬件产品、系统、服务资质等。 网址及路径…...
MySQL学习笔记九
第十一章使用数据处理函数 11.1函数 SQL支持函数来处理数据但是函数的可移植性没有SQL强。 11.2使用函数 11.2.1文本处理函数 输入: SELECT vend_name,UPPER(vend_name) AS vend_name_upcase FROM vendors ORDER BY vend_name; 输出: 说明&#…...
DFS 蓝桥杯
最大数字 问题描述 给定一个正整数 NN 。你可以对 NN 的任意一位数字执行任意次以下 2 种操 作: 将该位数字加 1 。如果该位数字已经是 9 , 加 1 之后变成 0 。 将该位数字减 1 。如果该位数字已经是 0 , 减 1 之后变成 9 。 你现在总共可以执行 1 号操作不超过 A…...
动态规划dp专题-(上)
目录 dp理论知识🔥🔥 🎯一、线性DP (1)🚀斐波那契数 -入门级 (2)🚀898. 数字三角形-acwing ---入门级 (3)往期题目 ①选数异或:在…...
正则表达式(一)
一、模式(Patterns)和修饰符(flags) 通过正则表达式,我们可以在文本中进行搜索和替换操作,也可以和字符串方法结合使用。 正则表达式 正则表达式(可叫作 “regexp”,或 “reg”&…...
需求变更导致成本超支,如何止损
需求变更导致成本超支时,可以通过加强需求管理、严格的变更控制流程、优化资源配置、实施敏捷开发、提高风险管理意识等方法有效止损。其中,加强需求管理是止损的核心措施之一。需求管理涉及需求明确化、需求跟踪和变更的管理,有效的需求管理…...
《数据分析与可视化》(清华)ch5-实训代码
小费数据集预处理——求思考题_有问必答-CSDN问答 以上代码在Jupyter Notebook中可以运行,但是在python中就会出如下问题: 这个错误表明在尝试计算均值填充缺失值时,数据中包含非数值类型的列(如文本列),…...
E: The package APP needs to be reinstalled, but I can‘t find an archive for it.
要解决错误 “E: The package mytest needs to be reinstalled, but I can’t find an archive for it”,通常是因为系统中存在损坏的软件包记录或安装过程中断导致 /var/lib/dpkg/status 文件异常。以下是综合多篇搜索结果的解决方案: 解决步骤 备份关…...
若依startPage()详解
背景 startPage基于PageHelper来进行强化,在用户传入pagesize,pageNum等标准参数的时候不需要进行解析 步骤 1.通过ServletUtils工具类getRequestAttributes来获取当前线程的上下文信息 public static ServletRequestAttributes getRequestAttributes() {try {R…...
Oracle AQ
Oracle AQ(Advanced Queuing) 是 Oracle 数据库内置的一种消息队列(Message Queue)技术,用于在应用或系统之间实现异步通信、可靠的消息传递和事件驱动架构。它是 Oracle 数据库的核心功能之一,无需依赖外部…...
npm报错CERT_HAS_EXPIRED解决方案
npm报错解决方案 npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED方案1:尝试切换镜像 # 使用腾讯云镜像 npm config set registry https://mirrors.cloud.tencent.com/npm/# 或使用官方npm源(科学上网) npm config set registry http…...
pnpm 中 Next.js 模块无法找到问题解决
问题概述 项目在使用 pnpm 管理依赖时,出现了 “Cannot find module ‘next/link’ or its corresponding type declarations” 的错误。这是因为 pnpm 的软链接机制在某些情况下可能导致模块路径解析问题。 问题诊断 通过命令 pnpm list next 确认项目已安装 Next.js 15.2.…...
急速实现Anaconda/Miniforge虚拟环境的克隆和迁移
目录 参考资料 点击Anaconda Prompt (anaconda_base) 查看现有环境 开始克隆,以克隆pandas_env为例,新的环境名字为image (base) C:\Users\hello>conda create -n image --clone pandas_env查看克隆结果,image环境赫然在列。 然后粘贴…...
OpenCv高阶(二)——图像的掩膜
目录 掩膜 bitwise_and原理 掩膜的实现 1、基于像素操作 2、使用形态学操作 3、基于阈值处理 案例 1、读取原图并绘制掩膜 2、掩膜的实现 3、绘制掩膜的直方图 应用 掩膜 OpenCV 中图像掩膜(Mask)实现的原理是通过一个与原始图像大小相同的二…...
数据结构和算法(十二)--最小生成树
一、最小生成树 定义:图的生成树是它的一颗含有其所有顶点的无环连通子图,一副加权无向图的最小生成树它的一颗权值(树中所有边的权重之和)最小的生成树。 约定:只考虑连通图。最小生成树的定义说明它只能存在于连通图…...
开源酷炫的Linux监控工具:sampler
sampler是一个开源的监控工具,来自GitHub用户sqshq(Alexander Lukyanchikov)的匠心之作。 简单来说,sampler能干这些事儿: 实时监控:CPU、内存、磁盘、网络,甚至应用程序的状态,它…...
InternVideo2.5:Empowering Video MLLMs with Long and Rich Context Modeling
一、TL;DR InternVideo2.5通过LRC建模来提升MLLM的性能。层次化token压缩和任务偏好优化(mask时空 head)整合到一个框架中,并通过自适应层次化token压缩来开发紧凑的时空表征MVBench/Perception Test/EgoSchema/MLVU数据benchmar…...
OSPF基础与特性
一.OSPF 的技术背景 OSPF出现是因为RIP协议无法满足大型网络的配置 RIP协议中存在的问题 RIP中存在最大跳数为15的限制,不能适应大规模组网 RIP周期性发送全部路由信息,占用大量的带宽资源 路由收敛速度慢 以跳数作为度量衡,选路可能会不优 存在路由环路的可能性 每隔30秒更新…...
[Linux]从零开始的ARM Linux交叉编译与.so文件链接教程
一、前言 最近在项目需要将C版本的opencv集成到原本的代码中从而进行一些简单的图像处理。但是在这其中遇到了一些问题,首先就是原本的opencv我们需要在x86的架构上进行编译然后将其集成到我们的项目中,这里我们到底应该将opencv编译为x86架构的还是编译…...
golang 中 make 和 new 的区别?
在Go语言中,make 和 new 都是用于内存分配的关键字,但它们在使用场景、返回值和初始化方式等方面存在一些区别,以下是具体分析: 使用场景 make 只能用于创建 map、slice 和 channel 这三种引用类型,用于初始化这些类型…...
碧螺春是绿茶还是红茶
碧螺春是绿茶,不是红茶。 碧螺春的特点: 类别: 碧螺春属于中国六大茶类中的绿茶类。产地: 它产自中国江苏省苏州市太湖的东山和西山(现称金庭镇),是中国十大名茶之一。外形: 碧螺春茶叶外形卷曲如螺,色泽…...
Linux平台搭建MQTT测试环境
Paho MQTT Paho MQTT 是 Eclipse 基金会下的一个开源项目,旨在为多种编程语言提供 MQTT 协议的客户端实现。MQTT(Message Queuing Telemetry Transport)是一种轻量级的发布/订阅(Pub/Sub)消息传输协议ÿ…...
【AI学习】AI Agent(人工智能体)
1,AI agent 1)定义 是一种能够感知环境、基于所感知到的信息进行推理和决策,并通过执行相应动作来影响环境、进而实现特定目标的智能实体。 它整合了多种人工智能技术,具备自主学习、自主行动以及与外界交互的能力,旨…...
克魔助手(Kemob)安装与注册完整教程 - Windows/macOS双平台指南
iOS设备管理工具克魔助手便携版使用全指南 前言:为什么需要专业的iOS管理工具 在iOS开发和设备管理过程中,开发者经常需要突破系统限制,实现更深层次的控制和调试。本文将详细介绍一款实用的便携式工具的使用方法,帮助开发者快速…...
了解GPIO对应的主要功能
GPIO GPIO是通用输入输出端口的简称,芯片上的GPIO引脚与外部设备连接实现通讯、控制以及数据采集等功能,最基本的输出功能是通过控制引脚输出高低电平继而实现开关控制,比如引脚接入LED灯可控制LED灯的亮灭,接入继电器或三极管可…...
Dubbo 注册中心与服务发现
注册中心与服务发现 注册中心概述 注册中心是dubbo服务治理的核心组件,Dubbo依赖注册中心的协调实现服务发现,自动化的服务发现是微服务实现动态扩容、负载均衡、流量治理的基础。 Dubbo的服务发现机制经历了Dubbo2时代的接口级服务发现、Dubbo3时代的…...
一文详解LibTorch环境搭建:Ubuntu20.4配置LibTorch CUDA与cuDNN开发环境
随着深度学习技术的迅猛发展,越来越多的应用程序开始集成深度学习模型以提供智能化服务。为了满足这一需求,开发者们不仅依赖于Python等高级编程语言提供的便捷框架,也开始探索如何将这些模型与C应用程序相结合,以便在性能关键型应…...
micro ubuntu 安装教程
micro ubuntu 安装教程 官网地址 : https://micro-editor.github.io 以下是在 Ubuntu 系统中安装 micro 编辑器 的详细教程: 方法 1:通过 apt 直接安装(推荐) 适用于 Ubuntu 20.04 及以上版本(官方仓库已收录…...