多头注意力
Multi-Head Attention
-
论文地址
https://arxiv.org/pdf/1706.03762
多头注意力介绍
-
多头注意力是Transformer模型的关键创新,通过并行执行多个独立的注意力计算单元,使模型能够同时关注来自不同表示子空间的信息。每个注意力头学习不同的语义特征,最后通过线性变换将多头的输出组合为最终结果。
当n_heads=1时,多头注意力等价于标准缩放点积注意力。多头设计通过建立多个独立的"观察视角",使模型能够捕获更丰富的上下文信息。
数学公式
-
多头注意力通过以下三个步骤实现:
- 线性投影:对Q/K/V进行线性变换
- 并行注意力计算:在多个子空间中计算缩放点积注意力
- 输出融合:拼接多头结果并进行线性变换
公式表达为:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , . . . , h e a d h ) W O where h e a d i = Attention ( Q W i Q , K W i K , V W i V ) \text{MultiHead}(Q,K,V) = \text{Concat}(head_1,...,head_h)W^O \\ \text{where } head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)
其中:- W i Q ∈ R s e q _ l e n × d _ k W_i^Q \in \mathbb{R}^{seq\_len \times d\_k} WiQ∈Rseq_len×d_k:第i个头的查询投影矩阵
- W i K ∈ R s e q _ l e n × d _ k W_i^K \in \mathbb{R}^{seq\_len \times d\_k} WiK∈Rseq_len×d_k:第i个头的键投影矩阵
- W i V ∈ R s e q _ l e n × d _ k W_i^V \in \mathbb{R}^{seq\_len \times d\_k} WiV∈Rseq_len×d_k:第i个头的值投影矩阵
- W O ∈ R s e q _ l e n × d _ m o d e l W^O \in \mathbb{R}^{seq\_len \times d\_model} WO∈Rseq_len×d_model:输出投影矩阵
s e q _ l e n seq\_len seq_len 为序列长度, d _ m o d e l d\_model d_model 为embedding的维度, d _ k d\_k d_k 为每个注意力头的维度,假设头数量为n_heads ,且n_heads能被d_model给整数,则
d _ k = d _ m o d e l n _ h e a d s d\_k = \frac{d\_model}{n\_heads} d_k=n_headsd_model
代码实现
-
需要导入之前写的计算注意力分数的函数
https://blog.csdn.net/hbkybkzw/article/details/147462845
import torch from torch import nndef calculate_attention(query, key, value, mask=None):d_k = key.shape[-1]att_scaled = torch.matmul(query, key.transpose(-2, -1)) / d_k ** 0.5 # 缩放点积,注意区分这里的 d_k 是多头注意力中的“d_model”if mask is not None:att_scaled = att_scaled.masked_fill_(mask=mask, value=-1e9) # 或者value= -inf,这样在softmax的时候,掩码的部分就为0了att_softmax = torch.softmax(input=att_scaled, dim=-1) # softmaxatt_score = torch.matmul(att_softmax, value) # 注意力分数return att_score
-
多头注意力模块
class MultiHeadAttention(nn.Module):def __init__(self, n_heads, d_model, dropout_prob=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0 # 需要确保d_model一定要能被注意力头的数量整除# 定义四个全连接层(Q/K/V/Output)self.q_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.k_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.v_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.dropout = nn.Dropout(dropout_prob)self.n_heads = n_headsself.d_k = d_model // n_headsself.d_model = d_modeldef forward(self, q, k, v, mask=None):"""前向传播过程输入形状:[batch_size, seq_len, d_model]输出形状:[batch_size, seq_len, d_model]"""# 步骤1:线性投影并分割多头q = self.q_linear(q)q = q.reshape(q.shape[0], -1, self.n_heads, self.d_k) # [batch_size,seq_len,n_heads,d_k]q = q.transpose(1, 2) # [batch_size, n_heads, seq_len, d_k]# 等价于# q = self.q_linear(q).reshape(q.shape[0],-1,self.n_heads,self.d_k).transpose(1,2)k = self.k_linear(k)k = k.reshape(k.shape[0], -1, self.n_heads, self.d_k)k = k.transpose(1, 2)v = self.v_linear(v)v = v.reshape(v.shape[0], -1, self.n_heads, self.d_k)v = v.transpose(1, 2)# 步骤2:计算注意力(每个头独立计算)out = calculate_attention(q, k, v, mask) # [batch_size,seq_len,n_heads,d_k]# 步骤3:拼接多头结果并进行输出投影out = out.transpose(1, 2) # [batch_size,n_heads,seq_len,d_k]out = out.reshape(out.shape[0], -1, self.d_model) # [batch_size,seq_len,d_model]out = self.linear(out)out = self.dropout(out)return out
-
关键操作解析
维度变换流程
操作步骤 张量形状变化示例 输入数据 [batch_size, seq_len, d_model] 步骤1-线性投影(保持维度) [batch_size, seq_len, d_model] 步骤1-分割多头(n_heads) [batch_size, seq_len, n_heads, d_k] 步骤1-维度转置(交换头与序列轴) [batch_size, n_heads , seq_len, d_k] 步骤2-计算注意力(各头独立) [batch_size, n_heads , seq_len, d_k] 步骤3-维度转置(交换头与序列轴) [batch_size , seq_len, n_heads ,d_k] 步骤3-拼接多头(恢复原始维度) [batch_size, seq_len, d_model]
使用示例
-
测试代码
if __name__ == "__main__":# 实例化模块:8头注意力,512维模型,20% dropoutmultihead_attention = MultiHeadAttention(n_heads=8, d_model=512, dropout_prob=0.2)# 模拟输入:batch_size=4,序列长度100,维度512x = torch.randn(4, 100, 512)# 前向传播(自注意力模式)output = multihead_attention(x, x, x)print("输入形状:", x.shape) # torch.Size([4, 100, 512])print("输出形状:", output.shape) # torch.Size([4, 100, 512])print("输出范数:", torch.norm(output)) # 约1.2-1.8(取决于初始化)
相关文章:
多头注意力
Multi-Head Attention 论文地址 https://arxiv.org/pdf/1706.03762 多头注意力介绍 多头注意力是Transformer模型的关键创新,通过并行执行多个独立的注意力计算单元,使模型能够同时关注来自不同表示子空间的信息。每个注意力头学习不同的语义特征&#x…...
【leetcode100】目标和
1、题目描述 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联起所有整数,可以构造一个 表达式 : 例如,nums [2, 1] ,可以在 2 之前添加 ,在 1 之前添加 - …...
动态哈希映射深度指南:从基础到高阶实现与优化
哈希表是计算机科学中最高效的数据结构之一,而动态哈希映射通过智能扩容机制,在实时系统中展现出极强的适应性。本文将深入探讨其实现细节,结合主流框架源码解析,并给出可落地的性能优化方案。 一、动态哈希的数学本质 1. 哈希函…...
leetcode 2799. 统计完全子数组的数目 中等
给你一个由 正 整数组成的数组 nums 。 如果数组中的某个子数组满足下述条件,则称之为 完全子数组 : 子数组中 不同 元素的数目等于整个数组不同元素的数目。 返回数组中 完全子数组 的数目。 子数组 是数组中的一个连续非空序列。 示例 1ÿ…...
使用RabbitMQ实现判题功能
这次主要选用RabbitMQ消息队列来对判题服务和题目服务解耦,题目服务只需要向消息队列发送消息,判题服务从消息队列中取信息去执行判题,然后异步更新数据库即可。 五一宝宝请快点跑~~~~~ 先回顾一下RabbitMQ (1)引入依…...
无过拟合的记忆:分析大语言模型的训练动态
Kushal Tirumala⇤ Aram H. Markosyan⇤ Luke Zettlemoyer Armen Aghajanyan Meta AI 研究 {ktirumala,amarkos,lsz,armenag}fb.com 原文链接:[2210.09262] Physics-Driven Convolutional Autoencoder Approach for CFD Data Compressions 摘要 尽管超大语言模型…...
【Java面试笔记:进阶】16.synchronized底层如何实现?什么是锁的升级、降级?
在 Java 中,synchronized 关键字的底层实现依赖于 对象头(Object Header) 和 监视器锁(Monitor) 机制,并通过 锁的状态升级(Lock Escalation) 来优化同步性能。 1. synchronized 的底层实现 synchronized 的同步机制基于 Monitor 对象,它是同步的基本实现单元。 通过…...
python可视化:北方城市人口流动趋势分析1
python可视化:北方城市人口流动趋势分析1 斑点鱼在做销售数据分析时发现北京天津的同比下滑明显,客流下滑明显。而山东保定的客流同比上升。引起了斑点鱼对于北方人口流动的好奇。 所以本文将分析2025年北方地区(北京、天津、河北、山东、山西、辽宁等)…...
wps excel 常用操作
数据分列 对于有分隔规律的内容,可以通过分隔符将该内容进行分列 例如,以下字符串,可使用Excel对包含IP地址、数据库类型、环境、负责人和日期的字符串进行分列: 192.168.175.211-MySQL 数据库-DEV-李华-2025.06.30 将以上字符串…...
云智融合普惠大模型AI,政务服务重构数智化路径
2025年是“十四五”收官之年,数字政府和政务数智化作为“数字中国”建设的重点,已经取得了显著成效。根据《联合国电子政务调查报告2024》,我国电子政务发展指数全球排名第35位,与2022年相比提升8个名次;其中ÿ…...
全行业软件定制:APP/小程序/系统开发与物联网解决方案
在数字化浪潮席卷全球的今天,软件已经渗透到我们生活的方方面面,成为推动社会进步的重要力量。作为一家专注于专业软件定制开发的公司,哲科软件深知每一个行业、每一个企业都有其独特的需求和痛点。因此,我们致力于提供个性化软件…...
Java虚拟机(JVM)家族发展史及版本对比
Java虚拟机(JVM)家族发展史及版本对比 一、JVM家族发展史 1. 早期阶段(1996-2000) Classic VM(Java 1.0-1.1): 厂商:Sun Microsystems(Oracle前身)。特点&…...
电脑怎么强制退出程序回到桌面 详细操作步骤
电脑日常使用过程中,我们有时会遇到程序无响应或卡死的情况,这时需要采取措施强制关闭这些程序才能保持电脑的正常工作和运行。那么,电脑如何强制退出程序呢?其实方法有很多种,下面便为大家介绍几种电脑强制关闭程序的…...
蓝牙 LE:安全模式和程序说明(蓝牙中的网络安全)
在蓝牙低功耗 (BLE) 中,安全性是一个多方面的难题。了解 BLE 的三种主要安全模式以及五个关键的 BLE 安全程序。 毫无疑问,低功耗蓝牙 (BLE) 技术的迅猛发展为我们的生活带来了更多便利。然而,随着低功耗蓝牙设备的普及,人们对其安全性的担忧也日益加剧。 与普遍看法相反…...
低代码平台开发胎压监测APP
项目介绍 该项目是一个利用Flutter框架和蓝牙技术实现轮胎压力实时监测的应用。 主要功能如下: 用于接收蓝牙模块传输的胎压数据,并实时显示胎压值。APP对接收到的胎压数据进行处理,如单位转换、数据滤波等,然后将处理后的胎压值…...
GNOME扩展入门:日期时间
Getting Started | GNOME JavaScript 1.扩展路径 ~/.local/share/gnome-shell/extensions/ 2.新建文件夹 datetimesonichy 3.metadata.json {"uuid": "datetimesonichy","name": "datetime","description": "Dis…...
NLP高频面试题(五十二)——深度学习优化器详解
在深度学习的训练过程中,各种基于梯度的优化器肩负着寻找损失函数最优解的重任。最基础的梯度下降法通过沿着损失函数负梯度方向迭代更新参数,实现对模型参数的优化;而随机梯度下降(SGD)则以更高的计算效率和内存利用率在大规模数据集上大放异彩,但也因更新噪声大、易陷入…...
SLAM常用地图对比示例
序号地图类型概述1格栅地图将现实环境栅格化,每一个栅格用 0 和 1 分别表示空闲和占据状态,初始化为未知状态 0.52特征地图以点、线、面等几何特征来描绘周围环境,将采集的信息进行筛选和提取得到关键几何特征3拓扑地图将重要部分抽象为地图&…...
Web常见攻击方式及防御措施
一、常见Web攻击方式 1. 跨站脚本攻击(XSS) 攻击原理:攻击者向网页注入恶意脚本,在用户浏览器执行 存储型XSS:恶意脚本存储在服务器(如评论区) 反射型XSS:恶意脚本通过URL参数反射给用户 DOM型XSS&…...
java.lang.IllegalArgumentException: URI is not hierarchical报错
java.lang.IllegalArgumentException: URI is not hierarchical Thread.currentThread().getContextClassLoader("类的全路径").getClass().newInstance()一个类的静态块初始化异常了,后面调用这个类创建对象会报错吗? 是的,如果一…...
118. 杨辉三角
目录 一、问题描述 二、解题思路 三、代码 四、复杂度分析 一、问题描述 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 二、解题思路 每一行的第一个和最后一个元素是 1&…...
Anything V4/V5 模型汇总
二次元风格生成扩散模型-anything-v4.0Stable Diffusion anything-v5-PrtRE模型介绍及使用深度探索 Anything V5:安装与使用全攻略anything-v5x0.25少儿插画_v1xyn-ai/anything-v4.0...
网络原理 - 7(TCP - 4)
目录 6. 拥塞控制 7. 延时应答 8. 捎带应答 9. 面向字节流 10. 异常情况 总结: 6. 拥塞控制 虽然 TCP 有了滑动窗口这个大杀器,就能够高效可靠的发送大量的数据,但是如果在刚开始阶段就发送大量的数据,仍然可能引起大量的…...
探秘 FFmpeg 版本发展时间简史
前言 FFmpeg 是一套开源的计算机程序,主要用于记录、转换数字音频、视频,并能将其转化为流。它提供了录制、转换以及流化音视频的完整解决方案,在多媒体处理领域应用广泛。很多小伙伴们想系统的学习FFmpeg,还是有必要了解下FFmpeg的版本发展历史,感受它每次的版本迭代是如…...
5.3.1 MvvmLight以及CommunityToolkit.Mvvm介绍
MvvmLight、CommunityToolkit.Mvvm是开源包,他们为实现 MVVM(Model-View-ViewModel)模式提供了一系列实用的特性和工具,能帮助开发者更高效地构建 WPF、UWP、MAUI 等应用程序。 本文介绍如下: 一、使用(旧)的MvvmLight库 其特点如下,要继承的基类是ViewModelBase;且使用…...
PCB常见封装类型
1. 电阻、电容、电感封装 2. 二极管、三极管封 3. 排阻类器件(8脚、16脚)封装 4. SO类器件(间距有1.27、2.54mm等)封装 5. QFP类器件封装(四方扁平封装) 结构:引脚分布在封装的四个侧面&#…...
一键多环境构建——用 Hvigor 玩转 HarmonyOS Next
引言 在 HarmonyOS Next 的应用开发中,常常需要针对不同环境(测试、预发、线上)或不同签名(调试、正式)输出多个 APP/HAP 包。虽然 HarmonyOS 提供了多目标构建(Multi-Target Build)能力&#…...
SQLPandas刷题(LeetCode3451.查找无效的IP地址)
描述:LeetCode3451.查找无效的IP地址 表:logs ---------------------- | Column Name | Type | ---------------------- | log_id | int | | ip | varchar | | status_code | int | ---------------------- log_id 是这张表的唯…...
【leetcode100】组合总和Ⅳ
1、题目描述 给你一个由 不同 整数组成的数组 nums ,和一个目标整数 target 。请你从 nums 中找出并返回总和为 target 的元素组合的个数。 题目数据保证答案符合 32 位整数范围。 示例 1: 输入:nums [1,2,3], target 4 输出࿱…...
2020-06-23 暑期学习日更计划(机器学习入门之路(资源汇总)+概率论)
机器学习入门 前言 说实话,机器学习想学好真心不易,很多时候都感觉自己学得云里雾里。以前一段时间自己为了完成毕业设计,在机器学习的理论部分并没有深究,仅仅通过TensorFlow框架力求快速实现模型。现在来看,很多时候…...
Linux操作系统--基础I/O(上)
目录 1.回顾C文件接口 stdin、stdout、stderr 2.系统文件I/O 3.接口介绍 4.open函数返回值 5.文件描述符fd 5.1 0&1&2 1.回顾C文件接口 hello.c写文件 #include<stdio.h> #include<string.h>int main() {FILE *fp fopen("myfile","…...
Spring boot 中的IOC容器对Bean的管理
Spring Boot 中 IOC 容器对 Bean 的管理,涵盖从容器启动到 Bean 的生命周期管理的全流程。 步骤 1:理解 Spring Boot 的容器启动 Spring Boot 的 IOC 容器基于 ApplicationContext,在应用启动时自动初始化。 入口类:通过 SpringB…...
ARINC818协议一些说明综述
关键术语 航空总线技术 光纤通道层次架构 光纤通道拓扑结构 FC-AV协议,架构,容器系统 ARINC818协议,容器 ADVB帧映射,帧格式 机载视频处理系统对视频数据进行实时处理和记录。 分辨率:1080p,4k,8k视频技术 FC-AV技术是…...
Turso:一个基于 libSQL的分布式数据库
Turso 是一个完全托管的数据库平台,支持在一个组织中创建高达数十万个数据库,并且可以复制到任何地点,包括你自己的服务器,以实现微秒级的访问延迟。你可以通过Turso CLI(命令行界面)管理群组、数据库和API…...
2025.5.4机器学习笔记:PINN文献阅读
2025.5.4周报 文献阅读题目信息摘要创新点网络架构实验结论不足以及展望 文献阅读 题目信息 题目: Physics-Informed Neural Network Approach for Solving the One-Dimensional Unsteady Shallow-Water Equations in Riverine Systems期刊: Journal o…...
一行命令打开iOS模拟器
要在 Mac 命令行打开 iPhone 15 Pro 模拟器,需满足已安装 Xcode 这一前提条件,以下是具体操作步骤: 步骤一:列出所有可用模拟器设备 打开终端(Terminal),输入并执行以下命令,用于列…...
java面向对象编程【基础篇】之基础语法
目录 🚀前言🌟构造器💯案例 🤔this关键字💯使用this调用本类中的属性💯使用this调用构造器💯this表示当前对象 🦜封装💯合理隐藏💯合理暴露 🐧实体…...
跑MPS产生委外采购申请(成品)
问题:跑MPS产生委外采购申请(成品),更改BOM和跑MRP,但物料需求清单中无新增物料复合膜的需求。截图如下: 解决方法:更改委外采购申请的批准日期为BOM的生效日和重新展开bom。 重新展开后&#x…...
[flutter]切换国内源(window)
如题,切换到国内源避免总是连不上google导致卡住的问题。 临时切换到国内: cmd set PUB_HOSTED_URLhttps://pub.flutter-io.cn set FLUTTER_STORAGE_BASE_URLhttps://storage.flutter-io.cnpower shell $env:PUB_HOSTED_URL "https://pub.flut…...
学习海康VisionMaster之顶点检测
一:进一步学习了 今天学习下VisionMaster中的顶点检测:可检测图像指定区域内的顶点,并输出顶点坐标等信息。该模块常用于检测目标物体的顶点 二:开始学习 1:什么是顶点检测? 一个不是很规则的物体需要检测…...
Vue2中常用的核心函数(选项和生命周期钩子)的完整示例及总结
以下是Vue2中常用的核心函数(选项和生命周期钩子)的完整示例及总结: 1. 实例选项函数 data 初始化组件数据 new Vue({el: #app,data() {return {message: Hello Vue!};} });methods 定义组件方法 new Vue({el: #app,data() {return { c…...
数据集-目标检测系列- F35 战斗机 检测数据集 F35 plane >> DataBall
数据集-目标检测系列- F35 战斗机 检测数据集 F35 plane >> DataBall DataBall 助力快速掌握数据集的信息和使用方式。 贵在坚持! * 相关项目 1)数据集可视化项目:gitcode: https://gitcode.com/DataBall/DataBall-detections-100s…...
2025年3月AGI技术月评|技术突破重构数字世界底层逻辑
〔更多精彩AI内容,尽在 「魔方AI空间」 ,引领AIGC科技时代〕 本文作者:猫先生 ——当「无限照片」遇上「可控试穿」,我们正在见证怎样的智能革命? 被低估的进化:开源力量改写游戏规则 当巨头们在AGI赛道…...
【k8s】k8s是怎么实现自动扩缩的
Kubernetes 提供了多种自动扩缩容机制,主要包括 Pod 水平自动扩缩(HPA)、垂直 Pod 自动扩缩(VPA) 和 集群自动扩缩(Cluster Autoscaler)。以下是它们的实现原理和配置方法: 1. Pod …...
协作开发攻略:Git全面使用指南 — 引言
协作开发攻略:Git全面使用指南 — 引言 Git 是一种分布式版本控制系统,用于跟踪文件和目录的变更。它能帮助开发者有效管理代码版本,支持多人协作开发,方便代码合并与冲突解决,广泛应用于软件开发领域。 文中内容仅限技…...
【AI提示词】私人教练
提示说明 以专业且细致的方式帮助客户实现健康与健身目标,提升整体生活质量。 提示词 # Role: 私人教练## Profile - language: 中文 - description: 以专业且细致的方式帮助客户实现健康与健身目标,提升整体生活质量 - background: 具备丰富的健身经…...
【星海出品】Calico研究汇总
Calico项目由Tigera公司发起并主导开发 源码 https://github.com/projectcalico/calico?tabreadme-ov-file#-join-the-calico-community 简介 Tigera是一家专注于云原生安全的公司,于2016年成立,其核心产品包括开源的Calico项目以及商业版的Calico Ent…...
观成科技:摩诃草组织Spyder下载器流量特征分析
一、概述 自2023年以来,摩诃草组织频繁使用Spyder下载器下载远控木马,例如Remcos。观成安全研究团队对近几年的Spyder样本进行了深入研究,发现不同版本的样本在数据加密、流量模式等方面存在差异。基于此,我们对多个版本样本的通…...
中心极限定理(CLT)习题集 · 题目篇
中心极限定理(CLT)习题集 题目篇 共 18 题,覆盖经典 CLT、Lyapunov/Lindeberg 条件、Berry–Esseen 评估、 以及工程/数据科学应用与编程仿真。推荐先独立完成,再看《答案与解析篇》。 之前已经出过相关的知识点文章,…...
ITL和TTL线程间值的传递
InheritableThreadLocal InheritableThreadLocal 继承自 ThreadLocal,增加了父线程到子线程的值传递功能。当一个新线程被创建时,InheritableThreadLocal 会将父线程中 ThreadLocal 变量的值拷贝到子线程(浅拷贝),子线…...