深入理解 PyTorch 的 view() 函数:以多头注意力机制(Multi-Head Attention)为例 (中英双语)
深入理解 PyTorch 的 view()
函数:以多头注意力机制(Multi-Head Attention)为例
在深度学习模型的实现中,view()
是 PyTorch 中一个非常常用的张量操作函数,它能够改变张量的形状(shape)而不改变数据的内容。本文将结合多头注意力机制中的具体实现,详细解析 view()
的作用、使用场景及其与其他操作的结合。
一、view()
函数的基本概念
view()
是 PyTorch 提供的一个高效重塑张量形状的函数。其功能类似于 NumPy 的 reshape()
,但它要求张量的内存布局是连续的。如果张量不连续,需要先使用 .contiguous()
方法让张量变成连续的内存布局。
语法:
tensor.view(*shape)
tensor
:需要被重新调整形状的张量。*shape
:目标形状,-1
表示自动推导维度大小,确保数据总量不变。
使用注意事项:
- 数据总量(元素数量)必须保持不变:
- 如果原始张量的形状为
(a, b)
,则新形状中各维度的乘积必须等于a * b
。
- 如果原始张量的形状为
- 连续性要求:
- 如果张量在内存中不是连续存储的,调用
view()
会报错,需要先调用.contiguous()
。
- 如果张量在内存中不是连续存储的,调用
二、结合多头注意力机制理解 view()
的作用
在多头注意力机制(Multi-Head Attention, MHA)中,需要将输入的张量沿最后一维切分成多个“头”(head)。我们以以下代码为例,逐步分析 view()
的实际作用。
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
假设输入张量:
x
的形状为 (B, T, C)
:
B
:Batch size,表示每个 batch 的样本数。T
:序列长度。C
:特征维度(通道数)。
多头注意力需要将最后一维 C
切分成 n_head
个头,每个头的维度是 head_size = C // n_head
,从而得到形状为 (B, T, n_head, head_size)
的张量。以下是具体的代码实现和解读。
三、代码实现与解析
1. 重新调整张量形状:切分多头
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
-
view(B, T, self.n_head, C // self.n_head)
:- 使用
view()
将原始张量(B, T, C)
调整为(B, T, n_head, head_size)
,其中head_size = C // n_head
。 - 每个维度的具体含义:
B
:Batch size。T
:序列长度。n_head
:多头数量。head_size
:每个头的特征维度。
- 目的:切分出多头,每个头独立计算注意力。
- 使用
-
.transpose(1, 2)
:- 调整维度顺序,将形状从
(B, T, n_head, head_size)
转换为(B, n_head, T, head_size)
。 - 目的:为了后续计算注意力时,每个头可以独立计算。
- 调整维度顺序,将形状从
2. 计算注意力权重
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
att = F.softmax(att, dim=-1)
q @ k.transpose(-2, -1)
:- 计算查询向量(query)与键向量(key)的点积。
k.transpose(-2, -1)
将k
的最后两维转置,从(B, nh, T, hs)
转换为(B, nh, hs, T)
,以便进行矩阵乘法。- 最终
att
的形状为(B, nh, T, T)
,表示每个头的注意力得分矩阵。
3. 添加 Mask
详细解释请看文末。
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(x.dtype).min)
- 通过 Mask 确保每个位置只关注前面的序列。
4. 计算加权输出并恢复形状
y = att @ v # (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.o_proj(y)
-
att @ v
:- 使用注意力得分加权值向量
v
,输出形状为(B, nh, T, hs)
。
- 使用注意力得分加权值向量
-
y.transpose(1, 2)
:- 调整维度顺序,将形状从
(B, nh, T, hs)
转换为(B, T, nh, hs)
。
- 调整维度顺序,将形状从
-
.view(B, T, C)
:- 使用
view()
将多头的输出重新组合为单个张量,恢复到原始特征维度。
- 使用
四、总结:view()
的核心作用
-
切分特征维度:
view()
将张量沿最后一维切分成多头,为每个头的独立计算创造条件。
-
调整张量形状:
- 将
(B, T, C)
重塑为(B, T, n_head, head_size)
,然后通过transpose()
等操作方便后续矩阵运算。
- 将
-
恢复原始形状:
- 最终通过
view()
将多头输出重新组合成单个张量,便于后续网络层处理。
- 最终通过
view()
的使用贯穿整个多头注意力机制的实现,其灵活性和高效性使其成为 PyTorch 中不可或缺的操作函数。
五、view()
与其他操作的对比
reshape()
:更通用,不要求张量是连续的,但可能会引入额外开销。.contiguous()
:与view()
配合使用,确保张量的内存布局连续。
六、完整代码示例
以下是一个完整的代码示例,展示如何通过 view()
实现多头注意力机制:
import torch
import torch.nn.functional as F
import math# 假设输入数据
B, T, C = 4, 512, 128
n_head = 8
head_size = C // n_head
x = torch.randn(B, T, C)# 线性变换
q_proj = torch.nn.Linear(C, C)
k_proj = torch.nn.Linear(C, C)
v_proj = torch.nn.Linear(C, C)
o_proj = torch.nn.Linear(C, C)# 计算 Q, K, V
q = q_proj(x)
k = k_proj(x)
v = v_proj(x)# 切分多头
q = q.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)
k = k.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)
v = v.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)# 注意力机制
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
att = F.softmax(att, dim=-1)
y = att @ v # (B, n_head, T, head_size)# 恢复形状
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
y = o_proj(y)
希望本文对理解 PyTorch 的 view()
函数以及其在多头注意力机制中的应用有所帮助.
英文版
Understanding PyTorch’s view()
Function: An Example with Multi-Head Attention (MHA)
In PyTorch, the view()
function is a powerful tool for reshaping tensors. It is frequently used in deep learning to manipulate tensor shapes for specific tasks, such as in the implementation of Multi-Head Attention (MHA). This blog post will break down the purpose, functionality, and application of view()
in the context of MHA, using a concrete example.
1. What is view()
?
The view()
function in PyTorch is used to reshape a tensor without changing its data. It is analogous to NumPy’s reshape()
function, but with a key requirement: the tensor must have a contiguous memory layout.
Syntax:
tensor.view(*shape)
tensor
: The tensor to reshape.*shape
: The new shape for the tensor. A-1
can be used for one dimension to infer its size automatically, provided the total number of elements remains constant.
Key Points:
- The total number of elements must remain the same:
- For example, a tensor of shape
(4, 128)
can be reshaped into(8, 64)
but not into(5, 64)
because4 * 128 != 5 * 64
.
- For example, a tensor of shape
- The tensor must have contiguous memory:
- If the tensor isn’t contiguous, you must first call
.contiguous()
before usingview()
.
- If the tensor isn’t contiguous, you must first call
2. Why Use view()
in Multi-Head Attention?
Multi-Head Attention (MHA) splits the feature dimension of the input into multiple “heads.” Each head independently performs attention calculations, and the results are combined at the end. This requires reshaping tensors to group the feature dimension into multiple heads while preserving the other dimensions (like batch size and sequence length).
Input Shape:
Suppose the input tensor x
has a shape of (B, T, C)
:
B
: Batch size.T
: Sequence length.C
: Feature dimension.
If we want to use n_head
heads in the attention mechanism, the feature dimension C
is split into n_head
groups, where each group has a size of head_size = C // n_head
.
The tensor is reshaped to (B, T, n_head, head_size)
for this purpose. To facilitate calculations, the dimensions are then transposed to (B, n_head, T, head_size)
.
3. Code Implementation: Reshaping for MHA
Here’s how the reshaping is implemented in MHA:
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, n_head, T, head_size)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, n_head, T, head_size)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, n_head, T, head_size)
Breaking it Down:
-
view(B, T, self.n_head, C // self.n_head)
:- Reshapes the tensor from
(B, T, C)
to(B, T, n_head, head_size)
, where:B
is the batch size.T
is the sequence length.n_head
is the number of attention heads.head_size = C // n_head
is the size of each head.
- This effectively splits the feature dimension into
n_head
separate heads.
- Reshapes the tensor from
-
.transpose(1, 2)
:- Swaps the sequence length dimension (
T
) with the head dimension (n_head
), resulting in a shape of(B, n_head, T, head_size)
. - This format is required for the attention mechanism, as each head performs its operations independently on the sequence.
- Swaps the sequence length dimension (
4. Applying Attention and Masking
Once the input tensors (q
, k
, v
) are reshaped, attention scores are computed, masked, and the output is calculated as follows:
Attention Computation:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, n_head, T, T)
att = F.softmax(att, dim=-1)
q @ k.transpose(-2, -1)
:- Computes the dot product of the query (
q
) and the transposed key (k
), resulting in a shape of(B, n_head, T, T)
. This represents the attention scores for each head. k.transpose(-2, -1)
changesk
from(B, n_head, T, head_size)
to(B, n_head, head_size, T)
to align dimensions for the dot product.
- Computes the dot product of the query (
Masking:
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(x.dtype).min)
- A mask is applied to ensure that positions cannot “see” future tokens in the sequence.
Output Calculation:
y = att @ v # (B, n_head, T, head_size)
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
y = self.o_proj(y)
-
att @ v
:- Multiplies the attention scores with the value (
v
) tensor, resulting in a shape of(B, n_head, T, head_size)
.
- Multiplies the attention scores with the value (
-
transpose(1, 2)
:- Swaps the
n_head
dimension withT
to prepare for reshaping.
- Swaps the
-
.contiguous().view(B, T, C)
:- Flattens the heads back into a single feature dimension, restoring the original shape
(B, T, C)
.
- Flattens the heads back into a single feature dimension, restoring the original shape
5. The Role of view()
The view()
function is crucial for:
-
Splitting Dimensions:
- It divides the feature dimension (
C
) into multiple heads (n_head
) for independent attention calculations.
- It divides the feature dimension (
-
Restoring Dimensions:
- After attention calculations, it combines the outputs of all heads back into a single feature dimension.
6. Example Code
Below is the complete example of reshaping for MHA:
import torch
import torch.nn.functional as F
import math# Example input
B, T, C = 4, 512, 128
n_head = 8
head_size = C // n_head
x = torch.randn(B, T, C)# Linear projections
q_proj = torch.nn.Linear(C, C)
k_proj = torch.nn.Linear(C, C)
v_proj = torch.nn.Linear(C, C)
o_proj = torch.nn.Linear(C, C)# Compute Q, K, V
q = q_proj(x)
k = k_proj(x)
v = v_proj(x)# Reshape for multi-head attention
q = q.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)
k = k.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)
v = v.view(B, T, n_head, head_size).transpose(1, 2) # (B, n_head, T, head_size)# Attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
att = F.softmax(att, dim=-1)
y = att @ v # (B, n_head, T, head_size)# Restore original shape
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
y = o_proj(y)
7. Summary
The view()
function plays a critical role in tensor manipulation for multi-head attention by enabling:
- Efficient splitting of dimensions into multiple heads.
- Seamless reshaping of tensor shapes for independent attention calculations.
- Reconstruction of the original shape after attention processing.
By combining view()
with operations like transpose()
, MHA becomes both efficient and modular, making it a cornerstone of modern NLP architectures.
【1】代码分析:att = att.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(x.dtype).min)
这行代码的目的是在 计算注意力分数(attention scores)后,对其进行遮掩(masking),以确保在某些情况下(如自回归模型的解码过程)当前位置无法访问未来的位置信息。
下面分步骤详细讲解这行代码的含义和作用:
1. self.bias[:,:,:T,:T]
含义:
self.bias
是一个用于遮掩的矩阵,通常是一个 上三角矩阵(triangular matrix),大小为(1, 1, max_length, max_length)
。- 它的作用是为注意力分数提供一种机制,来限制某些位置的访问。比如,在自回归任务中,每个时间步只允许看到当前及之前的时间步,不能看到未来的时间步。
max_length
是模型支持的最大序列长度,T
是当前序列的实际长度。通过self.bias[:,:,:T,:T]
截取一个大小为(1, 1, T, T)
的子矩阵,表示当前序列的遮掩规则。
举例:
假设 T = 4
,截取后的子矩阵形状为 (1, 1, 4, 4)
,其内容可能如下:
self.bias[:,:,:T,:T] =
[[[[1, 0, 0, 0],[1, 1, 0, 0],[1, 1, 1, 0],[1, 1, 1, 1]]]]
- 1 表示允许访问,0 表示禁止访问。
- 这种矩阵通常由
torch.tril()
函数生成(下三角部分为 1,上三角部分为 0)。
2. self.bias[:,:,:T,:T] == 0
含义:
== 0
将遮掩矩阵中的 0 位置标记为True
,表示这些位置需要被屏蔽。- 结果是一个布尔矩阵,形状仍然为
(1, 1, T, T)
。
举例:
对应上面的例子:
self.bias[:,:,:T,:T] == 0 =
[[[[False, True, True, True],[False, False, True, True],[False, False, False, True],[False, False, False, False]]]]
3. torch.finfo(x.dtype).min
含义:
torch.finfo(x.dtype).min
表示当前数据类型(x.dtype
)的最小值。- 例如,如果
x
的数据类型是float32
,那么torch.finfo(torch.float32).min
的值约为-3.4e38
。 - 这个极小值被用作屏蔽位置的填充值,因为在后续的 Softmax 操作中,极小值的指数将接近于 0,从而使这些位置的注意力权重为 0。
4. att.masked_fill(...)
含义:
masked_fill(mask, value)
是 PyTorch 中的一种操作,用于根据布尔掩码mask
将张量中对应位置填充为指定的值value
。- 在这段代码中,
att
是注意力分数矩阵,形状为(B, n_head, T, T)
,其中:B
是批量大小。n_head
是注意力头的数量。T
是序列长度。
- 通过
masked_fill()
操作,将遮掩矩阵中为True
的位置(即不允许访问的位置)填充为极小值torch.finfo(x.dtype).min
。
5. 完整作用
这行代码的作用是:
- 将注意力分数矩阵
att
中 不允许访问的位置 设置为极小值,以确保这些位置在 Softmax 计算时权重接近于 0,从而被忽略。
6. 举例说明
假设:
att
的形状为(1, 1, 4, 4)
,内容如下:
att =
[[[[0.1, 0.2, 0.3, 0.4],[0.5, 0.6, 0.7, 0.8],[0.9, 1.0, 1.1, 1.2],[1.3, 1.4, 1.5, 1.6]]]]
- 对应的遮掩矩阵:
self.bias[:,:,:4,:4] == 0 =
[[[[False, True, True, True],[False, False, True, True],[False, False, False, True],[False, False, False, False]]]]
- 极小值(例如
-1e9
)用于屏蔽。
执行 att = att.masked_fill(self.bias[:,:,:4,:4] == 0, -1e9)
后:
att =
[[[[ 0.1, -1e9, -1e9, -1e9],[ 0.5, 0.6, -1e9, -1e9],[ 0.9, 1.0, 1.1, -1e9],[ 1.3, 1.4, 1.5, 1.6]]]]
7. 总结
这行代码实现了遮掩逻辑,用于屏蔽注意力机制中不应该访问的位置,其主要作用如下:
- 限制注意力范围:确保当前时间步无法访问未来时间步的信息(例如语言模型的解码阶段)。
- 保留无效位置的注意力权重为 0:通过填充极小值,使这些位置在 Softmax 操作后被忽略。
这对于自回归任务(如 GPT 类模型)和其他需要时间步约束的任务至关重要。
【2】代码分析:q = q_proj(x)
是怎么做的?
q_proj(x)
是通过一个线性层(torch.nn.Linear
)对输入张量 x
进行线性变换,最终输出一个和输入形状相同的张量(除非特意改变输出维度)。
1. 线性层的作用
torch.nn.Linear
是 PyTorch 中的全连接层(fully connected layer),它的作用是:
y = x ⋅ W T + b \text{y} = \text{x} \cdot \text{W}^T + \text{b} y=x⋅WT+b
- 输入矩阵:
x
的形状为(B, T, C)
,其中:B
是批量大小(batch size)。T
是序列长度(sequence length)。C
是特征维度(embedding size)。
- 权重矩阵:
W
是线性层的权重,形状为(C, C)
。 - 偏置向量:
b
是线性层的偏置,形状为(C,)
。 - 输出矩阵:结果
y
的形状与输入x
的形状一致,即(B, T, C)
。
2. q_proj
的定义
q_proj
是一个线性层,初始化时:
q_proj = torch.nn.Linear(C, C)
- 该线性层将输入张量
x
的最后一维(大小为C
)映射到一个同样大小为C
的新表示。 q_proj
的内部参数:- 权重矩阵
W_q
,形状为(C, C)
。 - 偏置向量
b_q
,形状为(C,)
。
- 权重矩阵
3. q_proj(x)
的执行过程
当执行 q_proj(x)
时,会进行以下操作:
- 矩阵乘法:将输入张量
x
的最后一维与权重矩阵W_q
相乘。- 输入
x
的形状为(B, T, C)
,与W_q
的形状(C, C)
相乘,最后一维变换为新表示。 - 输出结果为形状
(B, T, C)
。
- 输入
- 加偏置:在矩阵乘法的结果上,加上偏置向量
b_q
,偏置会广播到每个位置。
4. 举例说明
假设:
- 输入张量
x
的形状为(4, 512, 128)
:
每个元素随机生成。x = torch.randn(4, 512, 128)
- 权重矩阵
W_q
的形状为(128, 128)
,偏置b_q
的形状为(128,)
。
执行 q_proj(x)
时:
- 矩阵乘法:每个时间步(
T=512
)和批次(B=4
)中的向量(大小为 128)都会与权重矩阵W_q
(大小为 128×128)相乘,得到一个新的大小为 128 的向量。 - 加偏置:在每个位置上,加上偏置向量
b_q
。
最终输出张量 q
的形状仍为 (4, 512, 128)
,但内容经过线性变换,表示的是对输入张量 x
的一种特征提取。
5. 小例子
假设:
-
输入
x
为(B=2, T=3, C=4)
的张量:x = torch.tensor([[[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0],[9.0, 10.0, 11.0, 12.0]],[[13.0, 14.0, 15.0, 16.0],[17.0, 18.0, 19.0, 20.0],[21.0, 22.0, 23.0, 24.0]]])
-
权重矩阵
W_q
初始化为:W_q = torch.tensor([[1.0, 0.0, 0.0, 0.0],[0.0, 1.0, 0.0, 0.0],[0.0, 0.0, 1.0, 0.0],[0.0, 0.0, 0.0, 1.0]]) # 单位矩阵
偏置向量
b_q
为:b_q = torch.tensor([1.0, 1.0, 1.0, 1.0])
执行 q_proj(x)
:
- 矩阵乘法:
x
的每个向量与W_q
相乘(这里W_q
是单位矩阵,所以输出等于输入)。
- 加偏置:
- 每个向量加上偏置
[1.0, 1.0, 1.0, 1.0]
。
- 每个向量加上偏置
输出 q
为:
q = [[[ 2.0, 3.0, 4.0, 5.0],[ 6.0, 7.0, 8.0, 9.0],[10.0, 11.0, 12.0, 13.0]],[[14.0, 15.0, 16.0, 17.0],[18.0, 19.0, 20.0, 21.0],[22.0, 23.0, 24.0, 25.0]]]
6. 总结
q_proj(x)
的作用是对输入x
的最后一维(特征维度)进行线性变换,提取注意力机制中需要的查询特征(Query)。- 输入和输出形状保持一致,内容经过了权重矩阵和偏置的变换。
- 在多头注意力(Multi-Head Attention)中,这种线性变换用于生成 Query、Key 和 Value,以便进一步计算注意力分数和上下文表示。
【3】 q @ k.transpose(-2, -1)
是如何进行矩阵乘法的?
q @ k.transpose(-2, -1)
是在多头自注意力机制(Multi-Head Self-Attention)中计算查询向量(query)与键向量(key)的点积注意力分数(attention score)的关键步骤。
具体过程如下:
-
q
的形状:查询向量q
的形状为(B, nh, T, hs)
,其中:B
是批量大小(batch size)。nh
是注意力头的数量(number of heads)。T
是序列长度(sequence length)。hs
是单个注意力头的特征维度(head size)。
-
k.transpose(-2, -1)
的形状:键向量k
的形状原本为(B, nh, T, hs)
,通过k.transpose(-2, -1)
,将最后两维交换,变成(B, nh, hs, T)
。 -
矩阵乘法:
q @ k.transpose(-2, -1)
是两个张量的矩阵乘法:- 查询向量
q
的最后两维(T, hs)
,与转置后的键向量k.transpose(-2, -1)
的前两维(hs, T)
相乘。 - 结果是一个新的张量,形状为
(B, nh, T, T)
。 - 这个结果表示在每个注意力头上,不同序列位置之间的点积注意力分数。
- 查询向量
2. 举例说明
假设:
- 批量大小
B = 1
(只有一个样本)。 - 注意力头数量
nh = 1
(只有一个头)。 - 序列长度
T = 3
(序列中有 3 个时间步)。 - 每个注意力头的特征维度
hs = 2
(每个向量的特征长度为 2)。
输入张量 q
和 k
:
-
查询向量
q
的形状为(1, 1, 3, 2)
:q = torch.tensor([[[[1, 0],[0, 1],[1, 1]]]])
-
键向量
k
的形状为(1, 1, 3, 2)
:k = torch.tensor([[[[1, 0],[1, 1],[0, 1]]]])
计算 k.transpose(-2, -1)
:
将 k
的最后两维转置,形状从 (1, 1, 3, 2)
变为 (1, 1, 2, 3)
:
k_transposed = torch.tensor([[[[1, 1, 0],[0, 1, 1]]]])
计算 q @ k.transpose(-2, -1)
:
执行矩阵乘法,将 q
的最后两维 (3, 2)
与 k_transposed
的前两维 (2, 3)
相乘,结果形状为 (1, 1, 3, 3)
。
矩阵乘法过程(以第一批次和第一注意力头为例):
- 第一行(第一个时间步与所有时间步的点积):
点积 = [ 1 0 ] ⋅ [ 1 1 0 0 1 1 ] = [ 1 1 0 ] \text{点积} = \begin{bmatrix} 1 & 0 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 1 & 0 \end{bmatrix} 点积=[10]⋅[101101]=[110] - 第二行(第二个时间步与所有时间步的点积):
点积 = [ 0 1 ] ⋅ [ 1 1 0 0 1 1 ] = [ 0 1 1 ] \text{点积} = \begin{bmatrix} 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 & 1 \end{bmatrix} 点积=[01]⋅[101101]=[011] - 第三行(第三个时间步与所有时间步的点积):
点积 = [ 1 1 ] ⋅ [ 1 1 0 0 1 1 ] = [ 1 2 1 ] \text{点积} = \begin{bmatrix} 1 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 2 & 1 \end{bmatrix} 点积=[11]⋅[101101]=[121]
最终得到的注意力分数矩阵为:
att = torch.tensor([[[[1, 1, 0],[0, 1, 1],[1, 2, 1]]]])
形状为 (1, 1, 3, 3)
。
3. 总结
q @ k.transpose(-2, -1)
是通过矩阵乘法计算序列中每个时间步之间的点积相似性。- 结果形状
(B, nh, T, T)
:- 表示在每个注意力头中,序列中每个位置(行)对其他位置(列)的相似性。
- 用途:这是多头注意力机制中用于计算注意力权重(attention scores)的核心步骤,下一步通过 softmax 函数,将这些分数归一化为概率分布,表示不同时间步之间的相关性。
在多头自注意力机制中,时间步(Time Step)指的是序列中的每个位置或词的表示(embedding)。如果用一句话 “how are you” 来解析,每个时间步就对应一个单词的表示,例如 “how” 是第一个时间步,“are” 是第二个时间步,“you” 是第三个时间步。
【4】 通过 “how are you” 来解析时间步和矩阵乘法**
1. 序列与时间步的定义
- 句子 “how are you” 可以看作一个序列,长度为 3(T = 3)。
- 每个单词都会被编码成一个向量(embedding),向量的维度为
hs
(head size,比如 2)。 - 这意味着,“how” 的表示是一个二维向量,“are” 和 “you” 也各自是二维向量。
假设以下是编码后的表示:
"how" = [1, 0] # 第一个时间步
"are" = [0, 1] # 第二个时间步
"you" = [1, 1] # 第三个时间步
这些向量会形成矩阵 ( q q q ) 和 ( k k k ),它们的形状都是 ( ( B , n h , T , h s ) (B, nh, T, hs) (B,nh,T,hs) ),这里我们假设批次大小 ( B = 1 B = 1 B=1 ),头的数量 ( n h = 1 nh = 1 nh=1 ),所以 ( q q q ) 和 ( k k k ) 的形状为 ( ( 1 , 1 , 3 , 2 ) (1, 1, 3, 2) (1,1,3,2) )。
2. 键向量 ( k k k ) 和转置 ( k . t r a n s p o s e ( − 2 , − 1 ) k.transpose(-2, -1) k.transpose(−2,−1) )
键向量 ( k k k ) 的矩阵如下(对应 “how”, “are”, “you” 的表示):
k = [[1, 0], # "how"[1, 1], # "are"[0, 1] # "you"
]
转置后(交换最后两维),矩阵变为:
k . t r a n s p o s e ( − 2 , − 1 ) = [ 1 1 0 0 1 1 ] k.transpose(-2, -1) = \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} k.transpose(−2,−1)=[101101]
3. 查询向量 ( q q q ) 的矩阵表示
查询向量 ( q q q ) 的矩阵如下(也对应 “how”, “are”, “you” 的表示):
q = [[1, 0], # "how"[0, 1], # "are"[1, 1] # "you"
]
4. 矩阵乘法 ( q @ k.transpose(-2, -1) ) 的计算
矩阵乘法的目的是计算每个时间步与序列中其他时间步之间的相似性,通过点积来完成。以下是每一行的具体计算:
-
第一个时间步(“how”)与所有时间步的点积:
点积 = [ 1 0 ] ⋅ [ 1 1 0 0 1 1 ] = [ 1 1 0 ] \text{点积} = \begin{bmatrix} 1 & 0 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 1 & 0 \end{bmatrix} 点积=[10]⋅[101101]=[110]- “how” 与 “how” 的相似性为 ( 1 )。
- “how” 与 “are” 的相似性为 ( 1 )。
- “how” 与 “you” 的相似性为 ( 0 )。
-
第二个时间步(“are”)与所有时间步的点积:
点积 = [ 0 1 ] ⋅ [ 1 1 0 0 1 1 ] = [ 0 1 1 ] \text{点积} = \begin{bmatrix} 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 & 1 \end{bmatrix} 点积=[01]⋅[101101]=[011]- “are” 与 “how” 的相似性为 ( 0 )。
- “are” 与 “are” 的相似性为 ( 1 )。
- “are” 与 “you” 的相似性为 ( 1 )。
-
第三个时间步(“you”)与所有时间步的点积:
点积 = [ 1 1 ] ⋅ [ 1 1 0 0 1 1 ] = [ 1 2 1 ] \text{点积} = \begin{bmatrix} 1 & 1 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 2 & 1 \end{bmatrix} 点积=[11]⋅[101101]=[121]- “you” 与 “how” 的相似性为 ( 1 )。
- “you” 与 “are” 的相似性为 ( 2 )。
- “you” 与 “you” 的相似性为 ( 1 )。
5. 最终的注意力分数矩阵
矩阵乘法结果是一个 ( 3 × 3 3 \times 3 3×3 ) 的矩阵,表示序列中每个时间步之间的点积分数:
Attention Scores (未归一化) = [ 1 1 0 0 1 1 1 2 1 ] \text{Attention Scores (未归一化)} = \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \\ 1 & 2 & 1 \end{bmatrix} Attention Scores (未归一化)= 101112011
- 第一行表示 “how” 与其他时间步的相似性。
- 第二行表示 “are” 与其他时间步的相似性。
- 第三行表示 “you” 与其他时间步的相似性。
6. 总结
在 “how are you” 这句话中:
- 每个时间步(单词)都会被表示为一个向量。
- 通过查询向量 ( q q q ) 和键向量 ( k k k ) 的点积计算出序列中每个位置的相似性。
- 注意力分数矩阵中的每一行表示一个单词与其他单词之间的关系。
- 这些分数随后会被归一化(通过 softmax),作为多头注意力机制的权重。
参考
[1] 手撕 MHA,阿里的一面问的真是太细了
后记
2024年12月25日13点18分于上海,在GPT4o大模型辅助下完成。
相关文章:
深入理解 PyTorch 的 view() 函数:以多头注意力机制(Multi-Head Attention)为例 (中英双语)
深入理解 PyTorch 的 view() 函数:以多头注意力机制(Multi-Head Attention)为例 在深度学习模型的实现中,view() 是 PyTorch 中一个非常常用的张量操作函数,它能够改变张量的形状(shape)而不改…...
【每日学点鸿蒙知识】获取是否有网接口、获取udid报错、本地通知、Json转Map、Window10安装Hyper-v
1、有没有获取当前是否真实有网的接口? 比如当前链接的是wifi,但是当前wifi是不能访问网络的,有没有接口可以获取到这个真实的网络访问状态? 请参考说明链接:https://developer.huawei.com/consumer/cn/doc/harmonyo…...
《Vue3 四》Vue 的组件化
组件化:将一个页面拆分成一个个小的功能模块,每个功能模块完成自己部分的独立的功能。任何应用都可以被抽象成一棵组件树。 Vue 中的根组件: Vue.createApp() 中传入对象的本质上就是一个组件,称之为根组件(APP 组件…...
Linux:alias别名永久有效
一、背景 日常使用bash时候,有些常用的命令参数的组合命令太长,很难记,此时可以利用Linux提供的alias命令生成命令的别名(命令的隐射),但是我们会发现,当退出了终端后重新登录就失效了ÿ…...
MicroDiffusion——采用新的掩码方法和改进的 Transformer 架构,实现了低预算的扩散模型
介绍 论文地址:https://arxiv.org/abs/2407.15811 现代图像生成模型擅长创建自然、高质量的内容,每年生成的图像超过十亿幅。然而,从头开始训练这些模型极其昂贵和耗时。文本到图像(T2I)扩散模型降低了部分计算成本&a…...
网神SecFox FastJson反序列化RCE漏洞复现(附脚本)
0x01 产品描述: 网神SecFox是奇安信网神信息技术(北京)股份有限公司推出的一款运维安全管理与审计系统,集“身份认证、账户管理、权限控制、运维审计”于一体,提供统一运维身份认证、细粒度的权限控制、丰富的运维审计报告、多维度的预警…...
解决无法在 Ubuntu 24.04 上运行 AppImage 应用
在 Ubuntu 24.04 中运行 AppImage 应用的完整指南 在 Ubuntu 24.04 中,许多用户可能会遇到 AppImage 应用无法启动的问题。即使你已经设置了正确的文件权限,AppImage 仍然拒绝运行。这通常是由于缺少必要的库文件所致。 问题根源:缺少 FUSE…...
Pytorch | 利用PC-I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
Pytorch | 利用PC-I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集PC-I-FGSM介绍算法原理 PC-I-FGSM代码实现PC-I-FGSM算法实现攻击效果 代码汇总pcifgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexNet对CIFAR…...
前端往后端传递参数的方式有哪些?
文章目录 1. URL 参数1.1. 查询参数(Query Parameters)1.2. 路径参数(Path Parameters) 2. 请求体(Request Body)2.1. JSON 数据2.2. 表单数据2.3. 文件上传 3. 请求头(Headers)3.1. 自定义请求…...
对抗攻击VA-I-FGSM:Adversarial Examples with Virtual Step and Auxiliary Gradients
文章目录 摘要相关定义算法流程代码:文章链接: Improving Transferability of Adversarial Examples with Virtual Step and Auxiliary Gradients 摘要 深度神经网络已被证明容易受到对抗样本的攻击,这些对抗样本通过向良性样本中添加人类难以察觉的扰动来欺骗神经网络。目…...
【Java】IO流练习
IO流练习 题干: 根据指定要求,完成电话记录、 注册、登录 注册 题干: 完成【注册】功能: 要求: 用户输入用户名、密码存入users.txt文件中 若users.txt文件不存在,创建该文件若users.txt文件存在 输入…...
红魔电竞PadPro平板解BL+ROOT权限-KernelSU+LSPosed框架支持
红魔Padpro设备目前官方未开放解锁BL,也阉割了很多解锁BL指令,造成大家都不能自主玩机。此规则从红魔8开始,就一直延续下来,后续的机型大概率也是一样的情况。好在依旧有开发者进行适配研究,目前红魔PadPro平板&#x…...
小程序配置文件 —— 12 全局配置 - pages配置
全局配置 - pages配置 在根目录下的 app.json 文件中有一个 pages 字段,这里我们介绍一下 pages 字段的具体用法; pages 字段:用来指定小程序由哪些页面组成,用来让小程序知道由哪些页面组成以及页面定义在哪个目录,…...
供应链系统设计-供应链中台系统设计(六)- 商品中心概念篇
概述 我们在供应链系统设计-中台系统设计系列(五)- 供应链中台实践概述 中描述了什么是供应链中台,供应链中台主要包含了那些组成部门。包括业务中台、通用中台等概念。为了后续方便大家对于中台有更深入的理解,我会逐一针对中台…...
leetcode 面试经典 150 题:删除有序数组中的重复项
链接删除有序数组中的重复项题序号26题型数组解题方法双指针难度简单熟练度✅✅✅✅✅ 题目 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保…...
Python 中的 lambda 函数和嵌套函数
Python 中的 lambda 函数和嵌套函数 Python 中的 lambda 函数和嵌套函数Python 中的 lambda 函数嵌套函数(内部函数)封装辅助函数闭包和工厂函数 Python 中的 lambda 函数和嵌套函数 Python 中的 lambda 函数 Lambda 函数是基于单行表达式的匿名函数。…...
Android笔试面试题AI答之Android基础(7)
Android入门请看《Android应用开发项目式教程》,视频、源码、答疑,手把手教 文章目录 1.Android开发如何提高App的兼容性?**1. 支持多版本 Android 系统****2. 适配不同屏幕尺寸和分辨率****3. 处理不同硬件配置****4. 适配不同语言和地区**…...
PhPMyadmin-cms漏洞复现
一.通过日志文件拿Shell 打开靶场连接数据库 来到sql中输入 show global variables like %general%; set global general_logon; //⽇志保存状态开启; set global general_log_file D:/phpstudy/phpstudy_pro/WWW/123.php //修改日志保存位置 show global varia…...
HTML-CSS(day01)
W3C标准: W3C( World Wide Web Consortium,万维网联盟) W3C是万维网联盟,这个组成是用来定义标准的。他们规定了一个网页是由三部分组成,分别是: 三个组成部分:(1&…...
【服务器项目部署】⭐️将本地项目部署到服务器!
目录 🍸前言 🍻一、服务器选择 🍹 二、服务器环境部署 2.1 java 环境部署 2.2 mysql 环境部署 🍸三、项目部署 3.1 静态页面调整 3.2 服务器端口开放 3.3 项目部署 🍹四、测试 🍸前言 小伙伴们大家好…...
计算机网络 (14)数字传输系统
一、定义与原理 数字传输系统,顾名思义,是一种将连续变化的模拟信号转换为离散的数字信号,并通过适当的传输媒介进行传递的系统。在数字传输系统中,信息被编码成一系列的二进制数字,即0和1,这些数字序列能够…...
机器学习周报-TCN文献阅读
文章目录 摘要Abstract 1 TCN通用架构1.1 序列建模任务描述1.2 因果卷积(Causal Convolutions)1.3 扩张卷积(Dilated Convolutions)1.4 残差连接(Residual Connections) 2 TCN vs RNN3 TCN缺点4 代码4.1 TC…...
UniApp 页面布局基础
一、UniApp 页面布局简介 在当今的移动应用开发领域,跨平台开发已成为一种主流趋势。UniApp作为一款极具影响力的跨平台开发框架,凭借其“一套代码,多端运行”的特性,为开发者们提供了极大的便利,显著提升了开发效率。…...
最新的强大的文生视频模型Pyramid Flow 论文阅读及复现
《PYRAMIDAL FLOW MATCHING FOR EFFICIENT VIDEO GENERATIVE MODELING》 论文地址:2410.05954https://arxiv.org/pdf/2410.05954 项目地址: jy0205/Pyramid-Flow: 用于高效视频生成建模的金字塔流匹配代码https://github.com/jy0205/Pyram…...
论文阅读 - 《Large Language Models Are Zero-Shot Time Series Forecasters》
Abstract 通过将时间序列编码为数字组成的字符串,我们可以将时间序列预测当做文本中下一个 token预测的框架。通过开发这种方法,我们发现像GPT-3和LLaMA-2这样的大语言模型在下游任务上可以有零样本时间序列外推能力上持平或者超过专门设计的时间序列训…...
STM32文件详解
STM32文件详解 启动文件打开MDK栈空间开辟堆空间开辟中断向量表复位程序对于 weak 的理解对于_main 函数的分析中断程序堆栈初始化系统启动流程 时钟树时钟源时钟配置函数时钟初始化配置函数 启动文件 启动文件的方式 1、初始化堆栈指针 SP _initial_sp 2、初始化程序计数器指…...
【Spring】详解(上)
Spring 框架核心原理与应用(上) 一、Spring 框架概述 (一)诞生背景 随着 Java 应用程序规模的不断扩大以及复杂度的日益提升,传统的 Java开发方式在对象管理、代码耦合度等方面面临诸多挑战。例如,对象之…...
大数据面试笔试宝典之Flink面试
1.Flink 是如何支持批流一体的? F link 通过一个底层引擎同时支持流处理和批处理. 在流处理引擎之上,F link 有以下机制: 1)检查点机制和状态机制:用于实现容错、有状态的处理; 2)水印机制:用于实现事件时钟; 3)窗口和触发器:用于限制计算范围,并定义呈现结果的…...
Rust编程与项目实战-箱
【图书介绍】《Rust编程与项目实战》-CSDN博客 《Rust编程与项目实战》(朱文伟,李建英)【摘要 书评 试读】- 京东图书 (jd.com) Rust编程与项目实战_夏天又到了的博客-CSDN博客 对于Rust而言,箱(crate)是一个独立的可编译单元&…...
git回退指定版本/复制提交id
1.使用“git reset --hard 目标版本号”命令将版本回退2.使用“git push -f”提交更改 因为我们回退后的本地库HEAD指向的版本比远程库的要旧,此时如果用“git push”会报错。 改为使用 git push -f 即可完成回退后的提交。...
数据库锁的深入探讨
数据库锁(Database Lock)是多用户环境中用于保证数据一致性和隔离性的机制。随着数据库系统的发展,特别是在高并发的场景下,锁的机制变得尤为重要。通过使用锁,数据库能够防止并发操作导致的数据冲突或不一致。本文将深…...
《机器学习》——KNN算法
文章目录 KNN算法简介KNN算法——sklearnsklearn是什么?sklearn 安装sklearn 用法 KNN算法 ——距离公式KNN算法——实例分类问题完整代码——分类问题 回归问题完整代码 ——回归问题 KNN算法简介 一、KNN介绍 全称是k-nearest neighbors,通过寻找k个距…...
iOS开发代码块-OC版
iOS开发代码块-OC版 资源分享资源使用详情Xcode自带代码块自定义代码块 资源分享 自提: 通过网盘分享的文件:CodeSnippets 2.zip 链接: https://pan.baidu.com/s/1Yh8q9PbyeNpuYpasG4IiVg?pwddn1i 提取码: dn1i Xcode中的代码片段默认放在下面的目录中…...
关于在M系列的Mac中使用SoftEtherClient软件
1. 前言 本文说明的是在M系列的苹果的MacBook中如何使用SoftetherClient这款软件,是直接在MacOS操作系统中安装连接使用,不是在PD环境或者非ARM架构的Mac中安装使用。 PS:别费劲百度了,很少有相关解决方案的,在国内会…...
【畅购商城】详情页模块之评论
目录 接口 分析 后端实现:JavaBean 后端实现 前端实现 接口 GET http://localhost:10010/web-service/comments/spu/2?current1&size2 { "code": 20000, "message": "查询成功", "data": { "impressions&q…...
机器学习DAY4续:梯度提升与 XGBoost (完)
本文将通过 XGBoost 框架来实现回归、分类和排序任务,帮助理解和掌握使用 XGBoost 解决实际问题的能力。我们将从基本的数据处理开始,逐步深入到模型训练、评估以及预测。最后,将模型进行保存和加载训练好的模型。 知识点 回归任务分类任务…...
Maven 测试和单元测试介绍
一、测试介绍 二、单元测试 1)介绍 2)快速入门 添加依赖 <dependencies><!-- junit依赖 --><dependency><groupId>org.junit.jupiter</groupId><artifactId>junit-jupiter</artifactId><version>5.9…...
LeetCode7. 整数反转
难度:中等 给你一个 32 位的有符号整数 x ,返回将 x 中的数字部分反转后的结果。 如果反转后整数超过 32 位的有符号整数的范围 [−231, 231 − 1] ,就返回 0。 假设环境不允许存储 64 位整数(有符号或无符号)。 示…...
Java编程题_面向对象和常用API01_B级
Java编程题_面向对象和常用API01_B级 第1题 面向对象、异常、集合、IO 题干: 请编写程序,完成键盘录入学生信息,并计算总分将学生信息与总分一同写入文本文件 需求:键盘录入3个学生信息(姓名,语文成绩,数学成绩) 求出每个学生的总分 ,并…...
WEB攻防-通用漏洞-文件上传-js验证-MIME验证-user.ini-语言特征
目录 定义 1.前端验证 2.MIME验证 3.htaccess文件和.user. ini 4.对内容进行了过滤,做了内容检测 5.[ ]符号过滤 6.内容检测php [] {} ; 7.()也被过滤了 8.反引号也被过滤 9.文件头检测 定义 文件上传漏洞是指攻击者上传了一个可执行文件(如木马…...
ubuntu20.04 调试bcache源码
搭建单步调试bcache的环境,/dev/sdb作为backing dev, /dev/sdc作为cache dev。 一、宿主机环境 1)安装ubuntu 20.04 : 参考ubuntu20.04 搭建kernel调试环境第一篇--安装系统_ubuntu kernel-CSDN博客安装,其中的第六…...
全国青少年信息学奥林匹克竞赛(信奥赛)备考实战之循环结构(for循环语句)(四)
实战训练1—最大差值 问题描述: 输入n个非负整数,找出这个n整数的最大值与最小值,并求最大值和最小值的差值。 输入格式: 共两行,第一行为整数的个数 n(1≤n≤1000)。第二行为n个整数的值(整…...
基于深度学习(HyperLPR3框架)的中文车牌识别系统-python程序开发测试
本篇内容为python开发,通过一个python程序,测试搭建的开发环境,读入一张带有车牌号的图片,成功识别出车牌号。 1、通过PyCharm新建一个工程,如:PlateRecognition,配置虚拟环境。 2、在工程中新…...
【SpringMVC】拦截器
拦截器(Interceptor)是一种用于动态拦截方法调用的机制。在 Spring MVC 中,拦截器能够动态地拦截控制器方法的执行过程。以下是请求发送与接收的基本流程: 当浏览器发出请求时,请求首先到达 Tomcat 服务器。Tomcat 根…...
离线的方式:往Maven的本地仓库里安装依赖
jar文件及源码的绝对路径,gav坐标,打包方式,Maven本地仓库的路径 mvn install:install-file ^-DfileD:\hello-spring-boot-starter-1.0-SNAPSHOT.jar ^-DsourcesD:\hello-spring-boot-starter-1.0-SNAPSHOT-sources.jar ^-DgroupIdcom.examp…...
短视频矩阵系统后端源码搭建实战与技术详解,支持OEM
一、引言 随着短视频行业的蓬勃发展,短视频矩阵系统成为了众多企业和创作者进行多平台内容运营的有力工具。后端作为整个系统的核心支撑,负责处理复杂的业务逻辑、数据存储与交互,其搭建的质量直接影响着系统的性能、稳定性和可扩展性。本文将…...
ArcGIS Pro地形图四至角图经纬度标注与格网标注
今天来看看ArcGIS Pro 如何在地形图上设置四至角点的经纬度。方里网标注。如下图的地形图左下角经纬度标注。 如下图方里网的标注 如下为本期要介绍的例图,如下: 图片可点击放大 接下来我们来介绍一下 推荐学习:GIS入门模型构建器Arcpy批量…...
鸿蒙Next状态管理V2 - @Once初始化用法总结
一、概述 Once装饰器用于实现变量仅在初始化时同步一次外部传入值,后续数据源更改时不会将修改同步给子组件。其必须搭配Param使用,且不影响Param的观测能力,仅拦截数据源变化,与Param装饰变量的先后顺序不影响实际功能ÿ…...
全新免押租赁系统助力商品流通高效安全
内容概要 全新免押租赁系统的推出,可以说是一场商品流通领域的小革命。想象一下,不再为押金烦恼,用户只需通过一个简单的信用评估,就能快速租到所需商品,这种体验简直令人惊喜!这个系统利用代扣支付技术&a…...
VUE前端实现防抖节流 Lodash
方法一:采用Lodash工具库 Lodash 是一个一致性、模块化、高性能的 JavaScript 实用工具库。 (1)采用终端导入Lodash库 $ npm i -g npm $ npm i --save lodash (2)应用 示例:搜索框输入防抖 在这个示例…...