TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(16)——注意力机制详解
- 0. 前言
- 1. 引入注意力机制
- 2. 注意力机制
- 2.1 注意力机制原理
- 2.2 注意力机制分类
- 3. 添加注意机制的 Seq2Seq 模型
- 3.1 数据处理
- 3.2 模型构建与训练
- 3.3 模型性能评估
- 小结
- 系列链接
0. 前言
在传统的神经网络中,所有的输入都被平等地处理,而注意力机制通过为输入的不同部分分配不同的权重(即注意力权重),使得网络能够更关注于对当前任务最重要的信息。例如在机器翻译中,某个单词在句子中比其他单词更关键,注意力机制会将更多的权重分配给该单词,从而使网络在生成翻译时能够更好地理解上下文。在本中,介绍了注意力机制的原理,以及如何利用注意力机制来提高 Seq2Seq 模型的性能。
1. 引入注意力机制
在基于 Seq2Seq 实现机器翻译一节中,我们学习了如何将编码器最后一个时间步的上下文作为初始隐藏状态输入到解码器中。随着上下文在解码器的时间步中流动,信号与解码器输出结合,并逐渐变弱,导致上下文对解码器的后续时间步的影响有限。此外,解码器输出的某些部分可能更多地依赖于输入的特定部分。例如,输入 “thank you very much
”,以及相应的英语到法语的翻译输出 “merci beaucoup
”,英语短语 “thank you
” 和 “very much
” 分别对应于法语的 “merci
” 和 “beaucoup
”,然而,这些信息通过单个上下文向量传达时并不充分。
注意力机制 (Attention Mechanism
) 在解码器的每个时间步提供对所有编码器隐藏状态的访问。解码器学习哪些编码器状态部分应更受关注,注意力机制的使用显著提高了机器翻译的质量以及多种标准自然语言处理任务的效果。
注意力机制的使用不仅限于 Seq2Seq
网络,注意力是“嵌入、编码、注意力、预测” (Embed, Encode, Attend, Predict
) 中的关键组件,用于创建先进的深度学习模型以处理自然语言处理任务。其中,注意力用来在将大型表示缩小到更紧凑的表示时尽可能保留更多信息,例如,将一系列词向量缩减为一个句子向量。
2. 注意力机制
2.1 注意力机制原理
本质上,注意力机制提供了一种评分方式,可以将目标中的词元与源中的所有词元进行比较,并相应地修改解码器的输入信号。考虑一个编码器-解码器架构,其中输入和输出的时间步由索引 i i i 和 j j j 表示,编码器和解码器在这些时间步上的隐藏状态分别由 h i h_i hi 和 s j s_j sj 表示。编码器的输入由 x i x_i xi 表示,解码器的输出由 y j y_j yj 表示。在没有注意力的编码器-解码器网络中,解码器状态 s j s_j sj 的值由前一个时间步的隐藏状态 s j − 1 s_{j-1} sj−1 和输出 y j − 1 y_{j-1} yj−1 给出。注意力机制添加了第三个信号 c j c_j cj,称为注意力上下文。因此,有了注意力,解码器的隐藏状态 s j s_j sj 是 y j − 1 y_{j-1} yj−1、 s j − 1 s_{j-1} sj−1 和 c j c_j cj 的函数:
s j = f ( y j − 1 , s j − 1 , c j ) s_j=f(y_{j-1},s_{j-1},c_j) sj=f(yj−1,sj−1,cj)
注意力上下文信号 c j c_j cj 的计算如下所示。对于每个解码器时间步 j j j,计算解码器状态 s j − 1 s_{j-1} sj−1 与每个编码器状态 h i h_i hi 之间的对齐 (alignment
),这为每个解码器状态 j j j 给出了一组 N N N 个相似性值 e i j e_{ij} eij,然后通过计算它们对应的 softmax
值 b i j b_{ij} bij 将其转换为概率分布。最后,注意力上下文 c j c_j cj 计算为所有编码器( N N N 个)时间步的编码器状态 h i h_i hi 及其对应的 softmax
权重 b i j b_{ij} bij 的加权和。以下方程组概括了每个解码器时间步 j j j 的这种转换:
e i j = a l i g n ( h i , s j − 1 ) ∀ i b i j = s o f t m a x ( e i j ) c j = ∑ i = 0 N h i b i j e_{ij}=align(h_i,s_{j-1})\forall i\\ b_{ij}=softmax(e_{ij})\\ c_j=\sum_{i=0}^Nh_ib_{ij} eij=align(hi,sj−1)∀ibij=softmax(eij)cj=i=0∑Nhibij
2.2 注意力机制分类
根据对齐方式的不同,研究人员提出了多种注意力机制。为了方便起见,我们将编码器的状态向量 h i h_i hi 表示为 h h h,并将解码器的状态向量 s j − 1 s_{j-1} sj−1 表示为 s s s。
对齐的最简单公式是基于内容的注意力 (content-based attention
),基于内容的注意力实际上是编码器和解码器状态之间的余弦相似度。使用此公式的前提条件是编码器和解码器上的隐藏状态向量必须具有相同的维度:
e = c o s i n e ( h , s ) e=cosine(h,s) e=cosine(h,s)
另一种公式称为加性注意力 (additive attention
, 或 Bahdanau attention
),这种方法使用一个小型神经网络中的可学习权重来组合状态向量,表示为以下方程,其中, s s s 和 h h h 向量连接后乘以学习权重 W W W,等效于使用两个学习权重 W s W_s Ws 和 W h W_h Wh 分别与 s s s 和 h h h 相乘,然后将结果相加:
e = v T tanh ( W [ s ; h ] ) e=v^T\text {tanh}(W[s;h]) e=vTtanh(W[s;h])
Luong
、Pham
和 Manning
提出了一组三种注意力形式,即点积 (dot
)、通用 (general
) 和连接 (concat
),其中通用形式也称为乘法 (multiplicative
) 或 Luong
注意力 (Luong’s attention
)。点积和连接注意力形式类似于基于内容和加性注意力形式。乘法注意力形式可以由以下方程表示:
e = h T W s e=h^TWs e=hTWs
Vaswani
等人提出了基于内容的注意力的变体,称为缩放点积注意力 (scaled dot-product attention
),其公式如下所示。其中, N N N 是编码器隐藏状态 h h h 的维度,Transformer
架构采用缩放点积注意力:
e = h T s N e=\frac {h^Ts}{\sqrt N} e=NhTs
注意力机制还可以根据其关注的内容进行分类。根据这种分类方式,注意力机制可以分为自注意力 (self-attention
)、全局注意力( global attention
,也称软注意力,soft attention
),以及局部注意力( local attention
,硬注意力,hard attention
)。自注意力是指在同一序列的不同部分之间计算对齐,已被证明在机器阅读、文本摘要和图像字幕等应用中非常有效。
软注意力或全局注意力是指在整个输入序列上计算对齐,而硬注意力或局部注意力是指在序列的一部分上计算对齐。软注意力的优势在于它是可微分的,但计算成本可能较高。相反,硬注意力在推理时计算成本较低,但是不可微分,并且在训练过程中需要更复杂的技术。
接下来,我们将学习如何将注意机制与 Seq2Seq 模型集成,以提升网络性能。
3. 添加注意机制的 Seq2Seq 模型
在本节中,同样使用 Seq2Seq
架构进行英语到法语的机器翻译任务,区别在于是,本节中解码器将使用加性注意机制和乘法注意机制来处理编码器输出。
3.1 数据处理
(1) 首先,导入所需库,并定义常量:
import nltk
import numpy as np
import re
import shutil
import tensorflow as tf
import os
import unicodedata
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunctionNUM_SENT_PAIRS = 30000
EMBEDDING_DIM = 256
ENCODER_DIM, DECODER_DIM = 1024, 1024
BATCH_SIZE = 64
NUM_EPOCHS = 30
(2) 下载数据集,并进行预处理,相关预处理过程参考 Seq2Seq 一节:
def clean_up_logs(data_dir):checkpoint_dir = os.path.join(data_dir, "checkpoints")if os.path.exists(checkpoint_dir):shutil.rmtree(checkpoint_dir, ignore_errors=True)os.makedirs(checkpoint_dir)return checkpoint_dirdef preprocess_sentence(sent):sent = "".join([c for c in unicodedata.normalize("NFD", sent) if unicodedata.category(c) != "Mn"])sent = re.sub(r"([!.?])", r" \1", sent)sent = re.sub(r"[^a-zA-Z!.?]+", r" ", sent)sent = re.sub(r"\s+", " ", sent)sent = sent.lower()return sentdef download_and_read(url, num_sent_pairs=30000):local_file = url.split('/')[-1]if not os.path.exists(local_file):os.system("wget -O {:s} {:s}".format(local_file, url))with zipfile.ZipFile(local_file, "r") as zip_ref:zip_ref.extractall(".")local_file = os.path.join(".", "fra.txt")en_sents, fr_sents_in, fr_sents_out = [], [], []with open(local_file, "r") as fin:for i, line in enumerate(fin):en_sent, fr_sent = line.strip().split('\t')[:2]en_sent = [w for w in preprocess_sentence(en_sent).split()]fr_sent = preprocess_sentence(fr_sent)fr_sent_in = [w for w in ("BOS " + fr_sent).split()]fr_sent_out = [w for w in (fr_sent + " EOS").split()]en_sents.append(en_sent)fr_sents_in.append(fr_sent_in)fr_sents_out.append(fr_sent_out)if i >= num_sent_pairs - 1:breakreturn en_sents, fr_sents_in, fr_sents_outdata_dir = "./data"
checkpoint_dir = clean_up_logs(data_dir)# data preparation
download_url = "http://www.manythings.org/anki/fra-eng.zip"
sents_en, sents_fr_in, sents_fr_out = download_and_read(download_url, num_sent_pairs=NUM_SENT_PAIRS)tokenizer_en = tf.keras.preprocessing.text.Tokenizer(filters="", lower=False)
tokenizer_en.fit_on_texts(sents_en)
data_en = tokenizer_en.texts_to_sequences(sents_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en, padding="post")tokenizer_fr = tf.keras.preprocessing.text.Tokenizer(filters="", lower=False)
tokenizer_fr.fit_on_texts(sents_fr_in)
tokenizer_fr.fit_on_texts(sents_fr_out)
data_fr_in = tokenizer_fr.texts_to_sequences(sents_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in, padding="post")
data_fr_out = tokenizer_fr.texts_to_sequences(sents_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out, padding="post")maxlen_en = data_en.shape[1]
maxlen_fr = data_fr_out.shape[1]
print("seqlen (en): {:d}, (fr): {:d}".format(maxlen_en, maxlen_fr))batch_size = BATCH_SIZE
dataset = tf.data.Dataset.from_tensor_slices((data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(10000)
test_size = NUM_SENT_PAIRS // 4
test_dataset = dataset.take(test_size).batch(batch_size, drop_remainder=True)
train_dataset = dataset.skip(test_size).batch(batch_size, drop_remainder=True)vocab_size_en = len(tokenizer_en.word_index)
vocab_size_fr = len(tokenizer_fr.word_index)
word2idx_en = tokenizer_en.word_index
idx2word_en = {v:k for k, v in word2idx_en.items()}
word2idx_fr = tokenizer_fr.word_index
idx2word_fr = {v:k for k, v in word2idx_fr.items()}
print("vocab size (en): {:d}, vocab size (fr): {:d}".format(vocab_size_en, vocab_size_fr))
# vocab size (en): 57, vocab size (fr): 123
3.2 模型构建与训练
(1) 与未使用注意力机制的 Seq2Seq
不同,编码器不再返回单个上下文向量,而是在每个时间步返回输出,因为注意机制将需要这些信息:
class Encoder(tf.keras.Model):def __init__(self, vocab_size, num_timesteps, embedding_dim, encoder_dim, **kwargs):super(Encoder, self).__init__(**kwargs)self.encoder_dim = encoder_dimself.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=num_timesteps)self.rnn = tf.keras.layers.GRU(encoder_dim, return_sequences=True, return_state=True)def call(self, x, state):x = self.embedding(x)x, state = self.rnn(x, initial_state=state)return x, statedef init_state(self, batch_size):return tf.zeros((batch_size, self.encoder_dim))
(2) 解码器最大的变化是注意力层的声明,首先需要定义注意力层。首先定义加性注意力类,加性注意力将解码器在每个时间步的隐藏状态与所有编码器的隐藏状态结合起来,以生成下一个时间步解码器的输入:
e = v T W[s;h] e=v^T\text{W[s;h]} e=vTW[s;h]
其中, W [ s ; h ] W[s;h] W[s;h] 是 s s s 和 h h h 的两个独立线性变换的简写(形式为 y = W x + b y=Wx+b y=Wx+b),一个是 s s s 的变换,另一个是 h h h 的变换,这两个线性变换实现为全连接层。我们将子类化 tf.keras Layer
对象,因为最终目标是将其作为网络中的一个层使用。call()
方法接受查询( query
,解码器状态)和值( value
,编码器状态),计算得分,然后通过相应的 softmax
计算对齐,并生成上下文向量,最后返回。上下文向量的形状为 (batch_size, num_decoder_timesteps)
,对齐的形状为 (batch_size, num_encoder_timesteps, 1)
。全连接层 W1
、W2
和 V
张量的权重在训练期间学习:
class BahdanauAttention(tf.keras.layers.Layer):def __init__(self, num_units):super(BahdanauAttention, self).__init__()self.W1 = tf.keras.layers.Dense(num_units)self.W2 = tf.keras.layers.Dense(num_units)self.V = tf.keras.layers.Dense(1)def call(self, query, values):# query is the decoder state at time step j# query.shape: (batch_size, num_units)# values are encoder states at every timestep i# values.shape: (batch_size, num_timesteps, num_units)# add time axis to query: (batch_size, 1, num_units)query_with_time_axis = tf.expand_dims(query, axis=1)# compute score:score = self.V(tf.keras.activations.tanh(self.W1(values) + self.W2(query_with_time_axis)))# compute softmaxalignment = tf.nn.softmax(score, axis=1)# compute attended outputcontext = tf.reduce_sum(tf.linalg.matmul(tf.linalg.matrix_transpose(alignment),values), axis=1)context = tf.expand_dims(context, axis=1)return context, alignment
(3) 乘法注意力 (Luong
注意力)的实现方式与加性注意力类似。只声明一个线性变换 W
,而不是三个 (W1
、W2
和 V
)。call()
方法的步骤类似,首先根据 Luong
注意力的方程计算得分,然后,我们将得分的对应 softmax
结果计算为对齐,然后将对齐和值的点积作为上下文向量,全连接层 W
表示的权重矩阵在训练过程中学习:
class LuongAttention(tf.keras.layers.Layer):def __init__(self, num_units):super(LuongAttention, self).__init__()self.W = tf.keras.layers.Dense(num_units)def call(self, query, values):# add time axis to queryquery_with_time_axis = tf.expand_dims(query, axis=1)# compute scorescore = tf.linalg.matmul(query_with_time_axis, self.W(values), transpose_b=True)# compute softmaxalignment = tf.nn.softmax(score, axis=2)# compute attended outputcontext = tf.matmul(alignment, values)return context, alignment
(4) 为了验证以上两个类是否可以互换使用,运行以下代码,构建一些随机输入,并将它们输入到这两个注意力类中:
batch_size = BATCH_SIZE
num_timesteps = MAXLEN_EN
num_units = ENCODER_DIMquery = np.random.random(size=(batch_size, num_units))
values = np.random.random(size=(batch_size, num_timesteps, num_units))# check out dimensions for Bahdanau attention
b_attn = BahdanauAttention(num_units)
context, alignments = b_attn(query, values)
print("Bahdanau: context.shape:", context.shape, "alignments.shape:", alignments.shape)# check out dimensions for Luong attention
l_attn = LuongAttention(num_units)
context, alignments = l_attn(query, values)
print("Luong: context.shape:", context.shape, "alignments.shape:", alignments.shape)
# Luong: context.shape: (64, 1024) alignments.shape: (64, 8, 1)
输出结果如下所示,和我们的预期一致,这两个类在给定相同输入时产生了形状完全相同的输出。因此,它们可以互换使用:
Bahdanau: context.shape: (64, 1024) alignments.shape: (64, 8, 1)
Luong: context.shape: (64, 1024) alignments.shape: (64, 8, 1)
(5) 实现注意力类后,构建解码器。init()
方法增加了注意力类变量,将其设置为 BahdanauAttention
类。此外,我们还有两个额外的变换,Wc
和 Ws
,它们将应用于解码器循环神经网络 (Recurrent Neural Network, RNN) 的输出。第一个变换使用 tanh
激活函数,将输出调节在 -1
到 +1
之间,而第二个则是标准的线性变换。与没有注意力解码器组件的 Seq2Seq
网络相比,本节的解码器在其 call()
方法中需要额外的参数 encoder_output
,并返回一个额外的上下文向量:
class Decoder(tf.keras.Model):def __init__(self, vocab_size, embedding_dim, num_timesteps,decoder_dim, **kwargs):super(Decoder, self).__init__(**kwargs)self.decoder_dim = decoder_dim# self.attention = LuongAttention(embedding_dim)self.attention = BahdanauAttention(embedding_dim)self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=num_timesteps)self.rnn = tf.keras.layers.GRU(decoder_dim, return_sequences=True, return_state=True)self.Wc = tf.keras.layers.Dense(decoder_dim, activation="tanh")self.Ws = tf.keras.layers.Dense(vocab_size)def call(self, x, state, encoder_out):x = self.embedding(x)context, alignment = self.attention(x, encoder_out)x = tf.expand_dims(tf.concat([x, tf.squeeze(context, axis=1)], axis=1), axis=1)x, state = self.rnn(x, state)x = self.Wc(x)x = self.Ws(x)return x, state, alignment
(6) 训练循环也与没有注意力机制的 Seq2Seq
网络不同,没有注意力机制的 Seq2Seq
网络使用 Teacher Forcing
来加快训练速度,使用注意力意味着必须逐个处理解码器输入,这是因为前一步的解码器输出通过注意力机制对当前时间步的输出影响更大。训练循环由 train_step
函数定义,明显比没有使用注意力的 Seq2Seq
网络的训练循环慢得多,但本节实现的训练循环也可以用于没有使用注意力的 Seq2Seq
网络,特别是需要实现调度采样策略时:
def loss_fn(ytrue, ypred):scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)mask = tf.math.logical_not(tf.math.equal(ytrue, 0))mask = tf.cast(mask, dtype=tf.int64)loss = scce(ytrue, ypred, sample_weight=mask)return loss@tf.function
def train_step(encoder_in, decoder_in, decoder_out, encoder_state):with tf.GradientTape() as tape:encoder_out, encoder_state = encoder(encoder_in, encoder_state)decoder_state = encoder_stateloss = 0for t in range(decoder_out.shape[1]):decoder_in_t = decoder_in[:, t]decoder_pred_t, decoder_state, _ = decoder(decoder_in_t,decoder_state, encoder_out)loss += loss_fn(decoder_out[:, t], decoder_pred_t)variables = encoder.trainable_variables + decoder.trainable_variablesgradients = tape.gradient(loss, variables)optimizer.apply_gradients(zip(gradients, variables))return loss / decoder_out.shape[1]
(7) predict()
和 evaluate()
方法同样涉及实现解码器的新数据流,包括额外的 encoder_out
参数和额外的 context
返回值:
def predict(encoder, decoder, batch_size, sents_en, data_en, sents_fr_out, word2idx_fr, idx2word_fr):random_id = np.random.choice(len(sents_en))print("input : ", " ".join(sents_en[random_id]))print("label : ", " ".join(sents_fr_out[random_id]))encoder_in = tf.expand_dims(data_en[random_id], axis=0)decoder_out = tf.expand_dims(sents_fr_out[random_id], axis=0)encoder_state = encoder.init_state(1)encoder_out, encoder_state = encoder(encoder_in, encoder_state)decoder_state = encoder_statepred_sent_fr = []decoder_in = tf.expand_dims(tf.constant(word2idx_fr["BOS"]), axis=0)while True:decoder_pred, decoder_state, _ = decoder(decoder_in, decoder_state, encoder_out)decoder_pred = tf.argmax(decoder_pred, axis=-1)pred_word = idx2word_fr[decoder_pred.numpy()[0][0]]pred_sent_fr.append(pred_word)if pred_word == "EOS":breakdecoder_in = tf.squeeze(decoder_pred, axis=1)print("predicted: ", " ".join(pred_sent_fr))def evaluate_bleu_score(encoder, decoder, test_dataset, word2idx_fr, idx2word_fr):bleu_scores = []smooth_fn = SmoothingFunction()for encoder_in, decoder_in, decoder_out in test_dataset:encoder_state = encoder.init_state(batch_size)encoder_out, encoder_state = encoder(encoder_in, encoder_state)decoder_state = encoder_stateref_sent_ids = np.zeros_like(decoder_out)hyp_sent_ids = np.zeros_like(decoder_out)for t in range(decoder_out.shape[1]):decoder_out_t = decoder_out[:, t]decoder_in_t = decoder_in[:, t]decoder_pred_t, decoder_state, _ = decoder(decoder_in_t, decoder_state, encoder_out)decoder_pred_t = tf.argmax(decoder_pred_t, axis=-1)for b in range(decoder_pred_t.shape[0]):ref_sent_ids[b, t] = decoder_out_t.numpy()[0]hyp_sent_ids[b, t] = decoder_pred_t.numpy()[0][0]for b in range(ref_sent_ids.shape[0]):ref_sent = [idx2word_fr[i] for i in ref_sent_ids[b] if i > 0]hyp_sent = [idx2word_fr[i] for i in hyp_sent_ids[b] if i > 0]# remove trailing EOSref_sent = ref_sent[0:-1]hyp_sent = hyp_sent[0:-1]bleu_score = sentence_bleu([ref_sent], hyp_sent,smoothing_function=smooth_fn.method1)bleu_scores.append(bleu_score)return np.mean(np.array(bleu_scores))
(8) 训练两个使用不同注意力机制的 Seq2Seq
网络,一个采用加性 (Bahdanau
) 注意力,另一个采用乘法 (Luong
) 注意力。两个网络均训练 50
个 epoch
,然而,无论是哪种注意力机制,翻译质量与不使用注意力的 Seq2Seq
网络训练 250
个 epoch
的情况相似:
embedding_dim = EMBEDDING_DIM
encoder_dim, decoder_dim = ENCODER_DIM, DECODER_DIMencoder = Encoder(vocab_size_en+1, embedding_dim, maxlen_en, encoder_dim)
decoder = Decoder(vocab_size_fr+1, embedding_dim, maxlen_fr, decoder_dim)optimizer = tf.keras.optimizers.Adam()
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,encoder=encoder,decoder=decoder)num_epochs = NUM_EPOCHS
eval_scores = []for e in range(num_epochs):encoder_state = encoder.init_state(batch_size)for batch, data in enumerate(train_dataset):encoder_in, decoder_in, decoder_out = data# print(encoder_in.shape, decoder_in.shape, decoder_out.shape)loss = train_step(encoder_in, decoder_in, decoder_out, encoder_state)print("Epoch: {}, Loss: {:.4f}".format(e + 1, loss.numpy()))if e % 10 == 0:checkpoint.save(file_prefix=checkpoint_prefix)predict(encoder, decoder, batch_size, sents_en, data_en,sents_fr_out, word2idx_fr, idx2word_fr)eval_score = evaluate_bleu_score(encoder, decoder, test_dataset, word2idx_fr, idx2word_fr)print("Eval Score (BLEU): {:.3e}".format(eval_score))# eval_scores.append(eval_score)checkpoint.save(file_prefix=checkpoint_prefix)
3.3 模型性能评估
训练结束时,与不带注意力的 Seq2Seq
网络相比,带有注意力机制的 Seq2Seq
网络在训练结束时损失略低,并且在测试集上的双语评估 (BiLingual Evaluation Understudy
, BLEU
) 分数略高:
模型 | epoch | loss | BLUE |
---|---|---|---|
Seq2Seq | 250 | 0.967 | 4.869e-02 |
使用加性注意力的 Seq2Seq | 50 | 0.0893 | 5.508e-02 |
使用乘法注意力的 Seq2Seq | 50 | 0.0706 | 5.563e-02 |
两个网络生成的翻译结果如下,需要注意的是,即使翻译结果与真实标签不完全一致,但仍然是原文的有效翻译:
注意力类型 | epoch | 英语 | 法语(标签值) | 法语(预测值) |
Bahdanau | 20 | your cat is fat. | ton chat est gras. | ton chat est mouille. |
25 | i had to go back. | il m a fallu retourner. | il me faut partir. | |
30 | try to find it. | tentez de le trouver. | tentez de le trouver. | |
Luong | 20 | that s peculiar. | c est etrange. | c est deconcertant. |
25 | tom is athletic. | thomas est sportif. | tom est sportif. | |
30 | it s dangerous. | c est dangereux. | c est dangereux. |
可以通过在 Decoder
类的 init()
方法中注释掉其中一种注意力机制在 Bahdanau
(加性)或 Luong
(乘法)注意力机制之间切换。
小结
注意力机制使得神经网络能够动态地聚焦于最重要的信息部分,从而提高了模型的表现和效率。最初在自然语言处理任务中得到了广泛应用,如今它已被扩展到计算机视觉、语音识别等多个领域。其核心思想是模拟人类注意力的分配方式,赋予网络在处理信息时对不同部分的不同关注程度。
系列链接
TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
相关文章:
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(16)——注意力机制详解 0. 前言1. 引入注意力机制2. 注意力机制2.1 注意力机制原理2.2 注意力机制分类 3. 添加注意机制的 Seq2Seq 模型3.1 数据处理3.2 模型构建与训练3.3 模型性能评估 小结系列链接 0. 前言 在传统的神经网络…...
架空防静电地板材质全解析:选对材质,守护精密空间的“安全卫士”
在现代科技驱动的社会中,无论是数据中心、实验室、手术室,还是高端电子厂房,静电都是精密设备的“隐形杀手”。而架空防静电地板作为这些场所的“安全卫士”,其材质选择直接决定了防静电性能、承重能力及使用寿命。今天࿰…...
Linux系统中部署java服务(docker)
1、不使用docker ✅ 1. 检查并安装 Java 环境 检查 Java 是否已安装: java -version✅ 2. 上传 Java 项目 JAR 文件 可以创建一个server文件夹,然后上传目录 查看当前目录 然后创建目录上传jar包 ✅ 3. 启动 Java 服务 java -jar hywl-server.jar…...
PyGame游戏开发(入门知识+组件拆分+历史存档/回放+人机策略)
前言: 本章实现游戏组件的复用解耦,以及使用配置文件替代原有硬编码形式,进而只需要改动配置文件即可实现整个游戏的难度和地图变化,同时增加历史记录功能,在配置文件开启后即可保存每一局的记录为json形式作为后续强化…...
【上位机——WPF】Window标签常用属性
常用属性 常用属性程序退出 常用属性都是写在Window标签中的 <Window x:Class"WpfDemo1.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"…...
K8S Gateway AB测试、蓝绿发布、金丝雀(灰度)发布
假设有如下三个节点的 K8S 集群: k8s31master 是控制节点 k8s31node1、k8s31node2 是工作节点 容器运行时是 containerd 一、场景分析 阅读本文,默认您已经安装了 K8S Gateway。 关于 AB 测试、金丝雀发布,可以看这篇文章。 二、实验准…...
人大金仓数据库 与django结合
要在Django项目中连接人大金仓数据库(Kingbase),你需要使用一个适合的数据库适配器。人大金仓数据库是基于PostgreSQL的,因此你可以使用psycopg2库来与Django连接。但是,由于人大金仓数据库有其特定的功能和配置&#…...
RK3588 桌面系统配置WiFi和蓝牙配置
桌面右上角点击,打开选项,找到WiFi的选择网络或者WiFi设置 在弹出的窗口中选择需要连接的WiFi,然后右下角选择连接,然后输入WiFi密码即可连接。 25.4. 命令行连接wifi路由器 命令行配置wifi的方法有很多,下面介绍几种…...
TLV格式
TLV格式(Tag-Length-Value)是一种常用的数据序列化格式,主要用于数据包或消息的有效载荷编码。TLV格式将数据划分为三个主要部分:Tag(标签)、Length(长度)和Value(值…...
2024年9月电子学会等级考试五级第三题——整数分解
题目 3、整数分解 正整数 N 的 K-P 分解是指将 N 写成 K 个正整数的 P 次方的和。本题就请你对任意给定的正整数 N、K、P,写出 N 的 K-P 分解。 时间限制:8000 内存限制:262144 输入 输入在一行给出 3 个正整数 N (≤ 400)、K (≤ N)、P (1 …...
软考 系统架构设计师系列知识点之杂项集萃(60)
接前一篇文章:软考 系统架构设计师系列知识点之杂项集萃(59) 第97题 在面向对象设计中,()可以实现界面控制、外部接口和环境隔离。()作为完成用例业务的责任承担者,协调…...
使用Python开发经典俄罗斯方块游戏
使用Python开发经典俄罗斯方块游戏 在这篇教程中,我们将学习如何使用Python和Pygame库开发一个经典的俄罗斯方块游戏。这个项目将帮助你理解游戏开发的基本概念,包括图形界面、用户输入处理、碰撞检测等重要内容。 项目概述 我们将实现以下功能&…...
C++:字符数组与字符串指针变量的大小
#include<iostream> #include<cstring> int main(int argc, char const *argv[]) {// 字符数组char str[128] "hello world";std::cout<<sizeof(str)<<std::endl;std::cout<<strlen(str)<<std::endl;// 字符串指针变量char *st…...
stm32使用freertos时延时时间间隔不对,可能是晶振频率没设置
freertos 获取频率的接口 在 FreeRTOSConfig.h 文件中声明一个函数作为freertos的接口 /// /// brief 获取 SysTick 的频率 /// /// note arm cortex-m 系列 CPU 有一个 systick ,里面有一个 CTRL 寄存器,其中的 bit2 /// 可以用来控制 systick 的时钟…...
51c~C语言~合集5
我自己的原文哦~ https://blog.51cto.com/whaosoft/13913911 一、大厂C语言编程10大规范 1 代码总体原则 1、清晰第一 清晰性是易于维护、易于重构的程序必需具备的特征。代码首先是给人读的,好的代码应当可以像文章一样发声朗诵出来。 目前软件维护期成本…...
前端流行框架Vue3教程:17. _组件数据传递
_组件数据传递 我们之前讲解过了组件之间的数据传递,props 和自定义事件 两种方式 props:父传子 自定义事件:子传父 除了上述的方案,props也可以实现子传父 一、项目结构 src/ └── components/├── ComponentsA.vue # 父…...
Stack overflow
本文来源 :腾讯元宝 Stack Overflow - Where Developers Learn, Share, & Build Careers 开发者学习,分享 通过学习、工作和经验积累等方式,逐步建立和发展自己的职业生涯。 Find answers to your technical questions and help othe…...
SpringBoot 3.4.5版本导入Lomobok依赖后无法生效的问题
问题背景 最近,随着DeepSeek的爆火,小编也编写了一个前后端分离的“知库随考”系统,由于Spring AI官方提示想要使用Spring AI的话要求Spring Boot的版本在“3.4.x”以上,所以我在创建SpringBoot项目的时候选择了了Server URL:http…...
FPGA: UltraScale+ bitslip实现(ISERDESE3)
收获 一晃五年~ 五年前那个夏夜,我对着泛蓝的屏幕敲下《给十年后的自己》,在2020年的疫情迷雾中编织着对未来的想象。此刻回望,第四届集创赛的参赛编号仍清晰如昨,而那个在家熬夜焊电路板的"不眠者",现在…...
Electron详解:原理与不足
Electron是一个集成项目,它通过定制Chromium和Node.js,并将它们集成在内部来实现其功能。具体来说,Electron做了以下几个重要的工作: 定制Chromium:并将定制版本的Chromium集成在Electron内部。定制Node.js࿱…...
Spring Boot多数据源配置的陷阱与终极解决方案
引言 在微服务架构和复杂业务场景中,一个Spring Boot应用连接多个数据库的需求日益普遍。许多开发者尝试通过简单复制单数据源配置来实现多数据源,结果却遭遇了Bean冲突、事务失效、连接泄漏等隐蔽问题。本文将深入剖析Spring Boot自动配置的底层逻辑&a…...
android display 笔记(十四)VAU 和GSP 分别代表什么
VAU 和 GSP 的解释 GSP (Graphics/GPU Subsystem Processor) 含义: 图形处理子系统,通常指 SoC(系统级芯片)中负责 2D/3D 图形渲染、显示合成、图像后处理(如缩放、旋转、色彩管理) 的硬件模块。 在部分芯…...
tomcat 400 The valid characters are defined in RFC 7230 and RFC 3986
在遇到 Tomcat 因 URL 非法字符返回 400 Bad Request 时,选择在 Nginx 还是 Tomcat 中配置错误处理,需根据实际场景和需求权衡。以下是两种方案的详细对比及配置方法: 一、选择建议 方案适用场景优点缺点Nginx 配置- 需要统一处理所有后端服务(如多个 Tomcat 实例)的 400 …...
nginx负载均衡及keepalive高可用
实验前期准备: 5台虚拟机:4台当做服务器,1台当做客户机(当然,也可以使用主机的浏览器),4台服务器中,2台服务器当做后端真实访问服务器;另外2台服务器当做负载均衡服务器…...
漏洞修复:tomcat 升级版本 spring-boot-starter-tomcat 的依赖项
在Spring Boot项目中修复Tomcat漏洞(如CVE-2024-21733)时,通常需要升级内嵌Tomcat版本。以下是具体操作步骤和注意事项: 一、确认当前Tomcat版本 通过启动日志查看 启动项目时,控制台日志中会显示类似 Starting Servlet engine: [Apache Tomcat/9.0.43] 的信息,直接查看版…...
二、IGMP
目录 1. IGMPv1 1.1 IGMPv1 报⽂类型 1.2 IGMPv1 工作机制 1.3 成员加入 1.4 离组机制 2. IGMPv2 2.1 IGMPv2 报文 2.3 查询器选举 & 维护 2.4 成员加入 2.4 离组机制 3. IGMPv3 3.1 IGMPv3 vs. IGMPv2 3.2 IGMPv3 报文 3.3 IGMPv3 工作机制 4. IGMP Proxy …...
Redis--基础知识点--27--redis缓存分类树
在 Redis 中存储分类树,通常需要选择合适的数据结构来表现层级关系。以下是使用 字符串(String) 和 哈希(Hash) 两种常见方案的举例说明,结合电商分类场景(如 电子产品 > 手机 > 智能手机…...
【2025最新】VSCode Cline插件配置教程:免费使用Claude 3.7提升编程效率
2025年最新VSCode Cline插件安装配置教程,详解多种免费使用Claude 3.7的方法,集成DeepSeek-R1与5大实用功能,专业编程效率提升指南。 Cline是VSCode中功能最强大的AI编程助手插件之一,它能与Claude、OpenAI等多种大模型无缝集…...
SQL笔记一
SQL的分类 DDL(数据定义语言):CREATE(创建) ALTER(修改) DROP(删除结构) RENAME(重命名) TRUNCATE(清空) DML࿰…...
高可靠低纹波国产4644电源芯片在工业设备的应用
摘要 随着工业自动化和智能化的飞速发展,工业设备对于电源芯片的性能和可靠性提出了前所未有的严格要求。电源芯片作为工业设备的核心供电组件,其性能直接影响到整个设备的运行效率和稳定性。本文以国科安芯的ASP4644四通道降压稳压器为例,通…...
MyBatis 的分页插件 c
前言 大型项目的数据体量很大,在前端界面展示时为保障展示效果,会要求接口快速返回,这时候后端会选择分页获取数据,只传递要查询的页码数据。这就避免了大多问题,达到快速返回的效果。 常用的分页有2种: …...
网络安全-等级保护(等保) 2-5 GB/T 25070—2019《信息安全技术 网络安全等级保护安全设计技术要求》-2019-05-10发布【现行】
################################################################################ GB/T 22239-2019 《信息安全技术 网络安全等级保护基础要求》包含安全物理环境、安全通信网络、安全区域边界、安全计算环境、安全管理中心、安全管理制度、安全管理机构、安全管理人员、安…...
嵌软面试每日一阅----FreeRTOS
一. FreeRTOS 创建任务的方法及区别 在 FreeRTOS 中,任务创建主要有两种方式:动态内存分配(xTaskCreate())和静态内存分配(xTaskCreateStatic())。以下是两者的核心区别及使用场景: 1. 动态创建…...
EasyExcel详解
文章目录 一、easyExcel1.什么是easyExcel2.easyExcel示例demo3.easyExcel read的底层逻辑~~4.easyExcel write的底层逻辑~~ 二、FastExcel1.为什么更换为fastExcel2.fastExcel新功能 一、easyExcel 1.什么是easyExcel 内容摘自官方:Java解析、生成Excel比较有名的…...
023-C语言预处理详解
C语言预处理详解 文章目录 C语言预处理详解1. 预定义符号2. #define定义常量3. #define定义宏4. 带有副作用的宏参数5. 宏替换的规则6. 宏函数的对比7. #和##7.1 #运算符7.2 ##运算符 8. 命名约定9. #undef10. 命令行定义11. 条件编译12. 头文件包含12.1 头文件被包含方式12.1.…...
C#自定义控件-实现了一个支持平移、缩放、双击重置的图像显示控件
1. 控件概述 这是一个继承自 Control 的自定义控件,主要用于图像的显示和交互操作,具有以下核心功能: 图像显示与缩放(支持鼠标滚轮缩放)图像平移(支持鼠标拖拽)视图重置(双击重置…...
MarkitDown:AI时代的文档转换利器
在当今AI快速发展的时代,如何高效地将各种格式的文档转换为机器可读的格式,成为了一个迫切需要解决的问题。今天,我们来介绍一款由微软开发的强大工具——MarkitDown,它正是为解决这一问题而生的。 什么是MarkitDown? MarkitDown是一个用Python编写的轻量级工具,专门用…...
《数字分身进化论:React Native与Flutter如何打造沉浸式虚拟形象编辑》
React Native,依托JavaScript语言,借助其成熟的React生态系统,开发者能够快速上手,将前端开发的经验巧妙运用到移动应用开发中。它通过JavaScript桥接机制调用原生组件,实现与iOS和Android系统的深度交互,这…...
DeerFlow:字节新一代 DeepSearch 框架
项目地址:https://github.com/bytedance/deer-flow/ 【全新的 Multi-Agent 架构设计】独家设计的 Research Team 机制,支持多轮对话、多轮决策和多轮任务执行。与 LangChain 原版 Supervisor 相比,显著减少 Tokens 消耗和 API 调用次数&#…...
数字孪生工厂实战指南:基于Unreal Engine/Omniverse的虚实同步系统开发
引言:工业元宇宙的基石技术 在智能制造2025与工业元宇宙的交汇点,数字孪生技术正重塑传统制造业。本文将手把手指导您构建基于Unreal Engine 5.4与NVIDIA Omniverse的实时数字孪生工厂系统,集成Kafka实现毫秒级虚实同步,最终交付…...
牛客网NC22015:最大值和最小值
牛客网NC22015:最大值和最小值 题目描述 题目要求 输入:一行,包含三个整数 a, b, c (1≤a,b,c≤1000000) 输出:两行,第一行输出最大数,第二行输出最小数。 样例输入: …...
Uniapp中小程序调用腾讯地图(获取定位地址)
1、先配置权限: 这是上图的代码: "permission": { "scope.userLocation": { "desc": "你的位置信息将用于小程序位置接口的效果展示" } } 第二步:写代码: //下面是uniapp的模版代码 主…...
2025 后端自学UNIAPP【项目实战:旅游项目】5、个人中心页面:微信登录,同意授权,获取用户信息
一、框架以及准备工作 1、前端项目文件结构展示 2、后端项目文件结构展示 3、登录微信公众平台,注册一个个人的程序,获取大appid(前端后端都需要)和密钥(后端需要) 微信公众平台微信公众平台&…...
隆重推荐(Android 和 iOS)UI 自动化工具—Maestro
文章目录 前言一、为什么选择 Maestro?二、使用步骤1.安装(Windows)2.运行示例 三、Maestro Studio (重点)轻松编辑测试 四、价格总结 前言 当前移动 UI 自动化工具的实际效能与预期存在显著差距,团队推行…...
C#发送文件到蓝牙设备
测试环境: visual studio 2022 win11笔记本电脑,具有蓝牙功能 .net6控制台 测试步骤如下: 1 新增名为BluetoothDemo控制台项目 2 通过nuget安装InTheHand.Net.Bluetooth,版本选择4.2.1和安装InTheHand.Net.Obex,版…...
采用sherpa-onnx 实现 ios语音唤起的调研
背景 项目中需要实现一个语音唤起的功能,选择sherpa-onnx进行调研,要求各端都要验证没有问题。个人负责ios这部分的调研。查询官方发现没有直接针对ios语音唤起的稳定,主要技术平台也没有相关的可以借鉴。经过调研之后补充了一个。 一、下载…...
磁盘I/O瓶颈排查:面试通关“三部曲”心法
想象一下,你就是线上系统的“交通调度总指挥”,服务器的磁盘是所有数据进出的“核心枢纽港口”。当这个“港口”突然拥堵不堪,卡车(数据请求)排起长龙,进不去也出不来,整个系统的“物流”&#…...
磁盘性能测试与分析:结合fio和iostat的完整方案
磁盘性能测试与分析:结合fio和iostat的完整方案 磁盘性能是影响现代计算机系统整体运行效率的关键因素之一,特别是对于高I/O负载的应用如数据库、虚拟化环境等。本文将详细介绍如何利用fio和iostat工具全面评估磁盘性能,包括IOPS、带宽、延迟…...
随机森林(Random Forest)
随机森林(Random Forest)是一种基于决策树的集成学习算法,它通过构建多个决策树并将它们的预测结果进行综合,从而提高模型的准确性和稳定性。 1.基本原理 随机森林属于集成学习中的“Bagging”方法。其核心思想是通过构建多个决…...
C#数据类型
🧩 一、布尔值(bool) 表示逻辑值:true 或 false bool isTrue true; bool isFalse false;📌 二、整数(Integer Types) C# 支持多种有符号和无符号整数类型: 类型大小范围sbyte8…...