自定义多头注意力模型:从代码实现到训练优化
引言
在自然语言处理和序列生成任务中,自注意力机制(Self-Attention)是提升模型性能的关键技术。本文将通过一个自定义的PyTorch模型实现,展示如何构建一个结合多头注意力与前馈网络的序列生成模型(如文本或字符生成)。该模型通过创新的 MaxStateSuper
模块实现动态特征融合,适用于字体生成、文本预测等场景。
技术背景
1. 模型结构解析
核心组件
-
MaxStateSuper(自注意力模块)
- 功能:通过多头注意力机制提取序列中的关键特征,并结合累积最大值操作增强长期依赖建模。
- 实现亮点:
- 合并三个线性层为一个
combined
层,减少参数冗余。 - 使用
torch.cummax
实现动态状态积累,提升序列记忆能力。
- 合并三个线性层为一个
-
FeedForward(前馈网络)
- 结构:两层全连接网络,中间夹杂
ReLU
激活函数和门控机制(gate
)。 - 作用:非线性变换,增强模型表达能力。
- 结构:两层全连接网络,中间夹杂
-
DecoderLayer(解码器层)
- 创新点:
- 引入
alpha
参数平衡前馈网络输出与原始输入的权重,实现动态特征融合。 - 层归一化(
LayerNorm
)确保梯度稳定性。
- 引入
- 创新点:
-
SamOut(整体模型)
- 输入:字符或token的Embedding向量。
- 输出:预测的下一时刻token概率分布。
2. 关键技术
- 多头注意力机制:通过
heads
参数将特征空间划分为多个子空间,提升模型对不同模式的捕捉能力。 - 累积最大值操作:
out2 = torch.cummax(out2, dim=2)[0]
保留序列中的关键特征轨迹。 - 动态参数平衡:
alpha
参数通过梯度下降自动学习前馈网络与原始输入的权重比例。
代码实现
完整代码
import torch
import torch.nn as nn
import torch.optim as optimclass MaxStateSuper(nn.Module):def __init__(self, dim_size, heads):super().__init__()self.heads = headsassert dim_size % heads == 0, "Dimension size must be divisible by head size."self.combined = nn.Linear(dim_size, 3 * dim_size, bias=False) # 合并QKV线性层def forward(self, x):b, s, d = x.shape# 合并后的线性变换并分割为QKVqkv = self.combined(x).chunk(3, dim=-1)q, k, v = qkv# 调整形状并执行注意力计算# ...(此处省略具体注意力计算逻辑,参考标准多头注意力实现)...return out, stateclass FeedForward(nn.Module):def __init__(self, hidden_size):super().__init__()self.ffn1 = nn.Linear(hidden_size, hidden_size)self.ffn2 = nn.Linear(hidden_size, hidden_size)self.gate = nn.Linear(hidden_size, hidden_size)self.relu = nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2return self.ffn2(xx)class DecoderLayer(nn.Module):def __init__(self, hidden_size, num_heads):super().__init__()self.self_attn = MaxStateSuper(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.norm = nn.LayerNorm(hidden_size)self.alpha = nn.Parameter(torch.tensor(0.5)) # 动态平衡参数def forward(self, x):attn_out, _ = self.self_attn(x)ffn_out = self.ffn(attn_out)x = self.norm(self.alpha * ffn_out + (1 - self.alpha) * x)return xclass SamOut(nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super().__init__()self.embedding = nn.Embedding(voc_size, hidden_size, padding_idx=3)self.layers = nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = nn.Linear(hidden_size, voc_size, bias=False)def forward(self, x):x = self.embedding(x)for layer in self.layers:x = layer(x)return self.head(x)# 训练流程(简化版)
if __name__ == '__main__':voc_size = 10000 # 假设词汇表大小model = SamOut(voc_size, hidden_size=256, num_heads=8, num_layers=6)criterion = nn.CrossEntropyLoss(ignore_index=3)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):# 假设 input_tensor 和 target_tensor 已准备output = model(input_tensor)loss = criterion(output.view(-1, voc_size), target_tensor.view(-1))loss.backward()optimizer.step()
关键步骤解析
1. MaxStateSuper
模块的创新点
# 合并QKV层
qkv = self.combined(x<
相关文章:
自定义多头注意力模型:从代码实现到训练优化
引言 在自然语言处理和序列生成任务中,自注意力机制(Self-Attention)是提升模型性能的关键技术。本文将通过一个自定义的PyTorch模型实现,展示如何构建一个结合多头注意力与前馈网络的序列生成模型(如文本或字符生成)。该模型通过创新的 MaxStateSuper 模块实现动态特征…...
vue部署到nginx服务器 启用gzip
要在使用Vue.js构建的应用程序上启用Nginx的Gzip压缩,你可以通过配置Nginx来实现这一功能,这样可以显著减少传输到客户端的数据量,从而加快页面加载速度。以下是如何配置Nginx以启用Gzip压缩的步骤: 1. 确认你的Vue.js应用已经构…...
Node.js和js到底什么关系
Node.js 和 JavaScript(JS)是紧密关联但本质不同的技术,它们的关系可以从以下几个关键维度进行解析: 1. 定义与角色 JavaScript: 一种高级、解释型的编程语言,最初设计用于浏览器端,负责网页的…...
如何开发一套TRS交易系统:架构设计、核心功能与风险控制
TRS(总收益互换)作为场外衍生品的重要工具,近年来在跨境投资、杠杆交易和风险对冲领域备受关注。2021年Archegos资本因TRS交易爆仓导致百亿美元级市场震荡,凸显了TRS系统设计的关键性。本文将从技术实现角度,解析TRS交…...
基于SpringBoot的高校体育馆场地预约管理系统-项目分享
基于SpringBoot的高校体育馆场地预约管理系统-项目分享 项目介绍项目摘要目录总体功能图用户实体图赛事实体图项目预览用户个人中心医生信息管理用户管理场地信息管理登录 最后 项目介绍 使用者:管理员 开发技术:MySQLJavaSpringBootVue 项目摘要 随着…...
MMIO、IOMAP 和 IOMMU 总结
MMIO、IOMAP 和 IOMMU 全面解析 📌 本文将深入浅出地梳理 Linux 驱动开发中常见的三大术语:MMIO、iomap、IOMMU。它们看似相似,其实职责完全不同,是理解 SoC 系统架构、DMA 安全性和驱动开发的基础。 一、MMIO(Memory-…...
Vscode开发STM32标准库
Vscode开发STM32 文章目录 引用一、文档介绍二、实际操作(基于标准库)总结 使用VScode开发STM32(keil),基础江科大标准库的串口接收和发送。 引用 VSCodeEIDE开发STM32,支持标准库、HAL库、LL库,可以在VSCode里进行调…...
Lateral 查询详解:概念、适用场景与普通 JOIN 的区别
1. 什么是Lateral查询? Lateral查询(也称为横向关联查询)是一种特殊的子查询,允许子查询中引用外层查询的列(即关联引用),并在执行时逐行对外层查询的每一行数据执行子查询。 语法上通常使用关…...
智能视频监控平台EasyCVR常见安防监控问题:录像机添加摄像头后无画面是什么原因
在智能安防场景中,室外安防监控摄像头承担着保障区域安全的重任,但画面无法显示、显示异常等问题却时常干扰正常监控工作,按照以下系统化步骤,即可高效定位并解决问题,让监控系统迅速恢复稳定运行。 一般出现这个问题…...
【Spring】深入解析 Spring AOP 核心概念:切点、连接点、通知、切面、通知类型和使用 @PointCut 定义切点的方法
Spring AOP 下面我们再来详细学习 AOP,主要是以下几部分: Spring AOP 核心概念 切点(Pointcut) 切点(Pointcut),也称之为“切入点”。 Pointcut 的作用就是提供一组规则(使用 Aspe…...
Uniapp:view容器(容器布局)
目录 一、基本概述二、属性说明三、常用布局3.1 横向布局3.2 纵向布局3.3 更多布局3.3.1 纵向布局-自动宽度3.3.2 纵向布局-固定宽度3.3.3 横向布局-自动宽度3.3.4 横向布局-居中3.3.5 横向布局-居右3.3.6 横向布局-平均分布3.3.7 横向布局-两端对齐3.3.8 横向布局-自动填充3.3…...
C# 运算符:?.(null 条件运算符)和 ??(null 合并运算符)
在 WinForms 中,comboBox1.SelectedValue?.ToString() ?? "" 这行代码使用了两个特殊的 C# 运算符:?.(null 条件运算符)和 ??(null 合并运算符)。让我分别解释它们的作用: ?.&…...
java/python——两个行为(操作)满足原子性的实现
目录 JAVA方法 1:使用 synchronized 同步块示例代码 方法 2:使用 ReentrantLock锁示例代码 方法 3:使用 AtomicReference 或其他原子类示例代码 方法 4:使用数据库事务(如果涉及数据库操作)示例代码&#x…...
SpringBoot中配置文件的加载顺序
下面的优先级由高到低 命令行参数java系统属性java系统环境变量外部config文件夹的application-{profile}.ym文件外部的application-{profile}.ym文件内部config文件夹的application-{profile}.ym文件内部的application-{profile}.ym文件外部config文件夹的application.ym文件外…...
Nginx下搭建rtmp流媒体服务 并使用HLS或者OBS测试
所需下载地址: 通过网盘分享的文件:rtmp 链接: https://pan.baidu.com/s/1t21J7cOzQR1ASLrsmrYshA?pwd0000 提取码: 0000 window: 解压 win目录下的 nginx-rtmp-module-1.2.2.zip和nginx 1.7.11.3 Gryphon.zip安装包,解压时选…...
在线查看【免费】 txt, xml(渲染), md(渲染), java, php, py, js, css 文件格式网站
可以免费在线查看 .docx/wps/Office/wmf/ psd/ psd/eml/epub/dwg, dxf/ txt/zip, rar/ jpg/mp3 m.gszh.xyz m.gszh.xyz 免费支持以下格式文件在线查看类型 支持 doc, docx, xls, xlsx, xlsm, ppt, pptx, csv, tsv, dotm, xlt, xltm, dot, dotx, xlam, xla, pages 等 Office 办…...
RIP动态路由(三层交换机+单臂路由)
RIP动态路由(三层交换机单臂路由) J1 (配置VLAN,修改端口) Switch>en Switch>en Switch# Switch#conf t Enter configuration commands, one per line. End with CNTL/Z. Switch(config)#int f0/1 Switch(config-if)#sw Switch(confi…...
Docker 基本概念与安装指南
Docker 基本概念与安装指南 一、Docker 核心概念 1. 容器(Container) 容器是 Docker 的核心运行单元,本质是一个轻量级的沙盒环境。它基于镜像创建,包含应用程序及其运行所需的依赖(如代码、库、环境变量等…...
Oracle DBA培训一般多长时间?
Oracle DBA培训的时间通常在2个月到6个月之间,具体看课程类型和你的学习目标。不过别只看总时长,关键得看每天学什么、练什么——有些机构把时间拖到半年,结果全是理论;有些课程压缩到2个月,但全是干货。下面分情况…...
【回眸】Linux 内核 (十七) 之 网络编程
前言 努力赶紧把Linux内核的内容更新完。 网络编程 协议的部分已经很成熟,只需要调用即可。 进程间通讯无法进行多机通信,网络通讯则解决了这一缺陷。 TCP/UDP协议对比 (1)TCP 面向连接(如打电话要先拨号建立连接…...
Batch Size
1. 什么是Batch Size? Batch Size(批大小)是指在深度学习模型训练过程中,每次前向传播和反向传播时输入到模型中的样本数量。具体来说,深度学习模型的训练通常基于梯度下降(Gradient Descent)算…...
Maven插件管理的基本原理
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
flutter 专题 六十三 Flutter入门与实战作者:xiangzhihong8Fluter 应用调试
Fluter 应用调试 Flutter 构建模式 目前,Flutter一共提供了三种运行模式,分别是Debug、Release和Profile模式。其中,Debug模式主要用在软件编写过程中,Release模式主要用于应用发布过程中,而Profile模式则主要用于应…...
MySQL-存储过程--游标
存储过程 游标 什么是游标 一个游标是一个SQL语句执行时系统内存创建的一个临时工作区域。一个游标包含一个查询语句的信息和它操作的数据行的信息。 mysql游标的特点 只读: 无法通过游标更新基础表中的数据不可滚动: 只能根据select中确定的顺序来…...
Spring AOP 事务
目录 一,引入依赖: 二,切面 1,基本概念 2, 通知类型: 3,Pointcut 4, 切面优先级: 5 ,自定义优先级Order 6,切点表达式 7, 自定义注解 总结: AOP有几种创建方式 三, Spring AOP原理 1, 代理模式 (1)静态代理 (2)动态代理 △JDK动态代理 △CGLIB动态代理 JDB和c…...
Itext进行PDF的编辑开发
这周写了一周的需求,是制作一个PDF生成功能,其中用到了Itext来制作PDF的视觉效果。其中一些功能不是很懂,仅作记录,若要学习请仔细甄别正确与否。 开始之前,我还是想说,这傻福需求怎么想出来的,…...
Python 中消费者 - 生产者模式详解
目录 引言 消费者 - 生产者模式原理 示例场景 Python 实现消费者 - 生产者模式 使用队列(Queue)实现 代码解释 使用协程实现 代码解释 应用场景 总结 引言 在软件开发里,消费者 - 生产者模式是一种常见且重要的设计模式。这种模式让…...
基于Hadoop的音乐推荐系统(源码+lw+部署文档+讲解),源码可白嫖!
摘要 本毕业生数据分析与可视化系统采用B/S架构,数据库是MySQL,网站的搭建与开发采用了先进的Java语言、爬虫技术进行编写,使用了Spring Boot框架。该系统从两个对象:由管理员和用户来对系统进行设计构建。主要功能包括ÿ…...
移动端动态滑动拨盘选择器【Axure元件库】
模拟移动端底部对话框效果,制作的年份、日期滑动拨盘选择器,支持日期动态滑动选择,提升原型制作强度。 该模板主要使用中继器、动态面板和矩形制作,使用简单,复用性强。只需对中继器数据表格中的数据项进行修改、增删…...
7. 深入Spring AI:刨析 Advisors 机制
目录 1、序言2、什么是Advisor?3、源码分析Advisor3.1、Advisor接口3.2、Advisor Ordered3.3、CallAroundAdvisor & StreamAroundAdvisor3.4、BaseAdvisor4、内置的Advisor类型4.1、MessageChatMemoryAdvisor4.2、PromptChatMemoryAdvisor4.3、VectorStoreChatMemoryAdvis…...
高保真动态项目管理图表集
本作品为项目管理图表类原型,以关系图谱、甘特图、流程图、泳道图为核心,提供基础的图表设计风格和交互案例,再进阶到高级的动态交互设计,由浅入深诠释Axure设计高端复杂的动态交互设计的魅力。 作品介绍 原型名称:Ax…...
MCP:AI时代的“万能插座”,开启大模型无限可能
摘要:Model Context Protocol(MCP)由Anthropic在2024年底开源,旨在统一大模型与外部工具、数据源的通信标准。采用客户端-服务器架构,基于JSON-RPC 2.0协议,支持stdio、SSE、Streamable HTTP等多种通信方式…...
IDEA打不开、打开报错
目录 场景异常原因解决 场景 1、本机已经安装了IDEA 2、再次安装另外一个版本的IDEA后打不开、打开报错 异常 这里忘记截图了。。。 原因 情况1-打不开:在同一台电脑安装多个IDEA是需要对idea的配置文件进行调整的,否则打不开 情况2-打开报错&#…...
TM1640学习手册及示例代码
数据手册 TM1640数据手册 数据手册解读 这里我们看管脚定义DIN和SCLK,一个数据线一个时钟线 SEG1~SEG8为段码,GRID1~GRID16为位码(共阴极情况下) 这里VDD给5V 数据指令 数据命令设置 地址命令设置 显示控制命令 共阴极硬件连接图…...
动态规划-零钱兑换
332.零钱兑换 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。计算并返回可以凑成总金额所需的 最少的硬币个数 。如果没有任何一种硬币组合能组成总金额,返回 -1 。你可以认为每种硬币的数量是无…...
leetcode50.pow(x,n)
class Solution {private double f(double x,long n){if(n0)return 1.0;else {double tempf(x,n/2);return n%21?temp*temp*x:temp*temp;}}public double myPow(double x, int n) {long Nn;return n>0?f(x,N):1.0/f(x,-N);} }...
ECA 注意力机制:让你的卷积神经网络更上一层楼
ECA 注意力机制:让你的卷积神经网络更上一层楼 在深度学习领域,注意力机制已经成为提升模型性能的重要手段。从自注意力(Self-Attention)到各种变体,研究人员不断探索更高效、更有效的注意方法。今天我们要介绍一种轻…...
《谷歌Gemini 1.5:长语境理解重塑文档分析与检索新格局》
在人工智能的快速发展进程中,大语言模型不断突破边界,为各个领域带来变革性影响。谷歌Gemini 1.5的问世,凭借其卓越的长语境理解能力,在文档分析和检索任务方面掀起了一阵技术革新的浪潮。 以往的大语言模型在处理长文本时&#…...
量变与质变的辩证关系
量变和质变是唯物辩证法中揭示事物发展状态和形式的一对重要范畴,二者之间存在着密切的辩证关系。 一、量变是质变的必要准备 含义 量变是指事物数量的增减和场所的变更,是一种渐进的、不显著的变化。例如,水的温度升高,从0℃逐…...
讯联桌面TV版apk下载-讯联桌面安卓电视版免费下载安装教程
在智能电视的使用过程中,一款好用的桌面应用能极大提升我们的使用体验。讯联桌面 TV 版就是这样一款备受关注的应用,它可以让安卓电视拥有更个性化、便捷的操作界面。今天,就为大家详细介绍讯联桌面 TV 版 apk 的免费下载安装教程。 一、下载…...
【Vue】组件基础
目录 🚀 Vue 非单文件组件 和 单文件组件 的区别与实践对比 ✨ 引言 一、非单文件组件 1. 基本使用 2. 注意: 3. 组件的嵌套 4. 关于VueComponent: 5. 一个重要的内置关系(有点难理解) 二、 单文件组件 那就…...
OpenCV---图像预处理(四)
OpenCV—图像预处理(四) 文章目录 OpenCV---图像预处理(四)九,图像掩膜9.1 制作掩膜9.2 与运算9.3 颜色替换9.3.19.3.2 颜色替换 十,ROI切割十 一,图像添加水印11.1模板输入11.2 与运算11.3 图像…...
《MySQL:MySQL表的基本查询操作CRUD》
CRUD:Create(创建)、Retrieve(读取)、Update(更新)、Delete(删除)。 Create into 可以省略。 插入否则更新 由于主键或唯一键冲突而导致插入失败。 可以选择性的进行同步…...
Psychology 101 期末测验(附答案)
欢呼 啦啦啦~啦啦啦~♪(^∇^*) 终于考过啦~ 开心(*^▽^*) 撒花✿✿ヽ(▽)ノ✿ |必须晒下证书: 判卷 记录下判卷,还是错了几道,填空题2道压根填不上。惭愧~ 答案我隐藏了,实在想不出答案的朋友可以留言,不定时回复。 建议还是认认真真的学习~认认真真的考试~,知识就…...
Linux:权限相关问题
文章目录 shell命令以及运行的原理Linux权限 shell命令以及运行的原理 操作系统分为内核和外壳程序,xshell是外壳程序,外壳程序包括我们windows桌面上的图形化界面,本质都是翻译给核心处理,再显示出来,而我们输入的命令…...
大模型应用开发大纲
AI大模型学习路径脑图结构 一、AI及LLM基础 学习目标:建立对AI和LLM的基础理解,了解主要的机器学习和神经网络模型,掌握API调用方法。 1.1 AI领域基础概念 AI, NL/NLU/NLG机器学习: 学习方法, 拟合评估神经网络: CNN, RNN, TransformerTra…...
【NCCL】transport建立(一)
transport建立 NCCL transport建立主要在ncclTransportP2pSetup函数中实现。 概况 先简单概括一下ncclTransportP2pSetup函数做了哪些事,方便理解代码流程。 recvpeer 表示本卡作为接收端的对端,sendpeer 表示本卡作为发送端的对端。假设8个rank全连接…...
智慧能源安全新纪元:当能源监测遇上视频联网的无限可能
引言:在数字化浪潮席卷全球的今天,能源安全已成为国家安全战略的重要组成部分。如何构建更加智能、高效的能源安全保障体系?能源安全监测平台与视频监控联网平台的深度融合,正为我们开启一扇通向未来能源管理新世界的大门。这种创…...
腾讯一面-软件开发实习-PC客户端开发方向
1.自我介绍就不多赘述了 2. 请介绍一下你的项目经历 - 介绍了专辑鉴赏项目,前端使用html语言编写,后端基于http协议使用C语言进行网页开发。此外,还提及项目中涉及处理多线程问题以及做过内存池管理项目。 3. 项目中HTTP协议是使用库实现的…...
Cad c# 射线法判断点在多边形内外
1、向量叉乘法 2、射线法原理 射线法是判断点与多边形位置关系的经典算法,核心思想是: 从目标点发出一条水平向右的射线(数学上可视为 y p_y, x \geq p_x 的射线),统计该射线与多边形边的交点数量: - 偶…...