从代码学习深度学习 - 多头注意力 PyTorch 版
文章目录
- 前言
- 一、多头注意力机制介绍
- 1.1 工作原理
- 1.2 优势
- 1.3 代码实现概述
- 二、代码解析
- 2.1 导入依赖
- 序列掩码函数
- 2.2 掩码 Softmax 函数
- 2.3 缩放点积注意力
- 2.4 张量转换函数
- 2.5 多头注意力模块
- 2.6 测试代码
- 总结
前言
在深度学习领域,注意力机制(Attention Mechanism)是自然语言处理(NLP)和计算机视觉(CV)等任务中的核心组件之一。特别是多头注意力(Multi-Head Attention),作为 Transformer 模型的基础,极大地提升了模型对复杂依赖关系的捕捉能力。本文通过分析一个完整的 PyTorch 实现,带你深入理解多头注意力的原理和代码实现。我们将从代码入手,逐步解析每个函数和类的功能,结合文字说明,让你不仅能运行代码,还能理解其背后的设计逻辑。无论你是初学者还是有一定经验的开发者,这篇博客都将帮助你更直观地掌握多头注意力机制。
完整代码:下载链接
一、多头注意力机制介绍
多头注意力(Multi-Head Attention)是 Transformer 模型的核心组件之一,广泛应用于自然语言处理(NLP)、计算机视觉(CV)等领域。它通过并行运行多个注意力头(Attention Heads),允许模型同时关注输入序列中的不同部分,从而捕捉更丰富的语义和上下文依赖关系。相比单一的注意力机制,多头注意力极大地增强了模型的表达能力,能够处理复杂的模式和长距离依赖。
1.1 工作原理
多头注意力的核心思想是将输入的查询(Queries)、键(Keys)和值(Values)通过线性变换映射到多个子空间,每个子空间由一个独立的注意力头处理。具体步骤如下:
- 线性变换:对输入的查询、键和值分别应用线性层,将其映射到隐藏维度(
num_hiddens
),并分割为多个头的表示。 - 缩放点积注意力:每个注意力头独立计算缩放点积注意力(Scaled Dot-Product Attention),即通过查询和键的点积计算注意力分数,再与值加权求和。
- 并行计算:多个注意力头并行运行,每个头关注输入的不同方面,生成各自的输出。
- 合并与变换:将所有头的输出拼接起来,并通过一个线性层融合,得到最终的多头注意力输出。
这种设计允许模型在不同子空间中学习不同的特征,例如在 NLP 任务中,一个头可能关注句法结构,另一个头可能关注语义关系。
1.2 优势
- 多样性:多头机制使模型能够从多个角度理解输入,捕捉多样化的模式。
- 并行性:多头计算可以高效并行化,提升计算效率。
- 稳定性:通过缩放点积(除以特征维度的平方根),缓解了高维点积导致的数值不稳定问题。
1.3 代码实现概述
在本文的实现中,我们使用 PyTorch 构建了一个完整的多头注意力模块,包含以下关键部分:
- 序列掩码:处理变长序列,屏蔽无效位置。
- 缩放点积注意力:实现单个注意力头的计算逻辑。
- 张量转换:通过
transpose_qkv
和transpose_output
函数实现多头分割与合并。 - 多头注意力类:整合所有组件,完成并行计算和输出融合。
接下来的代码解析将详细展示这些部分的实现,帮助你从代码层面深入理解多头注意力的每一步计算逻辑。
二、代码解析
以下是代码的完整实现和详细解析,代码按照 Jupyter Notebook(在最开始给出了完整代码下载链接) 的结构组织,并附上文字说明,帮助你理解每个部分的逻辑。
2.1 导入依赖
首先,我们导入必要的 Python 包,包括数学运算库 math
和 PyTorch 的核心模块 torch
和 nn
。
# 导入包
import math
import torch
from torch import nn
- math:用于计算缩放点积注意力中的归一化因子(即特征维度的平方根)。
- torch:PyTorch 的核心库,提供张量运算和自动求导功能。
- nn:PyTorch 的神经网络模块,包含
nn.Module
和nn.Linear
等工具,用于构建神经网络层。
序列掩码函数
在处理序列数据(如句子)时,不同序列的长度可能不同,我们需要通过掩码(Mask)来屏蔽无效位置,防止模型关注这些填充区域。以下是 sequence_mask
函数的实现:
def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项,使超出有效长度的位置被设置为指定值参数:X: 输入张量,形状 (batch_size, 最大序列长度, 特征维度) 或 (batch_size, 最大序列长度)valid_len: 有效长度张量,形状 (batch_size,),表示每个序列的有效长度value: 屏蔽值,标量,默认值为 0,用于填充无效位置返回:输出张量,形状与输入 X 相同,无效位置被设置为 value"""maxlen = X.size(1) # 最大序列长度,标量# 创建掩码,形状 (1, 最大序列长度),与 valid_len 比较生成布尔张量,形状 (batch_size, 最大序列长度)mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]# 将掩码取反后,X 的无效位置被设置为 valueX[~mask] = valuereturn X
解析:
- 输入:
X
:输入张量,通常是序列数据,可能包含填充(padding)部分。valid_len
:每个样本的有效长度,例如[3, 2]
表示第一个样本有 3 个有效 token,第二个样本有 2 个。value
:用于填充无效位置的值,默认为 0。
- 逻辑:
maxlen
获取序列的最大长度(即张量的第二维)。torch.arange(maxlen)
创建一个从 0 到maxlen-1
的序列,形状为(1, maxlen)
。- 通过广播机制,与
valid_len
(形状(batch_size, 1)
)比较,生成布尔掩码mask
,形状为(batch_size, maxlen)
。 mask
表示哪些位置是有效的(True),哪些是无效的(False)。- 使用
~mask
选择无效位置,将其值设置为value
。
- 输出:修改后的张量
X
,无效位置被设置为value
,形状不变。
作用:该函数用于在注意力计算中屏蔽填充区域,确保模型只关注有效 token。
2.2 掩码 Softmax 函数
在注意力机制中,我们需要对注意力分数应用 Softmax 操作,将其转换为概率分布。但由于序列长度不同,需要屏蔽无效位置的贡献。以下是 masked_softmax
函数的实现:
import torch
import torch.nn.functional as Fdef masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作,忽略无效位置参数:X: 输入张量,形状 (batch_size, 查询个数, 键-值对个数),3D张量valid_lens: 有效长度张量,形状 (batch_size,) 或 (batch_size, 查询个数),1D或2D张量,表示每个序列的有效长度,即每个查询可以参考的有效键值对长度返回:输出张量,形状 (batch_size, 查询个数, 键-值对个数),softmax后的注意力权重"""if valid_lens is None:# 如果没有有效长度,直接在最后一个轴上应用softmaxreturn F.softmax(X, dim=-1)shape
相关文章:
从代码学习深度学习 - 多头注意力 PyTorch 版
文章目录 前言一、多头注意力机制介绍1.1 工作原理1.2 优势1.3 代码实现概述二、代码解析2.1 导入依赖序列掩码函数2.2 掩码 Softmax 函数2.3 缩放点积注意力2.4 张量转换函数2.5 多头注意力模块2.6 测试代码总结前言 在深度学习领域,注意力机制(Attention Mechanism)是自然…...
通过扣子平台工作流将数据写入飞书多维表格
1. 进入扣子平台,并创建工作流扣子 扣子是新一代 AI 大模型智能体开发平台。整合了插件、长短期记忆、工作流、卡片等丰富能力,扣子能帮你低门槛、快速搭建个性化或具备商业价值的智能体,并发布到豆包、飞书等各个平台。https://www.coze.cn…...
python专题2-----用python生成多位,值均是数字的随机数
有很多方法可以用 Python 生成 多位随机数。我将向您介绍几个常用的方法,并解释它们的优缺点(此处以4位随机数为例): 1. 使用 random.randint() 这是最简单直接的方法: import randomrandom_number random.randint…...
Mybatis的简单介绍
文章目录 MyBatis 简介 1. MyBatis 核心特点2. MyBatis 核心组件3. MyBatis 基本使用示例(1) 依赖引入(Maven)(2) 定义 Mapper 接口(3) 定义实体类(4) 在 Service 层调用 4. MyBatis 与 JPA/Hibernate 对比 MyBatis 简介 MyBatis 是一款优秀的 持久层框…...
山东大学软件学院创新项目实训(11)之springboot+vue项目接入deepseekAPI
因为该阶段是前后端搭建阶段,所以没有进大模型的专项训练,所以先用老师给的deepseek接口进行代替 且因为前端设计部分非本人负责且还没有提交到github上,所以目前只能先编写一个简易的界面进行功能的测试 首先进行创建model类 然后创建Cha…...
Qt绘图事件
目录 1.绘图事件 1.1绘图事件 1.2声明一个画家对象 2.画线、画圆、画矩形、画文字 2.1画线 编辑 2.2画圆 2.3画矩形 2.4画文字 3.设置画笔 3.1设置画笔颜色 3.2设置画笔宽度 3.3设置画笔风格 4.设置画刷 4.1填色 4.2设置画刷风格 5.绘图高级设置 5.1设置抗锯…...
Linux 内核 BUG: Android 手机 USB 网络共享 故障
众所周知, 窝日常使用 ArchLinux 操作系统, 而 ArchLinux 是一个滚动发行版本, 也就是各个软件包更新很快. 然而, 突然发现, Android 手机的 USB 网络共享功能 BUG 了. 经过一通排查, 发现是 Linux 内核造成的 BUG. 哎, 没办法, 只能自己动手修改内核代码, 修复 BUG 了. 本文…...
程序化广告行业(82/89):解锁行业术语,开启专业交流之门
程序化广告行业(82/89):解锁行业术语,开启专业交流之门 在程序化广告这个充满活力与挑战的行业里,持续学习是我们不断进步的动力源泉。一直以来,我都期望能和大家一起深入探索这个领域,共同成长…...
Linux的网络配置的资料
目前有两种方式,network和NetworkManager。 network方式是在CentOS 6及更早版本中引入的配置方法,支持使用配置文件的方式管理网卡的配置。 NetworkManager是在CentOS 7及后续的版本中使用的配置方法,支持使用命令行和图形化界面的方式来管理…...
linux: 文件描述符fd
目录 1.C语言文件操作复习 2.底层的系统调用接口 3.文件描述符的分配规则 4.重定向 1.C语言文件操作复习 文件 内容 属性。所有对文件的操作有两部分:a.对内容的操作;b.对属性的操作。内容是数据,属性其实也是数据-存储文件,…...
每天学一个 Linux 命令(16):mkdir
每天学一个 Linux 命令(16):mkdir 命令简介 mkdir(Make Directory)是 Linux 和类 Unix 系统中用于创建新目录的基础命令。它允许用户快速创建单个目录、多级嵌套目录,并能灵活设置目录权限。 主要用途 创建单个目录:快速生成新的空目录。递归创建多级目录:自动生成缺…...
Java微服务注册中心深度解析:环境隔离、分级模型与Eureka/Nacos对比
在微服务架构体系中,注册中心如同神经系统般承担着服务发现与健康管理的核心职能。本文将从生产环境实践出发,系统剖析注册中心的环境隔离策略、分级部署模型,并通过Eureka与Nacos两大主流组件的全方位对比,帮助开发者构建高可用服…...
c++:new关键字
目录 基本语法 使用举例 应用场景 使用 new 时的注意事项 基本语法 Type* ptr new Type;Type 是你要创建的类型(可以是基本类型、结构体、类等) new Type 表示在堆上创建一个 Type 类型的对象 ptr 是一个指针,指向这个对象 使用举例 …...
理解 MCP 协议的数据传递:HTTP 之上的一层“壳子
以下是以 CSDN 博客的风格记录你对 MCP 协议数据传递的理解和发现,内容涵盖了 MCP 协议基于 HTTP 的本质、JSON-RPC 的“壳子”作用,以及为什么熟悉 HTTP 协议就足以理解 MCP 的数据传递。文章面向技术社区,结构清晰,适合分享。 理…...
word中的mathtype公式编辑时,大(中)括号会存在很大的空白
如下所示,公式编辑的时候发现总会多一个空白,怎么删也删不掉 这主要是公式的分隔符问题,选择:“格式”-“分隔符对齐”,选择第三个可以消除下面的空白 点击“确认”,效果如下所示:...
【Java】查看当前 Java 使用的垃圾回收器
一、查询 Code import java.lang.management.GarbageCollectorMXBean; import java.lang.management.ManagementFactory; import java.util.Arrays; import java.util.List;public class GCTypeDetector {public static void main(String[] args) {List<GarbageCollectorMX…...
Linux编程c/c++程序
前言 我们Windows系统下的idea可以说是非常智能了,集成了各种开发工具,包括并不限于编辑器/编译器/调试器/自动化构建工具/版本控制工具……而在Linux系统中,每个组件之间是相互独立的,每个组件各司其职,共同协作完成…...
【前端网络请求入门】XMLHttpRequest与Fetch保姆级教程
新手学前端时,经常会被「如何让网页和服务器说话」难住。今天我们用最通俗的语言,把浏览器最常用的两种网络请求方式——XMLHttpRequest和Fetch讲清楚,还会带完整的代码示例,跟着敲一遍就能上手! 一、先搞懂「网络请求…...
redis单机安装
redis单机安装 下载地址 官网:https://redis.io/下载列表页面:https://download.redis.io/releases/ 说明 版本选择:redis-7.0.0.tar.gz下载地址:https://download.redis.io/releases/redis-7.0.0.tar.gz 安装前准备 在linu…...
从零手写RPC-version0
参考文档 https://github.com/he2121/MyRPCFromZero Version-0 0、写项目第一步,先添加远程仓库 先在 github 上新建仓库,然后将本地新建的项目推送到远程仓库中 由于网上很多教程,所以本节不再赘述(可以让 chatGPT给出一个完…...
MOM成功实施分享(八)汽车活塞生产制造MOM建设方案(第二部分)
在制造业数字化转型的浪潮中,方案对活塞积极探索,通过实施一系列数字化举措,在生产管理、供应链协同、质量控制等多个方面取得显著成效,为行业提供了优秀范例。 1.转型背景与目标:活塞在数字化转型前面临诸多挑战&…...
二、Android Studio环境安装
一、下载安装 下载 Android Studio 和应用工具 - Android 开发者 | Android DevelopersAndroid Studio 提供了一些应用构建器以及一个已针对 Android 应用进行优化的集成式开发环境 (IDE)。立即下载 Android Studio。https://developer.android.google.cn/studio?hlzh-c…...
构件与中间件技术:概念、复用、分类及标准全解析
以下是对构件与中间件技术相关内容更详细的介绍: 一、构件与中间件技术的概念 1.构件技术 定义:构件是具有特定功能、可独立部署和替换的软件模块,它遵循一定的规范和接口标准,能够在不同的软件系统中被复用。构件技术就是以构…...
亲手打造可视化故事线管理工具:开发全流程、难点突破与开发过程经验总结
亲手打造可视化故事线管理工具:开发全流程、难点突破与开发过程经验总结 作为还没入门的业余编程爱好者,奋战了2天,借助AI开发一款FLASK小工具,功能还在完善中(时间轴可以跟随关联图缩放,加了一个用C键控制…...
CSS 字体学习笔记
在网页设计中,字体的使用对于提升用户体验和页面美观性至关重要。CSS 提供了一系列字体属性,用于控制文本的显示效果。以下是对 CSS 字体属性的详细学习笔记。 一、字体系列(font-family) 1. 字体系列的分类 在 CSS 中…...
通过 spring ai 对接 deepseek ai 大模型通过向量数据库,完成 AI 写标书及构建知识库功能的设计与开发
AI写标书及知识库构建详细设计方案 一、系统架构设计 +-------------------+ +-------------------+ +-------------------+ | 用户交互层 | | AI服务层 | | 知识库存储层 | | (Web/API) |---->| (Spring AI) |---…...
cropperjs 2.0裁剪图片后转base64提示“Tainted canvases may not be exported”跨域问题的解决办法。
目录 为什么会有这边文章 辛酸历程,不看也罢 想解决问题,看这里就够了 问题已解决,后边还是废话 为什么会有这边文章 最近,做一个项目需要用在前端实现图片裁剪功能,毋庸置疑,cropperjs是不二选择。当在…...
2、JSX:魔法世界的通行证——用魔法符号编织动态界面
一、开篇:魔法符号的觉醒 "看呐,赫敏!这根魔杖(React组件)为何能自动绘制出动态界面?"年轻的巫师学徒罗恩指着闪烁的屏幕惊呼。 "这就是JSX魔法阵的威力,"邓布利多校长挥舞…...
八大排序算法
目录 八大排序算法排序算法的稳定性比较排序插入排序直接插入排序希尔排序希尔排序的时间复杂度计算 选择排序直接选择排序堆排序 交换排序冒泡排序快速排序递归hoare版本挖坑法lomuto前后指针 非递归 归并排序排序性能对比 非比较排序计数排序 比较排序算法总结 八大排序算法 …...
搭建一个Spring Boot聚合项目
1. 创建父项目 打开IntelliJ IDEA,选择 New Project。 在创建向导中选择 Maven,确保选中 Create from archetype,选择 org.apache.maven.archetypes:maven-archetype-quickstart。 填写项目信息: GroupId:com.exampl…...
Google A2A协议解析:构建分布式异构多Agent系统
一、A2A 是什么?有什么用? 1.1 A2A 是什么? A2A(Agent-to-Agent Protocol)是Google最新推出的一项开源协议,旨在为AI智能体(Agents)提供标准化的通信方式。它允许不同框架…...
【Android读书笔记】读书笔记记录
文章目录 一. Android开发艺术探索1. Activity的生命周期和启动模式1.1 生命周期全面分析 一. Android开发艺术探索 1. Activity的生命周期和启动模式 1.1 生命周期全面分析 onPause和onStop onPause后会快速调用onStop,极端条件下直接调用onResume 当用户打开新…...
支持selenium的chrome driver更新到135.0.7049.84
最近chrome释放新版本:135.0.7049.84 如果运行selenium自动化测试出现以下问题,是需要升级chromedriver才可以解决的。 selenium.common.exceptions.SessionNotCreatedException: Message: session not created: This version of ChromeDriver only su…...
【玩泰山派】MISC(杂项)- 使用vscode远程连接泰山派进行开发
文章目录 前言流程1、安装、启动sshd2、配置一下允许root登录3、vscode中配置1、安装remote插件2、登录 **注意** 前言 有时候要在开发板中写一写代码,直接在终端中使用vim这种工具有时候也不是很方便。这里准备使用vscode去通过ssh远程连接泰山派去操作࿰…...
利用阿里云企业邮箱服务实现Python群发邮件
目录 一、阿里云企业邮箱群发邮件全流程实现 1. 准备工作与环境配置 2. 收件人列表管理 3. 邮件内容构建 4. 附件添加实现 5. 邮件发送核心逻辑 二、开发过程中遇到的问题与解决方案 1. 附件发送失败问题 2. 中文文件名乱码问题 3. 企业邮箱认证失败 三、完整工作流…...
中文编码,GB系列,UTF
图片来源:https://zhuanlan.zhihu.com/p/701690894 文章目录 ASCIIGB系列编码UTF编码 ASCII American Standard Code for Information Interchange 一个字节,但其实只用了一半: 128个字符 GB系列编码 “国标” 和ASCII是兼容的。 GB2312…...
车载以太网-TLS
文章目录 车载以太网与TLS的技术背景核心定位车载以太网TLS的技术架构车载TLS的核心安全机制TLS报文结构详解TLS工作机制密钥交换与计算流程标题完整握手流程(1-RTT)数据传输加密流程车载TLS的独特优化策略车载TLS的安全威胁相关标准车载以太网TLS(Transport Layer Security…...
【大英赛】大英赛准备笔记
听力 总结 提醒专注 一题一个听力时,听是重点 抓紧时间往后审题 比较容易的部分:secA & secD中的dictation,在大致审当前的基础上,分别利用这个时间提前看后面的secB√& summery secA 听之前应当大致审选项&#x…...
有序数组的平方
暴力排序 每个数平方以后排个序 class Solution { public:vector<int> sortedSquares(vector<int>& nums) {int slow0,fast0;int nnums.size();while(fast<n){nums[slow]nums[fast]*nums[fast];fast;slow;}sort(nums.begin(),nums.end());return nums;} }…...
Python基于Django的房屋信息可视化及价格预测系统(附源码,文档说明)
博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇dz…...
【5G-A学习】ISAC通信感知一体化学习小记
通信感知一体化(Integrated Sensing and Communication, ISAC)是一种将无线通信与环境感知功能深度融合的技术,通过共享硬件、频谱和信号处理流程,实现通信与感知的协同增效。其核心原理及无人机与飞鸟的识别方式如下:…...
深入解析@Validated注解:Spring 验证机制的核心工具
一、注解出处与核心定位 1. 注解来源 • 所属框架:Validated 是 Spring Framework 提供的注解(org.springframework.validation.annotation 包下)。 • 核心定位: 作为 Spring 对 JSR-380(Bean Validation 2.0&#…...
学生考勤管理系统(jsp+ssh+mysql5.x)含文档
学生考勤管理系统(jspsshmysql5.x)含万字详细文档 学生考勤管理系统是一个用于管理学生出勤和请假的系统,系统登录页面提供账号和密码输入框,用户可以选择角色进行登录。系统主菜单包括班级管理、用户管理、课程表管理和考勤情况…...
【响应式编程】Reactor 常用操作符与使用指南
文章目录 一、创建操作符1. just —— 创建包含指定元素的流2. fromIterable —— 从集合创建 Flux3. empty —— 创建空的 Flux 或 Mono4. fromArray —— 从数组创建 Flux5. fromStream —— 从 Java 8 Stream 创建 Flux6. create —— 使用 FluxSink 手动发射元素7. generat…...
为什么我们需要if __name__ == __main__:
[目录] 0.前言 1.什么是 __name__? 2.if __name__ __main__: 的作用 3.为何Windows更需if __name__ ?前言 if __name__ __main__: 是 Python 中一个非常重要的惯用法,尤其在使用 multiprocessing 模块或编写可导入的模块时。它的作用是区分…...
Week 1: Time Complexity, Rectangle Geometry
问题集 Square PastureBucket BrigadeBlocked BillboardBlocked Billboard IIWord ProcessorDo You Know Your ABCs?The Cow-SignalD3C - White Sheet 视频解析 Square Pasture Bucket Brigade Blocked Billboard Blocked Billboard II Word Processor Do You Know Your AB…...
论文学习:《通过基于元学习的图变换探索冷启动场景下的药物-靶标相互作用预测》
原文标题:Exploring drug-target interaction prediction on cold-start scenarios via meta-learning-based graph transformer 原文链接:https://www.sciencedirect.com/science/article/pii/S1046202324002470 药物-靶点相互作用(DTI&…...
STM32 HAL库 OLED驱动实现
一、概述 1.1 OLED 显示屏简介 OLED(Organic Light - Emitting Diode)即有机发光二极管,与传统的 LCD 显示屏相比,OLED 具有自发光、视角广、响应速度快、对比度高、功耗低等优点。在嵌入式系统中,OLED 显示屏常被用…...
UE5蓝图之间的通信------接口
一、创建蓝图接口 二、双击创建的蓝图接口,添加函数,并重命名新函数。 三、在一个蓝图(如玩家角色蓝图)中实现接口,如下图: 步骤一:点击类设置 步骤二:在细节面板已经实现的接口中…...
封装Tcp Socket
封装Tcp Socket 0. 前言1. Socket.hpp2. 简单的使用介绍 0. 前言 本文中用到的Log.hpp在笔者的历史文章中都有涉及,这里就不再粘贴源码了,学习地址如下:https://blog.csdn.net/weixin_73870552/article/details/145434855?spm1001.2014.3001…...