深入理解 PyTorch 的 nn.Embedding:词向量映射及变量 weight 的更新机制
文章目录
- 前言
- 一、直接使用 `nn.Embedding` 获得变量
- 1、典型场景
- 2、示例代码:
- 3、特点
- 二、使用 `iou_token = nn.Embedding(1, transformer_dim)` 并访问 `iou_token.weight`
- 1、典型场景
- 2、示例代码:
- 3、特点
- 三、第一种方法在模型更新中会更新其值吗?
- 1、默认行为
- 2、示例代码:
- 3、控制权重更新的方法
- 方法 1:设置 `requires_grad = False`
- 方法 2:加载预训练权重并冻结
- 方法 3:在优化器中排除某些参数
- 四、总结
前言
在深度学习领域,特别是在自然语言处理(NLP)中,nn.Embedding
是一个非常重要的模块,用于将离散的词汇(如单词或标记)映射为连续的向量表示。本文详细讲解了 nn.Embedding
的使用方法、其权重是否会在模型更新过程中被更新的问题,以及如何控制这些权重是否参与训练。
一、直接使用 nn.Embedding
获得变量
1、典型场景
这种用法通常用于处理离散的词汇表(如单词、token等),将这些离散的 token 映射为连续的向量表示。例如,在 NLP 任务中,输入是一批句子或标记序列,每个标记都有一个唯一的索引(ID)。通过 nn.Embedding
,可以将这些索引映射为对应的词向量。
2、示例代码:
import torch
import torch.nn as nn# 假设词汇表大小为 10,每个词嵌入维度为 5
vocab_size = 10
embedding_dim = 5# 创建 Embedding 层
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 输入是一个批次的 token 索引
input_tokens = torch.tensor([2, 3, 5]) # 示例输入索引
embedded_vectors = embedding_layer(input_tokens) # 获取词向量print(embedded_vectors)
3、特点
nn.Embedding
是一个可训练的参数层。- 输入是离散的 token 索引,输出是对应的连续向量表示。
- 这种用法适用于需要批量处理 token 的场景,比如文本分类、机器翻译等任务。
二、使用 iou_token = nn.Embedding(1, transformer_dim)
并访问 iou_token.weight
1、典型场景
这种用法通常用于定义一些特殊的、全局共享的向量,而不是处理整个词汇表中的 token。常见的例子包括在目标检测任务中,定义一个可学习的 “特殊 token” 来表示某些特定的对象或区域(如 IoU 预测中的 token)。
2、示例代码:
import torch
import torch.nn as nn# 定义一个特殊的 token,维度为 transformer_dim
transformer_dim = 64
iou_token = nn.Embedding(num_embeddings=1, embedding_dim=transformer_dim)# 访问这个特殊 token 的权重
special_token_vector = iou_token.weight # 形状为 [1, transformer_dim]print("Special Token Vector:", special_token_vector)
3、特点
iou_token
是一个nn.Embedding
实例,但它的词汇表大小为 1(即只有一个 token)。iou_token.weight
是这个特殊 token 的实际值,形状为[1, embedding_dim]
。- 这种用法适用于需要定义一个可学习的、全局共享的向量的场景,而不是处理多个离散 token。
三、第一种方法在模型更新中会更新其值吗?
1、默认行为
默认情况下,nn.Embedding
的权重(即词向量)是模型的可训练参数,默认情况下会被优化器更新。
2、示例代码:
import torch
import torch.nn as nn
import torch.optim as optim# 创建一个 Embedding 层
vocab_size = 10
embedding_dim = 5
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 定义输入和目标
input_tokens = torch.tensor([2, 3, 5]) # 输入 token 索引
target = torch.randn(3, embedding_dim) # 假设的目标向量# 定义优化器
optimizer = optim.SGD(embedding_layer.parameters(), lr=0.01)# 前向传播
embedded_vectors = embedding_layer(input_tokens)# 计算损失
loss_fn = nn.MSELoss()
loss = loss_fn(embedded_vectors, target)# 反向传播和更新
optimizer.zero_grad()
loss.backward()
optimizer.step()# 查看更新后的权重
print("Updated Embedding Weights:", embedding_layer.weight)
3、控制权重更新的方法
有时我们希望固定某些权重,不让它们参与训练。这可以通过以下方式实现:
方法 1:设置 requires_grad = False
将 embedding_layer.weight.requires_grad
设置为 False
,可以阻止这些权重被更新。
embedding_layer.weight.requires_grad = False
方法 2:加载预训练权重并冻结
如果我们使用预训练的词向量(如 GloVe 或 Word2Vec),可以选择加载这些权重并冻结它们。
# 加载预训练权重
pretrained_weights = torch.load('glove_embeddings.pth')# 创建 Embedding 层并加载权重
embedding_layer = nn.Embedding.from_pretrained(pretrained_weights, freeze=True)
方法 3:在优化器中排除某些参数
我们可以在定义优化器时,排除某些参数,从而避免更新它们。
# 排除 embedding_layer 的权重
optimizer = optim.SGD([param for param in model.parameters() if param is not embedding_layer.weight],lr=0.01
)
四、总结
- 默认情况下,
nn.Embedding
的权重是可训练的,会在每次反向传播后被更新。 - 如果需要固定权重,可以通过设置
requires_grad = False
、使用from_pretrained
并设置freeze=True
或在优化器中排除这些参数来实现。 - 选择是否更新权重取决于任务需求:如果你希望模型从头学习词向量(如随机初始化的场景),让权重可训练;如果你使用预训练的词向量并希望保持它们不变,则固定权重。
相关文章:
深入理解 PyTorch 的 nn.Embedding:词向量映射及变量 weight 的更新机制
文章目录 前言一、直接使用 nn.Embedding 获得变量1、典型场景2、示例代码:3、特点 二、使用 iou_token nn.Embedding(1, transformer_dim) 并访问 iou_token.weight1、典型场景2、示例代码:3、特点 三、第一种方法在模型更新中会更新其值吗?…...
go语言内存泄漏的常见形式
go语言内存泄漏 子字符串导致的内存泄漏 使用自动垃圾回收的语言进行编程时,通常我们无需担心内存泄漏的问题,因为运行时会定期回收未使用的内存。但是如果你以为这样就完事大吉了,哪里就大错特措了。 因为,虽然go中并未对字符串…...
操作系统
操作系统 操作系统(OperatingSystem,OS)是指控制和管理整个计算机系统的硬件和软件资源,并合理地组织调度计算机的工作和资源的分配;以提供给用户和其他软件方便的接口和环境;它是计算机系统中最基本的系统…...
《JVM考古现场(十八):造化玉碟·用字节码重写因果律的九种方法》
"鸿蒙初判!当前因果链突破十一维屏障——全体码农修士注意,《JVM考古现场(十八)》即将渡劫飞升!" 目录 上卷阴阳交缠 第一章:混沌初开——JVM因果律的量子纠缠 第二章:诛仙剑阵改—…...
【2】k8s集群管理系列--包应用管理器之helm(Chart语法深入应用)
一、Chart模板:函数与管道 常用函数: • quote:将值转换为字符串,即加双引号 • default:设置默认值,如果获取的值为空则为默认值 • indent和nindent:缩进字符串 • toYaml:引用一…...
汇编获取二进制
mov_.S mov %r8d,0 nop执行命令: gcc -c mov_.S 会输出 mov_.o 文件:objdump -D mov_.o : mov_.o: 文件格式 elf64-x86-64Disassembly of section .text:0000000000000000 <.text>:0: 44 89 04 25 00 00 00 mov %r8d,0x0…...
《嵌套调用与链式访问:C语言中的函数调用技巧》
🚀个人主页:BabyZZの秘密日记 📖收入专栏:C语言 🌍文章目入 一、嵌套调用(一)定义(二)实现方式(三)优点(四)缺点 二、链式…...
txt、Csv、Excel、JSON、SQL文件读取(Python)
txt、Csv、Excel、JSON、SQL文件读取(Python) txt文件读写 创建一个txt文件 fopen(rtext.txt,r,encodingutf-8) sf.read() f.close() print(s)open( )是打开文件的方法 text.txt’文件名 在同一个文件夹下所以可以省略路径 如果不在同一个文件夹下 ‘…...
前端工程化之新晋打包工具
新晋打包工具 新晋打包工具前端模块工具的发展历程分类初版构建工具grunt使用场景 gulp采用管道机制任务化配置与api简洁 现代打包构建工具基石--webpack基于webpack改进的构建工具rollup 推荐举例说明package.jsonrollup.config.mjsmy-extract-css-rollup-plugin.mjssrc/index…...
Python语言介绍
Python 是一种高级、通用、解释型的编程语言,由 Guido van Rossum 于 1991 年首次发布。其设计哲学强调代码的可读性和简洁性。 Python通过简洁的语法和强大的生态系统,成为当今最受欢迎的编程语言之一。 一、核心特点 Python 是一种解释型、面向对象、…...
关于 Spring Boot 部署到 Docker 容器的详细说明,涵盖核心概念、配置步骤及关键命令,并附上表格总结
以下是关于 Spring Boot 部署到 Docker 容器的详细说明,涵盖核心概念、配置步骤及关键命令,并附上表格总结: 1. Docker 核心概念 概念描述关系镜像(Image)预定义的只读模板,包含运行环境和配置(…...
Tomcat 服务频繁崩溃的排查方法
# Tomcat 服务频繁崩溃排查方法 当Tomcat服务频繁崩溃时,可以按照以下步骤进行系统化排查: ## 1. 检查日志文件 **关键日志位置**: - catalina.out (标准输出和错误) - catalina.log (主日志) - localhost.log (应用相关日志) - host-mana…...
分布式系统-脑裂,redis的解决方案
感谢你的反馈!很高兴能帮到你。关于你提到的“脑裂”(split-brain),这是一个分布式系统中的常见术语,尤其在像 Redis Cluster 这样的高可用集群中会涉及。既然你问到了,我会从头解释“脑裂”的含义、Redis …...
MySQL InnoDB 索引与B+树面试题20道
1. B树和B+树的区别是什么? 数据存储位置: B树:所有节点(包括内部节点和叶子节点)均存储数据。 B+树:仅叶子节点存储数据,内部节点仅存储键值(索引)。 叶子节点结构: B+树:叶子节点通过双向链表连接,支持高效的范围查询。 查询稳定性: B+树:所有查询必须走到叶子…...
深入解析 Spring AI Alibaba 多模态对话模型:构建下一代智能应用的实践指南
一、多模态对话模型的技术演进 1.1 从单一文本到多模态交互 现代AI应用正经历从单一文本交互到多模态融合的革命性转变。根据Gartner预测,到2026年将有超过80%的企业应用集成多模态AI能力。Spring AI Alibaba 对话模型体系正是为这一趋势量身打造,其技…...
2025年ESWA SCI1区TOP:动态分类麻雀搜索算法DSSA,深度解析+性能实测
目录 1.摘要2.麻雀搜索算法SSA原理3.孤立微电网经济环境调度4.改进策略5.结果展示6.参考文献7.代码获取 1.摘要 污染物排放对环境造成负面影响,而可再生能源的不稳定性则威胁着微电网的安全运行。为了在保障电力供应可靠性的同时实现环境和经济目标的平衡ÿ…...
MySQL Error Log
MySQL Error Log Error Log 的开启Error Log 查看Error Log 滚动 MySQL Error Log MySQL主从复制:https://blog.csdn.net/a18792721831/article/details/146117935 MySQL Binlog:https://blog.csdn.net/a18792721831/article/details/146606305 MySQL Ge…...
让DeepSeek API支持联网搜索
引子 DeepSeek官网注册的API token是不支持联网搜索的,这导致它无法辅助分析一些最新的情况或是帮忙查一下互联网上的资料。本文从实战角度提供一种稳定可靠的方法使得DeepSeek R1支持联网搜索分析。 正文 首先登录火山方舟控制台,https://www.volcen…...
SQL 语句说明
目录 数据库和数据表什么是 SQL 语言数据操作语言(DML)1、SELECT 单表查询通过 WHERE 对原始数据进行筛选通过 聚合函数 获取汇总信息通过 ORDER BY 对结果排序通过 GROUP BY 对数据进行分组通过 HAVING 对分组结果进行筛选 2、SELECT 多表查询3、INSERT…...
PostgreSQL内幕探索—基础知识
PostgreSQL内幕探索—基础知识 PostgreSQL(以下简称PG) 起源于 1986 年加州大学伯克利分校的 POSTGRES 项目,最初以对象关系模型为核心,支持高级数据类型和复杂查询功能。 1996 年更名为 PostgreSQL 并开源,逐…...
Springboot项目正常启动,访问资源却出现404错误如何解决?
我在自己的springboot项目中的启动类上同时使用了SprinBootApplication和ComponentScan注解, 虽然项目能够正常启动,但是访问资源后,返回404错误,随后在启动类中输出bean,发现controller创建失败: 而后我将ComponentScan去掉后资源就能访问到了. 原因 SprinBootApplication本身…...
MaxPooling层的作用(通俗解释)
MaxPooling层的作用(通俗解释) MaxPooling层是卷积神经网络中非常重要的组成部分,它的主要作用可以用以下几个简单的比喻来理解: 1. 信息压缩器(降维作用) 就像把一张高清照片缩小尺寸一样,M…...
0.DockerCE起步之Linux相关【完善中】
ubuntu用户组&权限&文件/目录 服务启停操作 sudo systemctl start docker # 启动服务3,4 sudo systemctl stop docker # 停止服务 sudo systemctl restart docker ps top 以下内容参考 Vim编辑器 Linux系统常用命令 管理Linux实例软件源 Cron定时任务 在Linux系统上…...
树莓派Pico C/C++ OpenOCD调试环境搭建(Windows)
树莓派Pico C/C OpenOCD调试环境搭建(Windows) 参考资料和背景 从上次树莓派Pico C/C 开发环境搭建(一键完成版)后,一直想找个合适调试器,最后测试了多种方案,还是使用另一块树莓派pico作为picoprobe 来调试比较方便,其中参考的…...
【图像生成之21】融合了Transformer与Diffusion,Meta新作Transfusion实现图像与语言大一统
论文:Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model 地址:https://arxiv.org/abs/2408.11039 类型:理解与生成 Transfusion模型是一种将Transformer和Diffusion模型融合的多模态模型,旨…...
《人件》第二章 办公环境
二、办公环境 电话铃不停的响,打印机维修人员顺道过来聊聊天,复印机不工作了,人事部不停催促更新的能力调查表,下午3点之前就要提交时间表…然后一天就这样过去了。 2.1 家具警察 人们怎么使用空间、需要的桌子空间多大、花多少小…...
哈希表系列一>存在重复元素II 存在重复元素I
目录 题目:解析:存在重复元素 II-->代码:存在重复元素-->代码: 题目: 链接: link 链接: link 解析: 存在重复元素 II–>代码: class Solution {public boolean containsNearbyDuplic…...
文献总结:AAAI2025-UniV2X-End-to-end autonomous driving through V2X cooperation
UniV2X 一、文章基本信息二、文章背景三、UniV2X框架1. 车路协同自动驾驶问题定义2. 稀疏-密集混合形态数据3. 交叉视图数据融合(智能体融合)4. 交叉视图数据融合(车道融合)5. 交叉视图数据融合(占用融合)6…...
LeetCode --- 444 周赛
题目列表 3507. 移除最小数对使数组有序 I 3508. 设计路由器 3509. 最大化交错和为 K 的子序列乘积 3510. 移除最小数对使数组有序 II 一、移除最小数对使数组有序 I & II 由于数组是给定的,所以本题的操作步骤是固定的,我们只要能快速模拟操作的过…...
单片机Day05---静态数码管
目录 一、原理图:编辑 二、思路梳理: 三:一些说明: 1.点亮方式: 2.数组: 3.数字与段码对应: 四:程序实现: 一、原理图: 二、思路梳理: …...
kernel32!GetQueuedCompletionStatus函数分析之返回值得有效性
第一部分://#define STATUS_SUCCESS 0x0返回值为0 } else { // // Set the completion status, capture the completion // information, deallocate the associated IRP, and // attempt to write the…...
gazebo 启动卡死的解决方法汇总
1. 排查显卡驱动是否正常安装 nvidia-smi # 英伟达显卡--------------------------------------------------------------------------------------- | NVIDIA-SMI 535.230.02 Driver Version: 535.230.02 CUDA Version: 12.2 | |------------------------…...
硬件设计-MOS管快速关断的原因和原理
目录 简介: 来源: MOS管快关的原理 先简单介绍下快关的原理: 同电阻时为什么关断时间会更长 小结 简介: 本章主要介绍MOS快速关断的原理和原因。 来源: 有人会问,会什么要求快速关断,而…...
塔能科技解节能密码,工厂成本“效益方程式”精准破题
在全球积极推进可持续发展战略的当下,各行业都在努力探索节能减排、绿色发展的新路径,对于工厂而言,节能早已不是锦上添花的选择,而已成为关乎企业生死存亡与长远发展的核心要素,是实现可持续运营的必由之路。塔能科技…...
swift ui基础
一个朴实无华的目录 今日学习内容:1.三种布局(可以相互包裹)1.1 vstack(竖直):先写的在上面1.1 hstack(水平):先写的在左边1.1 zstack(前后)&…...
格式工厂 v5.18最新免安装绿色便携版
前言 用它来转视频的时候,还能顺便给那些有点小瑕疵的视频修修补补,保证转出来的视频质量杠杠的。更厉害的是,它不只是转换那么简单,还能帮你把PDF合并成一本小册子,视频也能合并成大片,还能随心所欲地裁剪…...
CSPM认证对项目论证的范式革新:从合规审查到价值创造的战略跃迁
引言 在数字化转型浪潮中,全球企业每年因项目论证缺陷导致的损失高达1.7万亿美元(Gartner 2023)。CSPM(Certified Strategic Project Manager)认证体系通过结构化方法论,将传统的项目可行性评估升级为战略…...
TcxCustomCheckComboBoxProperties.EditValueFormat 值说明
TcxCheckStatesValueFormat 类枚举复选框状态对 edit 值的可能解释。以下选项可用。 价值 意义 cvf字幕 编辑值是一个字符串,其中包含两个由分号分隔的子字符串。分号前的子字符串包含灰显项目的标题列表。分号后面的子字符串包含已选中项目的标题列表。请注意&a…...
Spring Boot 测试详解,包含maven引入依赖、测试业务层类、REST风格测试和Mock测试
Spring Boot 测试详解 1. 测试依赖引入 Spring Boot 默认通过以下 Maven 依赖引入测试工具: <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</s…...
【C语言】预处理(下)(C语言完结篇)
一、#和## 1、#运算符 这里的#是一个运算符,整个运算符会将宏的参数转换为字符串字面量,它仅可以出现在带参数的宏的替换列表中,我们可以将其理解为字符串化。 我们先看下面的一段代码: 第二个printf中是由两个字符串组成的&am…...
IIC通信协议
一、概述 IIC协议:是一种各种电子设备之间进行数据交换和通信的串行,半双工通信协议,主要用于近距离,低速的芯片之间的通信。 I2C协议采用双线结构传输数据,由一个数据线&#…...
SpringBoot原生实现分布式MapReduce计算(无第三方中间件版)
一、架构设计调整 核心组件替换方案: 注册中心 → 数据库注册表任务队列 → 数据库任务表分布式锁 → 数据库行级锁节点通信 → HTTP REST接口 二、数据库表结构设计 -- 节点注册表 CREATE TABLE compute_nodes (node_id VARCHAR(36) PRIMARY KEY,last_heartbea…...
02-libVLC的视频播放器:播放音视频文件以及网络流
libvlc_new(0, nullptr)功能:创建并初始化libVLC的核心实例,是使用所有libVLC功能的前提。 参数:第一个参数:参数数量(通常设为0)第二个参数:参数列表(通常为nullptr,表示使用默认配置)返回值:成功返回libvlc_instance_t*指针,失败返回nullptr。注意事项:可通过参…...
Autoware源码总结
Autoware源码网站 项目简介 教程 Autoware的整体架构如下图,主要包括传感器sensing、高精地图map data、车辆接口vehicle interface、感知perception(动态障碍物检测detection、跟踪tracking、预测prediction;交通信号灯检测detection、分类c…...
PowerBI 条形图显示数值和百分比
数据表: 三个度量值 销售额 SUM(销量表[销售量])//注意, 因为Y轴显示的产品,会被筛选,所以用ALLSELECTED来获取当前筛选条件下,Y轴显示的产品 百分比 FORMAT(DIVIDE([销售额],CALCULATE([销售额],ALLSELECTED(销量表[产品编码]))),"0…...
Sa-Token 自定义插件 —— SPI 机制讲解(一)
前言 博主在使用 Sa-Token 框架的过程中,越用越感叹框架设计的精妙。于是,最近在学习如何给 Sa-Token 贡献自定义框架。为 Sa-Token 的开源尽一份微不足道的力量。我将分三篇文章从 0 到 1 讲解如何为 Sa-Token 自定义一个插件,这一集将是前沿…...
基于 Termux 在移动端配置 Ubuntu 系统并搭建工作环境
本套方案主要参考了以下内容,并根据自身体验进行了修改。 【教程】用Termux搭建桌面级生产力环境Termux安装完整版Linux(Ubuntu)详细步骤 前言 自己的电脑太重,有时候外出不想带,平板生产力有有限。所以一直在折腾用平板替代电脑的事情。之前…...
JAVA SDK通过proxy对接google: GCS/FCM
前言:因为国内调用google相关api需要通过代理访问(不想设置全局代理),所以在代理这里经常遇到问题,先说一下结论 GCS 需要设置全局代理或自定义代理选择器, FCM sdk admin 在初始化firebaseApp时是支持设置的。 GCS: 开始时尝试在…...
JAVA EE_多线程-初阶(三)
我对未来没有底气 我也不知道当下该如何做 那就活着,活着就能把日子过下去 ---------陳長生. 1.多线程案例 1.1.单例模式 单例模式是常见的设计模式之一 设计模式:一些编程大佬制定的一些通用代码,再特定的场景下能套用进去,即…...
@PKU秋招互联网产品经理求职分享
从校园到职场 非常荣幸能够在毕业后两年半再次回到燕园。今天,我主要想和大家分享一下我在互联网行业的求职和工作经验。从最初面对职场的迷茫,到现在能够从容应对职场各种挑战,这一路走来积累了不少心得。互联网行业变化迅速,持续…...