手撕Transformer -- Day6 -- DecoderBlock
手撕Transformer – Day6 – DecoderBlock
目录
- 手撕Transformer -- Day6 -- DecoderBlock
- Transformer 网络结构图
- DecoderBlock 代码
- Part1 库函数
- Part2 实现一个解码器Block,作为一个类
- Part3 测试
- 参考
Transformer 网络结构图
DecoderBlock 代码
Part1 库函数
# 这个是解码器的block,和编码器来说多了一个掩码注意力机制,但是其实就是把掩码换一下即可,同时还对于第二个多头注意力机制的k_v和q不同源了
# 主要构成要素,输入嵌入好的句子,经过1.掩码注意力机制+残差归一化 2. 交叉注意力+残差归一化 3. 前向+残差归一化。保证输入输出同纬度(batch_size,seq_len,emding)
'''
# Part1 引入库函数
'''
import torch
from torch import nn
from multihead_attn import MultiHeadAttention
# 应该是用于测试
from dataset import train_dataset,de_preprocess,de_vocab,en_preprocess,en_vocab,PAD_IDX
from emb import EmbeddingWithPosition
from encoder import Encoder
Part2 实现一个解码器Block,作为一个类
'''
# Part2 写个类,实现EncoderBlock
'''
class DecoderBlock(nn.Module):def __init__(self,head,emd_size,q_k_size,v_size,f_size):super().__init__()# 首先要进行掩码多头注意力机制self.mask_multi_atten=MultiHeadAttention(head=head,emd_size=emd_size,q_k_size=q_k_size,v_size=v_size)self.linear1=nn.Linear(head*v_size,emd_size)# 归一化(填写的是最后一个的那个维度大小)self.norm1=nn.LayerNorm(emd_size)# 交叉注意力机制self.cross_multi_atten=MultiHeadAttention(head=head,emd_size=emd_size,q_k_size=q_k_size,v_size=v_size)self.linear2 = nn.Linear(head * v_size, emd_size)# 归一化(填写的是最后一个的那个维度大小)self.norm2 = nn.LayerNorm(emd_size)# 前向self.feedforward=nn.Sequential(nn.Linear(emd_size,f_size),nn.ReLU(),nn.Linear(f_size, emd_size))self.norm3 = nn.LayerNorm(emd_size)def forward(self, x, encoder_z, mask_1, mask_2): # x(batch_size,q_seq_len,emd_size)# 掩码注意力机制z1=self.mask_multi_atten(x_q=x, x_k_v=x, mask_pad=mask_1) # (batch_size,q_seq_len,head*v_size)z1=self.linear1(z1) # (batch_size,q_seq_len,emd_size)# 第一个残差归一化,得到第一层的输出outputoutpu1=self.norm1(z1+x) # (batch_size,q_seq_len,emd_size)# 交叉注意力机制,把output作为q,编码器作为k_vz2=self.cross_multi_atten(x_q=outpu1, x_k_v=encoder_z, mask_pad=mask_2) # (batch_size,q_seq_len,head*v_size)# 第二个残差归一化z2 = self.linear1(z2) # (batch_size,q_seq_len,emd_size)output2=self.norm2(z2+outpu1) # (batch_size,q_seq_len,emd_size)# 前向z3=self.feedforward(output2) # (batch_size,q_seq_len,emd_size)# 第三个残差归一化output3 = self.norm3(z3 + output2) # (batch_size,q_seq_len,emd_size)return output3
Part3 测试
if __name__ == '__main__':# 取2个de句子转词ID序列,输入给encoderde_tokens1, de_ids1 = de_preprocess(train_dataset[0][0])de_tokens2, de_ids2 = de_preprocess(train_dataset[1][0])# 对应2个en句子转词ID序列,再做embedding,输入给decoderen_tokens1, en_ids1 = en_preprocess(train_dataset[0][1])en_tokens2, en_ids2 = en_preprocess(train_dataset[1][1])# de句子组成batch并padding对齐if len(de_ids1) < len(de_ids2):de_ids1.extend([PAD_IDX] * (len(de_ids2) - len(de_ids1)))elif len(de_ids1) > len(de_ids2):de_ids2.extend([PAD_IDX] * (len(de_ids1) - len(de_ids2)))enc_x_batch = torch.tensor([de_ids1, de_ids2], dtype=torch.long)print('enc_x_batch batch:', enc_x_batch.size())# en句子组成batch并padding对齐if len(en_ids1) < len(en_ids2):en_ids1.extend([PAD_IDX] * (len(en_ids2) - len(en_ids1)))elif len(en_ids1) > len(en_ids2):en_ids2.extend([PAD_IDX] * (len(en_ids1) - len(en_ids2)))dec_x_batch = torch.tensor([en_ids1, en_ids2], dtype=torch.long)print('dec_x_batch batch:', dec_x_batch.size())# Encoder编码,输出每个词的编码向量enc = Encoder(vocab_size=len(de_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_encoderblock=3)enc_outputs = enc(enc_x_batch)print('encoder outputs:', enc_outputs.size())# 生成decoder所需的掩码first_attn_mask = (dec_x_batch == PAD_IDX).unsqueeze(1).expand(dec_x_batch.size()[0], dec_x_batch.size()[1],dec_x_batch.size()[1]) # 目标序列的pad掩码first_attn_mask = first_attn_mask | torch.triu(torch.ones(dec_x_batch.size()[1], dec_x_batch.size()[1]),diagonal=1).bool().unsqueeze(0).expand(dec_x_batch.size()[0], -1,-1) # &目标序列的向后看掩码print('first_attn_mask:', first_attn_mask.size())# 根据来源序列的pad掩码,遮盖decoder每个Q对encoder输出K的注意力second_attn_mask = (enc_x_batch == PAD_IDX).unsqueeze(1).expand(enc_x_batch.size()[0], dec_x_batch.size()[1],enc_x_batch.size()[1]) # (batch_size,target_len,src_len)print('second_attn_mask:', second_attn_mask.size())first_attn_mask = first_attn_masksecond_attn_mask = second_attn_mask# Decoder输入做emb先emb = EmbeddingWithPosition(len(en_vocab), 128)dec_x_emb_batch = emb(dec_x_batch)print('dec_x_emb_batch:', dec_x_emb_batch.size())# 5个Decoder block堆叠decoder_blocks = []for i in range(5):decoder_blocks.append(DecoderBlock(emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8))for i in range(5):dec_x_emb_batch = decoder_blocks[i](dec_x_emb_batch, enc_outputs, first_attn_mask, second_attn_mask)print('decoder_outputs:', dec_x_emb_batch.size())
参考
视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili
github代码库:github.com
相关文章:
手撕Transformer -- Day6 -- DecoderBlock
手撕Transformer – Day6 – DecoderBlock 目录 手撕Transformer -- Day6 -- DecoderBlockTransformer 网络结构图DecoderBlock 代码Part1 库函数Part2 实现一个解码器Block,作为一个类Part3 测试 参考 Transformer 网络结构图 Transformer 网络结构 DecoderBlock 代…...
Docker常用命令大全
Docker容器相关命令: 创建并启动容器: docker run:创建一个新的容器并运行一个命令。例如:docker run -d -p 8080:80 nginx这将后台(-d)运行一个Nginx容器,并映射宿主机的8080端口到容器的80端口。 列出容器&#x…...
【Linux探索学习】第二十五弹——动静态库:Linux 中静态库与动态库的详细解析
Linux学习笔记: https://blog.csdn.net/2301_80220607/category_12805278.html?spm1001.2014.3001.5482 前言: 在 Linux 系统中,静态库和动态库是开发中常见的两种库文件类型。它们在编译、链接、内存管理以及程序的性能和可维护性方面有着…...
Vue 实现当前页面刷新的几种方法
以下是 Vue 中实现当前页面刷新的几种方法: 方法一:使用 $router.go(0) 方法 通过Vue Router进行重新导航,可以实现页面的局部刷新,而不丢失全局状态。具体实现方式有两种: 实现代码: <template&g…...
python mysql库的三个库mysqlclient mysql-connector-python pymysql如何选择,他们之间的区别
三者的区别 1. mysqlclient 特点: 是一个用于Python的MySQL数据库驱动程序,用于与MySQL数据库进行交互。 依赖于MySQL的本地库,因此在安装时需要确保系统上已安装了必要的依赖项,如libmysqlclient-dev等。 性能较好,…...
【可持久化线段树】 [SDOI2009] HH的项链 主席树(两种解法)
文章目录 1.题目描述2.思路3.解法一解法一代码 4.解法二解法二代码(版本一)解法二代码(版本二) 1.题目描述 原题:https://www.luogu.com.cn/problem/P1972 [SDOI2009] HH的项链 题目描述 HH 有一串由各种漂亮的贝壳…...
【C语言】线程----同步、互斥、条件变量
目录 3. 同步 3.1 概念 3.2 同步机制 3.3 函数接口 1. 同步 1.1 概念 同步(synchronization)指的是多个任务(线程)按照约定的顺序相互配合完成一件事情 1.2 同步机制 通过信号量实现线程间的同步 信号量:通过信号量实现同步操作;由信号量来决定…...
15. 三数之和【力扣】--三指针
三数之和 已解答 中等 相关标签 相关企业 提示 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意&#x…...
大数据学习(35)- spark- action算子
&&大数据学习&& 🔥系列专栏: 👑哲学语录: 承认自己的无知,乃是开启智慧的大门 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言📝支持一下博主哦ᾑ…...
vim使用指南
🏝️专栏:计算机操作系统 🌅主页:猫咪-9527-CSDN博客 “欲穷千里目,更上一层楼。会当凌绝顶,一览众山小。” 目录 一、Vim 的基本概念 1.Vim 的主要模式: 1.1普通模式 (Normal Mode) 1.2插入…...
Docker 镜像制作原理 做一个自己的docker镜像
一.手动制作镜像 启动容器进入容器定制基于容器生成镜像 1.启动容器 启动容器之前我们首先要有一个镜像,这个镜像可以是从docker拉取,例如:现在pull一个ubuntu镜像到本机。 docker pull ubuntu:22.04 我们接下来可以基于这个容器进行容器…...
基于Java+SpringBoot+Vue的前后端分离的在线BLOG网
基于JavaSpringBootVue的前后端分离的在线BLOG网 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末附源码下载链接dz…...
Linux网络_套接字_UDP网络_TCP网络
一.UDP网络 1.socket()创建套接字 #include<sys/socket.h> int socket(int domain, int type, int protocol);domain (地址族): AF_INET网络 AF_UNIX本地 AF_INET:IPv4 地址族,适用于 IPv4 协议。用于网络通信AF_INET6:IPv6 地址族&a…...
Java学习教程,从入门到精通,JDBC驱动程序类型及语法知识点(91)
JDBC驱动程序类型及语法知识点 一、JDBC驱动程序类型 JDBC驱动程序主要有以下四种类型: 1. Type 1:JDBC - ODBC桥驱动程序(JDBC - ODBC Bridge Driver) 特点:这种驱动程序是Java与ODBC(Open Database C…...
YOLOv8从菜鸟到精通(二):YOLOv8数据标注以及模型训练
数据标注 前期准备 先打开Anaconda Navigator,点击Environment,再点击new(new是我下载anaconda的文件夹名称),然后点击创建 点击绿色按钮,并点击Open Terminal 输入labelimg便可打开它,labelimg是图像标注工具,在上篇…...
3D目标检测数据集——Nusence数据集
链接地址 [官网] nuScenes[arXiv] nuScenes: A multimodal dataset for autonomous driving[GitHub] nuScenes devkitnuScenes devkit教程数据集概述 2.1 数据采集 2.1.1 传感器配置 nuScenes的数据采集车辆为Renault Zoe迷你电动车,配备6个周视相机&#x...
网站收录入口提交的方法有哪些(网站收录的方式都有哪些)
网站被搜索引擎收录是获得流量和曝光的重要前提,以下为你介绍常见的网站收录方式: 搜索引擎提交入口 各大搜索引擎都设有专门的网站收录入口,供站长提交网站。例如百度搜索资源平台、谷歌搜索控制台等。以百度为例,在百度搜索资…...
移动端H5缓存问题
移动端页面缓存问题是指页面的静态资源(如图片、JS 和 CSS 文件)在浏览器中被缓存后,用户在下次访问时可以直接从本地获取缓存数据,而不需要每次都从服务器重新获取,不过这样可能会导致页面不能正确地更新或者加载最新…...
11-1.Android 项目结构 - androidTest 包与 test 包(单元测试与仪器化测试)
androidTest 包与 test 包 在 Android 项目中,androidTest 包与 test 包用于存放不同类型的测试代码的 1、测试类型 (1)androidTest 包 主要用于存放单元测试(Unit Tests)代码 单元测试是针对应用程序中的独立模块…...
计算机网络(五)——传输层
一、功能 传输层的主要功能是向两台主机进程之间的通信提供通用的数据传输服务。功能包括实现端到端的通信、多路复用和多路分用、差错控制、流量控制等。 复用:多个应用进程可以通过同一个传输层发送数据。 分用:传输层在接收数据后可以将这些数据正确分…...
ZCC9159 -7V 300mA 超低功耗高速 LDO
功能描述 ZCC9195是一款超低功耗并具有快速响应、关断快速放电功能的高速LDO。静态电流低至 0.8uA,输出电流最大为300mA。 ZCC9195具有输出过流保护、输出短路保护、温度保护等功能,确保芯片在异常工作条件 下不会损坏。 ZCC9195只需要1uF的陶瓷电容即…...
微信小程序实现个人中心页面
文章目录 1. 官方文档教程2. 编写静态页面3. 关于作者其它项目视频教程介绍 1. 官方文档教程 https://developers.weixin.qq.com/miniprogram/dev/framework/ 2. 编写静态页面 mine.wxml布局文件 <!--index.wxml--> <navigation-bar title"个人中心" ba…...
【C语言算法刷题】第7题
题目描述 一个XX产品行销总公司,只有一个boss,其有若干一级分销,一级分销又有若干二级分销,每个分销只有唯一的上级分销。 规定,每个月,下级分销需要将自己的总收入(自己的下级上交的…...
BERT与CNN结合实现糖尿病相关医学问题多分类模型
完整源码项目包获取→点击文章末尾名片! 使用HuggingFace开发的Transformers库,使用BERT模型实现中文文本分类(二分类或多分类) 首先直接利用transformer.models.bert.BertForSequenceClassification()实现文本分类 然后手动实现B…...
RocketMQ消息发送---源码解析
我们知道rocketMQ的消息发送支持很多特性,如同步发送,异步发送,oneWay发送,也支持超时机制,回调机制,并且能够保证消息的可靠性和消息发送的限流,底层使用netty框架等等,如此多的特性…...
机器学习06-正则化
机器学习06-正则化 文章目录 机器学习06-正则化0-核心逻辑脉络1-参考网址3-大模型训练中的正则化1.正则化的定义与作用2.常见的正则化方法及其应用场景2.1 L1正则化(Lasso)2.2 L2正则化(Ridge)2.3 弹性网络正则化(Elas…...
如何开放2375和2376端口供Docker daemon监听
Linux (以 Ubuntu 为例) 1. 修改 Docker 配置文件 打开 Docker 的配置文件 /etc/docker/daemon.json。如果该文件不存在,则可以创建一个新的。 bash sudo nano /etc/docker/daemon.json在配置文件中添加以下内容: json {"hosts": ["un…...
Vue.js组件开发-如何实现路由懒加载
在Vue.js应用中,路由懒加载是一种优化性能的技术,它允许在需要时才加载特定的路由组件,而不是在应用启动时加载所有组件。这样可以显著减少初始加载时间,提高用户体验。在Vue Router中,实现路由懒加载非常简单…...
rclone,云存储备份和迁移的瑞士军刀,千字常文解析,附下载链接和安装操作步骤...
一、什么是rclone? rclone是一个命令行程序,全称:rsync for cloud storage。是用于将文件和目录同步到云存储提供商的工具。因其支持多种云存储服务的备份,如Google Drive、Amazon S3、Dropbox、Backblaze B2、One Drive、Swift、…...
集成学习算法
目录 1.必要的导入 2.Bagging集成 3.基于matplotlib写一个函数对决策边界做可视化 4.总结图中结论 5.扩展说明 1.必要的导入 # To support both python 2 and python 3 from __future__ import division, print_function, unicode_literals# Common imports import numpy as np…...
vue3之pinia学习
最近查看了pinia这个状态管理管理,想跟大家一起学习下,下面是我的个人理解,希望对大家有帮助,我们开始吧! 第一步:安装pinia npm install pinia 第二步:创建pinia <script setup langts&…...
Flink (七): DataStream API (四) Watermarks
1. Event Time and Processing Time 1. 1 处理时间(Processing time) 处理时间是指执行相应操作的机器的系统时间。当流处理程序基于处理时间运行时,所有基于时间的操作(如时间窗口)将使用执行相应算子的机器的系统时…...
卷积神经05-GAN对抗神经网络
卷积神经05-GAN对抗神经网络 使用Python3.9CUDA11.8Pytorch实现一个CNN优化版的对抗神经网络 简单的GAN图片生成 CNN优化后的图片生成 优化模型代码对比 0-核心逻辑脉络 1)Anacanda使用CUDAPytorch2)使用本地MNIST进行手写图片训练3)…...
【原创】大数据治理入门(2)《提升数据质量:质量评估与改进策略》入门必看 高赞实用
提升数据质量:质量评估与改进策略 引言:数据质量的概念 在大数据时代,数据的质量直接影响到数据分析的准确性和可靠性。数据质量是指数据在多大程度上能够满足其预定用途,确保数据的准确性、完整性、一致性和及时性是数据质量的…...
GLM: General Language Model Pretraining with Autoregressive Blank Infilling论文解读
论文地址:https://arxiv.org/abs/2103.10360 参考:https://zhuanlan.zhihu.com/p/532851481 GLM混合了自注意力和masked注意力,而且使用了2D位置编码。第一维的含义是在PartA中的位置,如5 5 5。第二维的含义是在Span内部的位置&a…...
总结SpringBoot项目中读取resource目录下的文件多种方法
系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 例如:第一章 Python 机器学习入门之pandas的使用 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目…...
云原生第四次作业
下载 [rootopenEuler-1 ~]# wget https://archive.apache.org/dist/httpd/httpd-2.4.46.tar.gz 压缩 配置实验环境 [rootopenEuler-1 httpd-2.4.46]# yum -y install apr apr-devel cyrus-sasl-devel expat-devel libdb-devel openldap-devel apr-util-devel apr-util pcre-d…...
day10_Structured Steaming
文章目录 Structured Steaming一、结构化流介绍(了解)1、有界和无界数据2、基本介绍3、使用三大步骤(掌握)4.回顾sparkSQL的词频统计案例 二、结构化流的编程模型(掌握)1、数据结构2、读取数据源2.1 File Source2.2 Socket Source…...
设计模式-工厂模式/抽象工厂模式
工厂模式 定义 定义一个创建对象的接口,让子类决定实列化哪一个类,工厂模式使一个类的实例化延迟到其子类; 工厂方法模式是简单工厂模式的延伸。在工厂方法模式中,核心工厂类不在负责产品的创建,而是将具体的创建工作…...
【算法学习】——整数划分问题详解(动态规划)
🧮整数划分问题是一个较为常见的算法题,很多问题从整数划分这里出发,进行包装,形成新的题目,所以完全理解整数划分的解决思路对于之后的进一步学习算法是很有帮助的。 「整数划分」通常使用「动态规划」解决࿰…...
【新教程】Ubuntu 24.04 单节点安装slurm
背景 网上教程老旧,不适用。 详细步骤 1、安装slurm sudo apt install slurm-wlm slurm-wlm-doc -y检查是否安装成功: slurmd --version如果得到slurm-wlm 23.11.4,表明安装成功。 2、配置slurm。 使用命令: sudo vi /etc/s…...
window下用vim
Windows 默认不支持 vim 命令,需要手动安装后才能使用。以下是解决方案: 1. 安装 Vim 编辑器 方法 1:通过 Scoop 或 Chocolatey 安装 使用 Scoop: 安装 Scoop(如果尚未安装):iwr -useb get.sco…...
citrix netscaler13.1 重写负载均衡响应头(基础版)
在 Citrix NetScaler 13.1 中,Rewrite Actions 用于对负载均衡响应进行修改,包括替换、删除和插入 HTTP 响应头。这些操作可以通过自定义策略来完成,帮助你根据需求调整请求内容。以下是三种常见的操作: 1. Replace (替换响应头)…...
使用PWM生成模式驱动BLDC三相无刷直流电机
引言 在 TI 的无刷直流 (BLDC) DRV8x 产品系列使用的栅极驱动器应用中,通常使用一些控制模式来切换MOSFET 开关的输出栅极。这些控制模式包括:1x、3x、6x 和独立脉宽调制 (PWM) 模式。 不过,DRV8x 产品系列(例如 DRV8311&…...
常见的php框架有哪几个?
一直以来,PHP作为一种广泛使用的编程语言,拥有着许多优秀的框架来帮助开发人员快速构建稳定的Web应用程序。本文降为大家介绍几种常见的PHP的主流框架,以及它们相关的特点和使用场景。如有问题,欢迎指正! 1.Laravel&a…...
机器学习(2):线性回归Python实现
1 概念回顾 1.1 模型假设 线性回归模型假设因变量y yy与自变量x xx之间的关系可以用以下线性方程表示: y β 0 β 1 ⋅ X 1 β 2 ⋅ X 2 … β n ⋅ X n ε y 是因变量 (待预测值);X1, X2, ... Xn 是自变量(特征)β0, β1,…...
Unity-Mirror网络框架-从入门到精通之RigidbodyPhysics示例
文章目录 前言示例一、球体的基础配置二、三个球体的设置差异三、示例意图LatencySimulation前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。本系列文章将为读者提供对Mirror网络框架的深入了解,涵盖从基础到高级的多个主题。Mirror是一个用于Unity的开…...
【Unity-Animator】通过 StateMachineBehaviour 实现回调
StateMachineBehaviour 简介 StateMachineBehaviour是一个基类,所有状态脚本都派生自该类。它可以在状态机进入、退出或更新状态时执行代码,而无需编写自己的逻辑来测试和检测状态的变化。这使得开发者可以更方便地处理状态转换时的逻辑,例…...
并行服务、远程SSH无法下载conda,报错404
原下载代码无效,报错404 wget -c https://repo.anaconda.com/archive/Anaconda3-2023.03-1-Linux-x86_64.sh 使用下面代码下载 wget --user-agent"User-Agent: Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.9.2.12) Gecko/20101026 Firefox/3.6.12…...
cuquantum 简介
1. 关于 cuquantum 概述 官方文档: https://docs.nvidia.com/cuda/cuquantum/latest/appliance/overview.html#prerequisites NVIDIA 的 cuQuantum 是一个专门用于量子计算的高性能库,旨在加速量子电路的模拟和量子算法的执行。cuQuantum 提供了一系列…...