主要组件:
- Multi-Head Self-Attention (多头自注意力)
- Position Encoding (位置编码)
- Feed Forward Network (前馈神经网络)
- Encoder/Decoder Layer (编码器/解码器层)
- Complete Transformer Model (完整模型)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制 (Multi-Head Self-Attention)
核心思想:
- 将输入投影到Q、K、V三个矩阵
- 计算注意力权重:Attention(Q,K,V) = softmax(QK^T/√d_k)V
- 多个注意力头并行计算,捕获不同位置和表示子空间的信息
"""def __init__(self, d_model, n_heads, dropout=0.1):"""Args:d_model: 模型维度 (通常512或768)n_heads: 注意力头数 (通常8或12)dropout: dropout概率"""super(MultiHeadAttention, self).__init__()# 确保d_model能被n_heads整除assert d_model % n_heads == 0self.d_model = d_modelself.n_heads = n_headsself.d_k = d_model // n_heads # 每个头的维度# 线性变换层:将输入投影到Q、K、V# 注意:这里用一个大矩阵同时计算所有头的QKV,更高效self.w_q = nn.Linear(d_model, d_model, bias=False)self.w_k = nn.Linear(d_model, d_model, bias=False) self.w_v = nn.Linear(d_model, d_model, bias=False)# 输出投影层self.w_o = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)# 缩放因子,防止softmax饱和self.scale = math.sqrt(self.d_k)def forward(self, query, key, value, mask=None):"""Args:query: [batch_size, seq_len, d_model]key: [batch_size, seq_len, d_model] value: [batch_size, seq_len, d_model]mask: [batch_size, seq_len, seq_len] 注意力掩码Returns:output: [batch_size, seq_len, d_model]attention_weights: [batch_size, n_heads, seq_len, seq_len]"""batch_size, seq_len, d_model = query.size()# 1. 线性变换得到Q、K、VQ = self.w_q(query) # [batch_size, seq_len, d_model]K = self.w_k(key) # [batch_size, seq_len, d_model]V = self.w_v(value) # [batch_size, seq_len, d_model]# 2. 重塑为多头形式# 注意:Q, K, V的序列长度可能不同(特别是在交叉注意力中)q_seq_len = query.size(1)k_seq_len = key.size(1)v_seq_len = value.size(1)Q = Q.view(batch_size, q_seq_len, self.n_heads, self.d_k)K = K.view(batch_size, k_seq_len, self.n_heads, self.d_k)V = V.view(batch_size, v_seq_len, self.n_heads, self.d_k)# 转置以便矩阵乘法: [batch_size, n_heads, seq_len, d_k]Q = Q.transpose(1, 2)K = K.transpose(1, 2)V = V.transpose(1, 2)# 3. 计算注意力attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask, self.scale)# 4. 拼接多头结果# [batch_size, n_heads, q_seq_len, d_k] -> [batch_size, q_seq_len, n_heads, d_k]attention_output = attention_output.transpose(1, 2).contiguous()# [batch_size, q_seq_len, n_heads, d_k] -> [batch_size, q_seq_len, d_model]attention_output = attention_output.view(batch_size, q_seq_len, d_model)# 5. 输出投影output = self.w_o(attention_output)return output, attention_weightsdef scaled_dot_product_attention(self, Q, K, V, mask, scale):"""缩放点积注意力核心计算公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V"""# 计算注意力分数:QK^T# [batch_size, n_heads, seq_len, d_k] × [batch_size, n_heads, d_k, seq_len]# = [batch_size, n_heads, seq_len, seq_len]attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale# 应用掩码(如果提供)if mask is not None:# 将掩码位置的分数设为很小的负数,softmax后接近0attention_scores = attention_scores.masked_fill(mask == 0, -1e9)# 计算注意力权重attention_weights = F.softmax(attention_scores, dim=-1)# 只在训练时应用dropoutif self.training:attention_weights = self.dropout(attention_weights)# 应用注意力权重到V# [batch_size, n_heads, seq_len, seq_len] × [batch_size, n_heads, seq_len, d_k]# = [batch_size, n_heads, seq_len, d_k]attention_output = torch.matmul(attention_weights, V)return attention_output, attention_weights
class PositionalEncoding(nn.Module):
"""
位置编码 (Positional Encoding)
由于Transformer没有循环或卷积结构,需要显式地给序列添加位置信息
使用sin/cos函数生成固定的位置编码公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""def __init__(self, d_model, max_seq_len=5000):"""Args:d_model: 模型维度max_seq_len: 支持的最大序列长度"""super(PositionalEncoding, self).__init__()# 创建位置编码矩阵pe = torch.zeros(max_seq_len, d_model)# 位置索引 [0, 1, 2, ..., max_seq_len-1]position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)# 计算除数项:10000^(2i/d_model)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 应用sin和cospe[:, 0::2] = torch.sin(position * div_term) # 偶数位置用sinpe[:, 1::2] = torch.cos(position * div_term) # 奇数位置用cos# 添加batch维度并注册为buffer(不参与梯度更新)pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]self.register_buffer('pe', pe)def forward(self, x):"""Args:x: [batch_size, seq_len, d_model]Returns:x + positional_encoding: [batch_size, seq_len, d_model]"""# 取出对应长度的位置编码并加到输入上seq_len = x.size(1)return x + self.pe[:, :seq_len, :]
class FeedForward(nn.Module):
"""
前馈神经网络 (Feed Forward Network)
结构:Linear -> ReLU -> Linear
通常中间层维度是输入的4倍(如512->2048->512)
"""def __init__(self, d_model, d_ff, dropout=0.1):"""Args:d_model: 输入/输出维度d_ff: 中间层维度(通常是d_model的4倍)dropout: dropout概率"""super(FeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.linear2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):"""Args:x: [batch_size, seq_len, d_model]Returns:output: [batch_size, seq_len, d_model]"""# Linear -> ReLU -> Dropout -> Linearreturn self.linear2(self.dropout(F.relu(self.linear1(x))))
class EncoderLayer(nn.Module):
"""
Transformer编码器层
结构:
1. Multi-Head Self-Attention + 残差连接 + LayerNorm
2. Feed Forward + 残差连接 + LayerNorm
"""def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super(EncoderLayer, self).__init__()self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)self.feed_forward = FeedForward(d_model, d_ff, dropout)# Layer Normalizationself.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):"""Args:x: [batch_size, seq_len, d_model]mask: 注意力掩码Returns:output: [batch_size, seq_len, d_model]"""# 1. Self-Attention + 残差连接 + LayerNormattn_output, _ = self.self_attention(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))# 2. Feed Forward + 残差连接 + LayerNorm ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return x
class DecoderLayer(nn.Module):
"""
Transformer解码器层
结构:
1. Masked Multi-Head Self-Attention + 残差 + LayerNorm
2. Multi-Head Cross-Attention + 残差 + LayerNorm
3. Feed Forward + 残差 + LayerNorm
"""def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)self.feed_forward = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):"""Args:x: 解码器输入 [batch_size, tgt_len, d_model]encoder_output: 编码器输出 [batch_size, src_len, d_model]src_mask: 源序列掩码tgt_mask: 目标序列掩码(下三角掩码)Returns:output: [batch_size, tgt_len, d_model]"""# 1. Masked Self-Attention(防止看到未来信息)self_attn_output, _ = self.self_attention(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(self_attn_output))# 2. Cross-Attention(解码器attend到编码器输出)cross_attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)x = self.norm2(x + self.dropout(cross_attn_output))# 3. Feed Forwardff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return x
class Transformer(nn.Module):
"""
完整的Transformer模型
包含:
- 输入嵌入 + 位置编码
- N层编码器
- N层解码器
- 输出线性层
"""def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, n_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):"""Args:src_vocab_size: 源语言词汇表大小tgt_vocab_size: 目标语言词汇表大小d_model: 模型维度n_heads: 注意力头数n_layers: 编码器/解码器层数d_ff: 前馈网络中间层维度max_seq_len: 最大序列长度dropout: dropout概率"""super(Transformer, self).__init__()self.d_model = d_model# 词嵌入层self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)# 位置编码self.positional_encoding = PositionalEncoding(d_model, max_seq_len)# 编码器层self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])# 解码器层self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout)for _ in range(n_layers)])# 输出投影层self.output_projection = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)# 参数初始化self.init_parameters()def init_parameters(self):"""Xavier初始化"""for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def encode(self, src, src_mask=None):"""编码器前向传播Args:src: 源序列 [batch_size, src_len]src_mask: 源序列掩码Returns:encoder_output: [batch_size, src_len, d_model]"""# 词嵌入 + 位置编码src_emb = self.src_embedding(src) * math.sqrt(self.d_model)src_emb = self.positional_encoding(src_emb)src_emb = self.dropout(src_emb)# 通过编码器层encoder_output = src_embfor encoder_layer in self.encoder_layers:encoder_output = encoder_layer(encoder_output, src_mask)return encoder_outputdef decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):"""解码器前向传播Args:tgt: 目标序列 [batch_size, tgt_len]encoder_output: 编码器输出 [batch_size, src_len, d_model]src_mask: 源序列掩码tgt_mask: 目标序列掩码Returns:decoder_output: [batch_size, tgt_len, d_model]"""# 词嵌入 + 位置编码tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)tgt_emb = self.positional_encoding(tgt_emb)tgt_emb = self.dropout(tgt_emb)# 通过解码器层decoder_output = tgt_embfor decoder_layer in self.decoder_layers:decoder_output = decoder_layer(decoder_output, encoder_output, src_mask, tgt_mask)return decoder_outputdef forward(self, src, tgt, src_mask=None, tgt_mask=None):"""完整前向传播Args:src: 源序列 [batch_size, src_len]tgt: 目标序列 [batch_size, tgt_len]src_mask: 源序列掩码tgt_mask: 目标序列掩码Returns:output: [batch_size, tgt_len, tgt_vocab_size]"""# 编码encoder_output = self.encode(src, src_mask)# 解码decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)# 输出投影output = self.output_projection(decoder_output)return output
def create_padding_mask(seq, pad_idx=0):
"""
创建padding掩码,遮蔽padding位置
Args:seq: [batch_size, seq_len]pad_idx: padding token的索引Returns:mask: [batch_size, 1, 1, seq_len]
"""
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask
def create_look_ahead_mask(seq_len):
"""
创建下三角掩码,防止解码器看到未来信息
Args:seq_len: 序列长度Returns:mask: [1, 1, seq_len, seq_len]
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0)
使用示例和训练代码
if name == "main":
# 模型参数
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
n_heads = 8
n_layers = 6
d_ff = 2048
max_seq_len = 100
# 创建模型
model = Transformer(src_vocab_size=src_vocab_size,tgt_vocab_size=tgt_vocab_size,d_model=d_model,n_heads=n_heads,n_layers=n_layers,d_ff=d_ff,max_seq_len=max_seq_len
)print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")# 模拟数据
batch_size = 32
src_len = 20
tgt_len = 25src = torch.randint(1, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))# 创建掩码
src_mask = create_padding_mask(src)
tgt_mask = create_look_ahead_mask(tgt_len) & create_padding_mask(tgt)# 前向传播
with torch.no_grad():output = model(src, tgt, src_mask, tgt_mask)print(f"输出形状: {output.shape}") # [batch_size, tgt_len, tgt_vocab_size]# 简单训练循环示例
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略padding
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)model.train()
for epoch in range(3):# 前向传播output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])# 计算损失(预测下一个token)target = tgt[:, 1:].contiguous().view(-1)output = output.contiguous().view(-1, tgt_vocab_size)loss = criterion(output, target)# 反向传播optimizer.zero_grad()loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")print("\n=== Transformer模型实现完成 ===")