从代码学习深度学习 - Transformer PyTorch 版
文章目录
- 前言
- 1. 位置编码(Positional Encoding)
- 2. 多头注意力机制(Multi-Head Attention)
- 3. 前馈网络与残差连接(Position-Wise FFN & AddNorm)
- 3.1 基于位置的前馈网络(PositionWiseFFN)
- 3.2 残差连接和层规范化(AddNorm)
- 4. 编码器(Encoder)
- 4.1 编码器块(EncoderBlock)
- 4.2 Transformer 编码器(TransformerEncoder)
- 5. 解码器(Decoder)
- 5.1 解码器块(DecoderBlock)
- 5.2 Transformer 解码器(TransformerDecoder)
- 6. 完整 Transformer 模型
- 使用示例
- 总结
前言
Transformer 模型自 2017 年在论文《Attention is All You Need》中提出以来,彻底改变了自然语言处理(NLP)领域,并在计算机视觉等其他领域展现了强大的潜力。与传统的 RNN 和 LSTM 相比,Transformer 通过自注意力机制(Self-Attention)实现了并行计算,极大地提高了训练效率和模型性能。本博客将通过 PyTorch 实现的 Transformer 模型代码,深入剖析其核心组件,包括多头注意力机制、位置编码、编码器和解码器等。我们将结合代码和文字说明,逐步拆解 Transformer 的实现逻辑,帮助读者从代码层面理解这一经典模型的精髓。
本文基于提供的代码文件(PE.py
、EnDecoder.py
、MHA.py
和 Transformer.ipynb
),完整呈现 Transformer 的 PyTorch 实现,并通过清晰的目录结构和代码注释,带领大家从零开始学习 Transformer 的构建过程。关于训练和可视化部分,这里忽略掉,但是你仍然可以在下面的链接里找到所有的源代码,其中提供了丰富的注释。无论你是深度学习初学者还是希望深入理解 Transformer 的开发者,这篇博客都将为你提供一个清晰的学习路径。
完整代码:下载链接
1. 位置编码(Positional Encoding)
Transformer 的自注意力机制不包含序列的位置信息,因此需要通过位置编码(Positional Encoding)为每个词元添加位置信息。以下是 PE.py
中实现的位置编码类,它通过正弦和余弦函数生成固定位置编码。
import torch
import torch.nn as nnclass PositionalEncoding(nn.Module):"""位置编码在Transformer中,由于自注意力机制不含位置信息,需要额外添加位置编码在位置嵌入矩阵P中,行代表词元在序列中的位置,列代表位置编码的不同维度"""def __init__(self, num_hiddens, dropout, max_len=1000):"""初始化位置编码参数:num_hiddens (int): 隐藏层维度,即位置编码的维度dropout (float): dropout概率max_len (int, 可选): 最大序列长度,默认为1000"""super(PositionalEncoding, self).__init__()# 初始化丢弃层self.dropout = nn.Dropout(dropout)# 创建位置编码矩阵P,形状为(1, max_len, num_hiddens)self.P = torch.zeros((1, max_len, num_hiddens))# 计算位置编码的正弦和余弦函数输入# X形状: (max_len, num_hiddens/2)X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)# 偶数维度赋值正弦,奇数维度赋值余弦self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):"""前向传播参数:X (torch.Tensor): 输入张量,形状为(batch_size, seq_len, embed_dim)返回:torch.Tensor: 添加位置编码后的张量,形状为(batch_size, seq_len, embed_dim)"""# 将位置编码加到输入X上,截取与X长度匹配的部分X = X + self.P[:, :X.shape[1], :].to(X.device)# 应用丢弃并返回结果return self.dropout(X)
代码解析:
- 初始化:
PositionalEncoding
类根据隐藏层维度(num_hiddens
)和最大序列长度(max_len
)生成一个位置编码矩阵P
。该矩阵的每一行表示一个位置,每一列对应一个编码维度。 - 正弦和余弦编码:通过正弦(
sin
)和余弦(cos
)函数为不同位置和维度生成编码值,公式为:
P E ( p o s , 2 i ) = sin ( p o s 1000 0 2 i / d ) , P E ( p o s , 2 i + 1 ) = cos ( p o s 1000 0 2 i / d ) PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i)=sin(100002i/dpos),PE(pos,2i+1)=cos(100002i/dpos)
其中pos
是位置索引,i
是维度索引,d
是隐藏层维度。 - 前向传播:将输入张量
X
与位置编码矩阵P
相加,并应用 dropout 以增强模型的鲁棒性。
位置编码的作用是将序列的位置信息嵌入到词嵌入中,使得 Transformer 能够区分相同词元在不同位置的语义。
2. 多头注意力机制(Multi-Head Attention)
多头注意力机制是 Transformer 的核心组件,允许模型并行计算多个注意力头,从而捕获序列中不同方面的依赖关系。以下是 MHA.py
中实现的多头注意力机制。
import math
import torch
from torch import nn
import torch.nn.functional as Fdef sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项,使超出有效长度的位置被设置为指定值"""maxlen = X.size(1)mask = torch.arange(maxlen, dtype=torch
相关文章:
从代码学习深度学习 - Transformer PyTorch 版
文章目录 前言1. 位置编码(Positional Encoding)2. 多头注意力机制(Multi-Head Attention)3. 前馈网络与残差连接(Position-Wise FFN & AddNorm)3.1 基于位置的前馈网络(PositionWiseFFN)3.2 残差连接和层规范化(AddNorm)4. 编码器(Encoder)4.1 编码器块(Enco…...
多模态大语言模型arxiv论文略读(二十五)
ManipLLM: Embodied Multimodal Large Language Model for Object-Centric Robotic Manipulation ➡️ 论文标题:ManipLLM: Embodied Multimodal Large Language Model for Object-Centric Robotic Manipulation ➡️ 论文作者:Xiaoqi Li, Mingxu Zhang…...
LVS+Keepalived+dns高可用项目架构
一、搭建DNS服务 配置主服务器 1.修改核心配置文件 [rootDNS-master ~]# vim /etc/named.conf options { listen-on port 53 { 192.168.111.107;192.168.111.100; }; directory "/var/named"; }; zone "haha.com" IN { ty…...
C#日志辅助类(Log4Net)实现
一、Log4Net类库安装 在解决方案中项目上右键单击,选择“管理NuGet程序包”,在浏览窗口的搜索框输入log4net进行搜索,安装搜索出的第一项,如下图。 二、辅助类实现(Log4NetHelper) using log4net.Appender; using log4net.Config; using log4net.Layout; using log4net…...
【FFmpeg从入门到精通】第二章-FFmpeg工具使用基础
1 ffmpeg常用命令 ffmpeg在做音视频编解码时非常方便,所以在很多场景下转码使用的是ffmpeg,通过 ffmpeg --help可以看到 ffmpeg 常见的命令大概分为6个部分,具体如下。 ffmpeg信息查询部分公共操作参数部分文件主要操作参数部分视频操作参数…...
论文阅读:2022 ACL TruthfulQA: Measuring How Models Mimic Human Falsehoods
总目录 大模型安全相关研究:https://blog.csdn.net/WhiffeYF/article/details/142132328 TruthfulQA: Measuring How Models Mimic Human Falsehoods https://arxiv.org/pdf/2109.07958 https://www.doubao.com/chat/3130551217163266 https://github.com/sylin…...
基于C++(MFC)实现的文件管理系统
基于 MFC 的文件管理系统 第一章 题目解读与要求分析 1 实习题目 实现一个文件系统。 2 功能要求 界面上显示树形目录结构 a)根节点是“我的电脑” b)“我的电脑”下有几个盘符(C、D、E 等)就有几个子节点,递归…...
selenium 实现模拟登录中的滑块验证功能
用python在做数据采集过程中,经常需要用到模拟登录,经常遇到各种图片、文字甚至短信等验证,如果能通过脚本的方便实现验证,就可以自动帮我更高效地收集数据。Selenium 是一个开源的 Web 自动化测试工具,最初是为网站自…...
Oracle 19c部署之数据库软件安装(二)
在完成了Oracle Linux 9的初始化配置之后,我们准备安装Oracle 19c数据库软件。 Oracle数据库支持两种主要的安装方式:图形化安装和静默安装。这两种方法各有优缺点,选择哪种取决于你的具体需求、环境配置以及个人偏好。 图形化安装 图形化安…...
Paramiko 使用教程
目录 简介安装 Paramiko连接到远程服务器执行远程命令文件传输示例 简介 Paramiko 是一个基于 Python 的 SSH 客户端库,它提供了在网络上安全传输文件和执行远程命令的功能。本教程将介绍 Paramiko 的基本用法,包括连接到远程服务器、执行命令、文件传输…...
从EOF到REOF:如何用旋转经验正交函数提升时空数据分析精度?
目录 1. 基本概念与原理2. 应用场景3. 与传统EOF的区别4. 技术实现5. 其他领域中的“REOF”参考资料 REOF 的输入是多个地区在不同时间的气候数据(如温度或降雨量),它的作用是通过旋转计算找出这些数据中最主要的变化规律,输出则是…...
VS-Code创建Vue3项目
1 创建工程文件 创建一个做工程项目的文件夹 如:h5vue 2 cmd 进入文件 h5vue 3 输入如下命令 npm create vuelatest 也可以输入 npm create vitelatest 4 输入项目名称 项目名称:自已输入 回车 可以按键盘 a (全选) 回车: Playwright…...
JESD204B接收器核心实现和系统级关键细节
目录 1.通道偏移 2.弹性缓冲器的实现 3.接受延迟 4.RX端到端延迟 5.计算端到端延迟 6.实现可重复的延迟 1.通道偏移 JESD204B接收器核心已经过验证,其功能具有高达8个字节的通道到通道偏斜。 2.弹性缓冲器的实现 在JESD204B设备中,接收通道对齐弹性缓冲区是在分布式…...
NLP高频面试题(四十七)——探讨Transformer中的注意力机制:MHA、MQA与GQA
MHA、MQA和GQA基本概念与区别 1. 多头注意力(MHA) 多头注意力(Multi-Head Attention,MHA)通过多个独立的注意力头同时处理信息,每个头有各自的键(Key)、查询(Query)和值(Value)。这种机制允许模型并行关注不同的子空间上下文信息,捕捉复杂的交互关系。然而,MHA…...
k230学习笔记-疑难点(1)
1.出现boot failed with exit code 19: 需要将k230开发板的btoot0拨到ON 2.出现boot failed with exit code 13: 说明k230开发板的固件烧录已经丢失,需要重新烧录 *** 注意重新烧录时需要将btoot0重新拨到OFF,才会弹出加载固件需要的通用串行总线&…...
JavaScript性能优化实战:让你的Web应用飞起来
JavaScript性能优化实战:让你的Web应用飞起来 在前端开发中,JavaScript性能优化是提升用户体验的关键。一个性能良好的应用不仅能吸引用户,还能提高转化率和用户留存率。今天,我们就来深入探讨JavaScript性能优化的实战技巧&…...
金融数据库转型实战读后感
荣幸收到老友太保科技有限公司数智研究院首席专家林春的签名赠书。 这是国内第一本关于OceanBase数据库实际替换过程总结的的实战书。打个比方可以说是从战场上下来分享战斗经验。读后感受颇深。我在这里讲讲我的感受。 第三章中提到的应用改造如何降本。应用改造是国产化替换…...
血脂代谢通路(医学-计算机系统对照方式)
血脂代谢通路(医学-计算机系统对照方式) 整合所有类比,用医学-计算机系统对照的方式完整描述血脂代谢通路,采用分步骤的对照结构: 1. 食物摄入(数据输入层) # 医学术语: 膳食脂肪摄入 → 计算机类比: 原始数据输入 …...
git更新的bug
文章目录 1. 问题2. 分析 1. 问题 拉取了一个项目后遇到了这个问题, nvocation failed Server returned invalid Response. java.lang.RuntimeException: Invocation failed Server returned invalid Response. at git4idea.GitAppUtil.sendXmlRequest(GitAppUtil…...
直流电源基本原理
整流电路 在构建整流电路时,要选择合适参数的二极管 If是二极管能够通过电流的能力,也是最大整流的平均电流。 还要考虑二极管的反向截至电压。 脉动系数电压交流幅值/直流平均电压(越小越好) 三相整流电路优点: …...
Git -> git merge --no-ff 和 git merge的区别
git merge --no-ff <branch> 与 git merge <branch> 的区别 核心区别 git merge <branch>: 默认使用Fast-forward模式(若可行)不创建额外的合并提交记录合并后看不出曾经存在过分支 git merge --no-ff <branch>:强制创建一个…...
名胜古迹传承与保护系统(springboot+ssm+vue+mysql)含运行文档
名胜古迹传承与保护系统(springbootssmvuemysql)含运行文档 名胜古迹传承与保护系统是一个专注于文化遗产保护和管理的综合性平台。系统提供了一系列功能模块,包括名胜古迹管理、古迹预约管理、古迹故事管理、举报信息管理、保护措施管理、古迹讨论、管理员管理、版…...
windows资源管理器左边导航窗格增加2个项,windows10/11有效
下面文档存为.reg文件, Windows Registry Editor Version 5.00; 根 CLSID —— 名称、图标、固定到导航窗格 [HKEY_CURRENT_USER\Software\Classes\CLSID\{C1A3F2D2-BD2D-4D60-82C5-394F01753A5F}] "手机系统" "System.IsPinnedToNamespaceTree&quo…...
【八股文】基于源码聊聊ConcurrentHashmap的设计
版本演进 jdk 1.7中是分段锁的设计,将哈希表划分为多个segment,每个段独立加锁,锁粒度为段级别。 操作需两次哈希,第一次定位段,第二次定位桶内链表。这种实现方式的缺点就是段数量固定,扩容复杂…...
Mysql--基础知识点--93--两阶段提交
1 两阶段提交 以update语句的具体执行过程为例: 具体更新一条记录 UPDATE t_user SET name ‘xiaolin’ WHERE id 1;的流程如下: 1.执行器负责具体执行,会调用存储引擎的接口,通过主键索引树搜索获取 id 1 这一行记录&#…...
数字化招标采购系统怎么让招采协同更高效?
招标采购领域的数智化转型正在引发行业革命性变革。从传统线下模式到全流程电子化,再到当前数智化阶段的超时空协同,行业的演进路径清晰展现了技术与管理的深度融合。郑州信源信息数智化招采系统作为行业标杆,其创新实践为未来协同工作方式的…...
池塘计数(BFS)
题目描述 由于最近的降雨,光头强的田地里的各个地方都积水了,用 NM(1≤N≤100;1≤M≤100)NM(1≤N≤100;1≤M≤100) 的正方形的矩形表示。每个广场都有水 W 或旱地 .。光头强想知道他的田地里形成了多少池塘。池塘是指一组相邻的有…...
《Science》观点解读:AI无法创造真正的智能体(AI Agent)
无论是想要学习人工智能当做主业营收,还是像我一样作为开发工程师但依然要运用这个颠覆开发的时代宠儿,都有必要了解、学习一下人工智能。 近期发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,入行门槛低&#x…...
从零开始学A2A四:A2A 协议的安全性与多模态支持
文章目录 A2A 协议的安全性与多模态支持一、A2A 协议安全机制1. 认证机制2. 授权机制3. 数据加密 二、多模态交互支持1. 文本交互2. 音频支持3. 视频与图像处理4. 复合数据格式 三、安全与多模态最佳实践1. 安全性实践2. 多模态实践 四、与 MCP 的对比1. 安全机制对比2. 多模态…...
一种大位宽加减法器的时序优化
平台:vivado2018.3 芯片:xc7a100tfgg484-2 (active) 在FPGA中实现超高位宽加减法器(如256)时,时序收敛常成为瓶颈。由于进位链(Carry Chain)跨越多级逻辑单元,关键路径延迟会随位宽…...
【专业解读:Semantic Kernel(SK)】大语言模型与传统编程的桥梁
目录 Start:什么是Semantic Kernel? 一、Semantic Kernel的本质:AI时代的操作系统内核 1.1 重新定义LLM的应用边界 1.2 技术定位对比 二、SK框架的六大核心组件与技术实现 2.1 内核(Kernel):智能任务调度中心 2…...
InfiniBand与RoCEv2负载均衡机制的技术梳理与优化实践
AI技术的高速迭代正驱动全球算力格局进入全新纪元。据IDC预测,未来五年中国智能算力规模将以超50%的年复合增长率爆发式扩张,数据中心全面迈入“智能算力时代”。然而,海量AI训练、实时推理等高并发场景对底层网络提出了更严苛的挑战——超大…...
Vue与React组件化设计对比
组件化是现代前端开发的核心思想之一,而Vue和React作为两大主流框架,在组件化设计上既有相似之处,也存在显著差异。本文将从语法设计、数据管理、组件通信、性能优化、生态系统等多个方向,结合实例详细对比两者的特点。 一、模板…...
UE中通过AAIController::MoveTo函数巡逻至目标点后没法正常更新巡逻目标点
敌人巡逻的逻辑如下: 敌人在游戏一开始的时候就通过moveto函数先前往首先设定的patroltarget目标,在距离patroltarget距离为patroradius(200unit)之内时就可以通过checkpatroltarget函数更新新的patroltarget目标,随后前往新的pat…...
Python-细节知识点range函数的详解
在 Python 中,range 是一个内置函数,用于生成一个不可变的整数序列,通常用于控制循环次数或生成数值范围。以下是详细说明: 基本语法 range(stop) # 生成 [0, stop) 的整数,步长为1 range(start, stop) …...
git rebase的使用
我的使用 git checkout feature # 本地分支 git pull origin main --rebase # 目标分支 git pull origin feature --rebase git push origin featuregit rebase 是 Git 中用于重写提交历史的强大工具,可将分支的提交移动到新的基点上,使历史更线性。以…...
CMake Error at build/_deps/glog-src/CMakeLists.txt:1 (cmake_minimum_required):
这个错误提示意思是你当前系统上安装的 CMake 版本过低,不满足项目的要求。根据错误信息: CMake Error at build/_deps/glog-src/CMakeLists.txt:1 (cmake_minimum_required): CMake 3.22 or higher is required. You are running version 3.16.3 项目…...
MATLAB 控制系统设计与仿真 - 34
多变量系统知识回顾 - MIMO system 这一章对深入理解多变量系统以及鲁棒分析至关重要 首先,对于如下系统: 当G(s)为单输入,单输出系统时: 如果: 则: 所以 当G(s)为MIMO时,例如2X2时ÿ…...
【Unity】JSON数据的存取
这段代码的结构是为了实现 数据的封装和管理,特别是在 Unity 中保存和加载玩家数据时。以下是对代码设计的逐步解释: 1. PlayerCoin 类 PlayerCoin 是一个简单的数据类,用于表示单个玩家的硬币信息。它包含以下字段: count&…...
利用 Java 爬虫按关键字搜索淘宝商品
在电商领域,通过关键字搜索商品是常见的需求。淘宝作为国内知名的电商平台,提供了丰富的商品搜索功能。本文将详细介绍如何使用 Java 爬虫技术按关键字搜索淘宝商品,并获取搜索结果的详细信息。 一、准备工作 1. 注册淘宝开放平台账号 要使…...
【C】初阶数据结构11 -- 选择排序
本篇文章主要讲解经典排序算法 -- 选择排序 目录 1 算法思想 2 代码 3 时间复杂度与空间复杂度分析 1) 时间复杂度 2) 空间复杂度 1 算法思想 选择排序是一种在一段区间里面选择最小的元素和最大的元素的一种排序算法。假设这里排升序&#…...
【Semantic Kernel核心组件】Plugin:连接AI与业务逻辑的桥梁
目录 一、Plugin是什么?为什么它是SK的核心? 一、Plugin的核心机制与Python实现 1. 插件类型:语义函数与本地函数 语义函数(Semantic Function) 本地函数(Native Function) 2. Plugin的注…...
《基于神经网络实现手写数字分类》
《基于神经网络实现手写数字分类》 一、主要内容: 1、通过B站陈云霁老师的网课,配合书本资料,了解神经网络的基本组成和数学原理。 2、申请云平台搭建实验环境 3、基于5个不同的实验模块逐步理解实验操作步骤,并实现不同模块代码…...
车载诊断架构 --- 车载诊断概念的深度解读
我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…...
四、探索LangChain:连接语言模型与外部世界的桥梁
一、什么是 LangChain LangChain 是一个开源的软件框架,旨在帮助开发者高效构建和部署基于**大型语言模型(LLM)**的应用程序。它通过提供一系列工具、组件和接口,简化了从模型调用、提示工程到复杂应用开发的全流程,使得开发者能够快速将 LLM 集成到实际场景中。 1. 核心…...
LangChain4j中的Chat与语言模型API详解:构建高效对话系统的利器
LangChain4j中的Chat与语言模型API详解:构建高效对话系统的利器 引言:大模型时代的开发利器 在人工智能快速发展的今天,大型语言模型(LLM)已成为构建智能应用的核心组件。LangChain4j作为Java生态中领先的LLM集成框架…...
C++中const与constexpr的区别
在C中,const和constexpr都用于定义常量,但它们的用途和行为有显著区别: ### 1. **初始化时机** - **const**:表示变量是只读的,但其值可以在**编译时或运行时**初始化。 cpp const int a 5; // 编译…...
长亭2月公开赛Web-ssrfme
环境部署 拉取环境报错: 可以尝试拉取一下ubuntu:16.04,看是否能拉取成功 将wersion:"3"删掉 我拉去成功之后,再去拉取环境,成功! 访问环境 测试ssrf 源码 <?php highlight_file(__file__…...
AI日报 - 2025年4月18日
🌟 今日概览(60秒速览) ▎🤖 AGI探讨 | 专家激辩AGI定义与实现时间点,Causal AI被视为关键一步,o3模型预测2027年实现引关注。 Causal AI强调因果关系而非模式;专家清单推荐不同模型适用场景;AGI定义及何时…...
Spring IoC 详解
在 Spring IoC& DI 详解 中对 IoC已经有了介绍,下面对 IoC 进行详细介绍。 IoC,即控制反转,在之前我们编写程序的时候,我们都是自己 new 出来一个对象,然后自己去管理这个对象,但是这有时候有些麻烦&a…...