【实验16】基于双向LSTM模型完成文本分类任务
目录
1 数据集处理- IMDB 电影评论数据集
1.1 认识数据集
1.2 数据加载
1.3 构造Dataset类
1.4 封装DataLoader
1.4.1 collate_fn函数
1.4.2 封装dataloader
2 模型构建
2.1 汇聚层算子
2.2 模型汇总
3 模型训练
4 模型评价
5 模型预测
6 完整代码
7 拓展实验
1 数据集处理- IMDB 电影评论数据集
1.1 认识数据集
IMDB 电影评论数据集是一份关于电影评论的经典二分类数据集.IMDB 按照评分的高低筛选出了积极评论和消极评论,如果评分 ≥7,则认为是积极评论;如果评分 ≤4,则认为是消极评论。
数据集包含训练集和测试集数据,数量各为 25000 条,每条数据都是一段用户关于某个电影的真实评价,以及观众对这个电影的情感倾向。
下载地址:IMDB数据集
1.2 数据加载
这里将原始的测试集平均分为两份,分别作为验证集和测试集,存放于./dataset
目录下。
代码如下:
import osdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载IMDB数据集
def load_imdb_data(path):assert os.path.exists(path)# 初始化数据集列表trainset, devset, testset = [], [], []# 加载训练集数据for label in ['pos', 'neg']:label_path = os.path.join(path, 'train', label)for filename in os.listdir(label_path):if filename.endswith('.txt'):with open(os.path.join(label_path, filename), 'r', encoding='utf-8') as f:sentence = f.read().strip().lower() # 读取并处理每个评论trainset.append((sentence, label))# 加载测试集数据for label in ['pos', 'neg']:label_path = os.path.join(path, 'test', label)for filename in os.listdir(label_path):if filename.endswith('.txt'):with open(os.path.join(label_path, filename), 'r', encoding='utf-8') as f:sentence = f.read().strip().lower() # 读取并处理每个评论testset.append((sentence, label))# 随机拆分测试集的一半作为验证集random.shuffle(testset) # 打乱测试集顺序split_index = len(testset) // 2 # 计算拆分索引devset = testset[:split_index] # 选择测试集前一半作为验证集testset = testset[split_index:] # 剩下的部分作为测试集return trainset, devset, testset# 加载IMDB数据集
train_data, dev_data, test_data = load_imdb_data("./dataset/")# 打印一下加载后的数据样式
print(train_data[4]) # 打印训练集中的第5条数据
this is not the typical mel brooks film. it was much less slapstick than most of his movies and actually had a plot that was followable. leslie ann warren made the movie, she is such a fantastic, under-rated actress. there were some moments that could have been fleshed out a bit more, and some scenes that could probably have been cut to make the room to do so, but all in all, this is worth the price to rent and see it. the acting was good overall, brooks himself did a good job without his characteristic speaking to directly to the audience. again, warren was the best actor in the movie, but "fume" and "sailor" both played their parts well.', 'pos')
从输出结果看,加载后的每条样本包含两部分内容:文本串和标签。
1.3 构造Dataset类
构造IMDBDataset类用于数据管理,输入是文本序列,需要先将其中的每个词转换为该词在词表中的序号 ID,然后根据词表ID查询这些词对应的词向量(词向量)【使用IMDBDataset类中的words_to_id方法】。
利用词表将序列中的每个词映射为对应的数字编号,便于进一步转为为词向量。当序列中的词没有包含在词表时,默认会将该词用[UNK]代替。words_to_id方法利用一个如下图所示的哈希表来进行转换,实验中词表为数据集文件中的imdb.vocab。【这里注意原来的imdb.vocab中没有UNK和PAD映射,需要自行添加】
class IMDBDataset(Dataset):def __init__(self, examples, word2id_dict):super(IMDBDataset, self).__init__()self.word2id_dict = word2id_dictself.examples = self.words_to_id(examples)def words_to_id(self, examples):tmp_examples = []for idx, example in enumerate(examples):seq, label = example# 将单词映射为字典索引的ID, 对于词典中没有的单词用[UNK]对应的ID进行替代seq = [self.word2id_dict.get(word, self.word2id_dict['[UNK]']) for word in seq.split(" ")]# 映射标签: 'pos' -> 1, 'neg' -> 0label = 1 if label == 'pos' else 0 # 将标签从'pos'/'neg'转换为1/0tmp_examples.append([seq, label])return tmp_examplesdef __getitem__(self, idx):seq, label = self.examples[idx]return seq, labeldef __len__(self):return len(self.examples)def load_vocab(path):assert os.path.exists(path) # 确保词表文件路径存在words = [] # 初始化空列表,存储词表中的单词with open(path, "r", encoding="utf-8") as f: # 打开文件并读取内容words = f.readlines() # 读取文件中的所有行words = [word.strip() for word in words if word.strip()] # 移除每个单词的前后空白字符并去掉空行word2id = dict(zip(words, range(len(words)))) # 创建一个字典,将单词与对应的ID映射return word2id # 返回这个字典# 加载词表
word2id_dict = load_vocab("./dataset/imdb.vocab")# 实例化Dataset
train_set = IMDBDataset(train_data, word2id_dict)
dev_set = IMDBDataset(dev_data, word2id_dict)
test_set = IMDBDataset(test_data, word2id_dict)print('训练集样本数:', len(train_set))
print('样本示例:', train_set[4])
运行结果:
([11, 7, 21, 2, 764, 3633, 2822, 0, 8, 13, 74, 324, 2706, 72, 89, 5, 25, 101, 3, 161, 67, 4, 113, 12, 13, 0, 2750, 1900, 3725, 92, 2, 0, 53, 7, 138, 4, 0, 13617, 0, 39, 69, 49, 369, 12, 97, 26, 75, 7239, 46, 4, 221, 0, 3, 49, 137, 12, 97, 234, 26, 75, 644, 6, 96, 2, 667, 6, 83, 0, 18, 30, 9, 0, 11, 7, 283, 2, 1766, 6, 859, 3, 66, 0, 2, 111, 13, 50, 0, 2822, 302, 119, 4, 50, 284, 202, 25, 7517, 1409, 6, 2475, 6, 2, 0, 0, 3725, 13, 2, 117, 266, 9, 2, 0, 18, 0, 3, 0, 192, 248, 65, 512, 0], 1)
训练集样本数: 25000
样本示例: ([11, 7, 21, 2, 764, 3633, 2822, 0, 8, 13, 74, 324, 2706, 72, 89, 5, 25, 101, 3, 161, 67, 4, 113, 12, 13, 0, 2750, 1900, 3725, 92, 2, 0, 53, 7, 138, 4, 0, 13617, 0, 39, 69, 49, 369, 12, 97, 26, 75, 7239, 46, 4, 221, 0, 3, 49, 137, 12, 97, 234, 26, 75, 644, 6, 96, 2, 667, 6, 83, 0, 18, 30, 9, 0, 11, 7, 283, 2, 1766, 6, 859, 3, 66, 0, 2, 111, 13, 50, 0, 2822, 302, 119, 4, 50, 284, 202, 25, 7517, 1409, 6, 2475, 6, 2, 0, 0, 3725, 13, 2, 117, 266, 9, 2, 0, 18, 0, 3, 0, 192, 248, 65, 512, 0], 1)
可知,train_set[4] 的样本为:
this is not the typical mel brooks film. it was much less slapstick than most of his movies and actually had a plot that was followable. leslie ann warren made the movie, she is such a fantastic, under-rated actress. there were some moments that could have been fleshed out a bit more, and some scenes that could probably have been cut to make the room to do so, but all in all, this is worth the price to rent and see it. the acting was good overall, brooks himself did a good job without his characteristic speaking to directly to the audience. again, warren was the best actor in the movie, but "fume" and "sailor" both played their parts well.', 'pos')
成功被转换为该词在词表中的序号 ID:
([11, 7, 21, 2, 764, 3633, 2822, 0, 8, 13, 74, 324, 2706, 72, 89, 5, 25, 101, 3, 161, 67, 4, 113, 12, 13, 0, 2750, 1900, 3725, 92, 2, 0, 53, 7, 138, 4, 0, 13617, 0, 39, 69, 49, 369, 12, 97, 26, 75, 7239, 46, 4, 221, 0, 3, 49, 137, 12, 97, 234, 26, 75, 644, 6, 96, 2, 667, 6, 83, 0, 18, 30, 9, 0, 11, 7, 283, 2, 1766, 6, 859, 3, 66, 0, 2, 111, 13, 50, 0, 2822, 302, 119, 4, 50, 284, 202, 25, 7517, 1409, 6, 2475, 6, 2, 0, 0, 3725, 13, 2, 117, 266, 9, 2, 0, 18, 0, 3, 0, 192, 248, 65, 512, 0], 1)
1.4 封装DataLoader
构造对应的 DataLoader,用于批次数据的迭代。
主要功能:
- 长度限制:需要将序列的长度控制在一定的范围内,避免部分数据过长影响整体训练效果。-----使用max_seq_len参数对于过长的文本进行截断.
- 长度补齐:神经网络模型通常需要同一批处理的数据的序列长度是相同的,然而在分批时通常会将不同长度序列放在同一批,因此需要对序列进行补齐处理.-----先统计该批数据中序列的最大长度,并将短的序列填充一些没有特殊意义的占位符 [PAD],将长度补齐到该批次的最大长度
1.4.1 collate_fn函数
定义一个collate_fn函数来做数据的截断和填充.。该函数可以作为回调函数传入 DataLoader,DataLoader 在返回一批数据之前,调用该函数去处理数据,并返回处理后的序列数据和对应标签。
def collate_fn(batch_data, pad_val=0, max_seq_len=256):seqs, seq_lens, labels = [], [], []max_len = 0for example in batch_data:seq, label = example# 对数据序列进行截断seq = seq[:max_seq_len]# 对数据截断并保存于seqs中seqs.append(seq)seq_lens.append(len(seq))labels.append(label)# 保存序列最大长度max_len = max(max_len, len(seq))# 对数据序列进行填充至最大长度for i in range(len(seqs)):seqs[i] = seqs[i] + [pad_val] * (max_len - len(seqs[i]))return (torch.tensor(seqs).to(device), torch.tensor(seq_lens)), torch.tensor(labels).to(device)
测试一下collate_fn函数的功能,假定一下max_seq_len为5,然后定义序列长度分别为6和3的两条数据,传入collate_fn函数中
# =======测试==============
max_seq_len = 5
batch_data = [[[1, 2, 3, 4, 5, 6], 1], [[2, 4, 6], 0]]
(seqs, seq_lens), labels = collate_fn(batch_data, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
print("seqs: ", seqs)
print("seq_lens: ", seq_lens)
print("labels: ", labels)
运行结果:
seqs: tensor([[1, 2, 3, 4, 5],[2, 4, 6, 0, 0]], device='cuda:0')
seq_lens: tensor([5, 3])
labels: tensor([1, 0], device='cuda:0')
可以看到,原始序列中长度为6的序列被截断为5,同时原始序列中长度为3的序列被填充到5,同时返回了非[PAD]
的序列长度。
1.4.2 封装dataloader
将collate_fn作为回调函数传入DataLoader中,其在返回一批数据时,可以通过collate_fn函数处理该批次的数据:
max_seq_len = 256
batch_size = 128
collate_fn = partial(collate_fn, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,shuffle=True, drop_last=False, collate_fn=collate_fn)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size,shuffle=False, drop_last=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,shuffle=False, drop_last=False, collate_fn=collate_fn)
2 模型构建
实践的整个模型结构如下图所示:
(1)嵌入层:将输入的数字序列(单词映射成的ID)进行向量化,即将每个数字映射为向量。 ---------------使用Pytorch API:torch.nn.Embedding来完成
(2)双向LSTM层:接收向量序列,分别用前向和反向更新循环单元。----------------使用Pytorch API:torch.nn.LSTM来完成【在定义LSTM时设置参数bidirectional为True,可使用双向LSTM。】
(3)汇聚层:将双向LSTM层所有位置上的隐状态进行平均,作为整个句子的表示。
(4)输出层:输出分类的几率。----------调用torch.nn.Linear来完成。
2.1 汇聚层算子
实现了AveragePooling算子进行隐状态的汇聚,首先利用序列长度向量生成掩码(Mask)矩阵【LSTM在传入批次数据的真实长度后,会对[PAD]位置返回零向量,但考虑到汇聚层与处理序列数据的模型进行解耦,因此在汇聚层的实现中,会对[PAD]位置进行掩码】,用于对文本序列中[PAD]位置的向量进行掩蔽,然后将该序列的向量进行相加后取均值。具体操作如下图:
代码实现如下:
class AveragePooling(nn.Module):def __init__(self):super(AveragePooling, self).__init__()def forward(self, sequence_output, sequence_length):# 假设 sequence_length 是一个 PyTorch 张量sequence_length = sequence_length.unsqueeze(-1).to(torch.float32)# 根据sequence_length生成mask矩阵,用于对Padding位置的信息进行maskmax_len = sequence_output.shape[1]mask = torch.arange(max_len, device='cuda') < sequence_length.to('cuda')mask = mask.to(torch.float32).unsqueeze(-1)# 对序列中paddling部分进行masksequence_output = torch.multiply(sequence_output, mask.to('cuda'))# 对序列中的向量取均值batch_mean_hidden = torch.divide(torch.sum(sequence_output, dim=1), sequence_length.to('cuda'))return batch_mean_hidden
2.2 模型汇总
# ===================模型汇总=====================
class Model_BiLSTM_FC(nn.Module):def __init__(self, num_embeddings, input_size, hidden_size, num_classes=2):super(Model_BiLSTM_FC, self).__init__()# 词典大小self.num_embeddings = num_embeddings# 单词向量的维度self.input_size = input_size# LSTM隐藏单元数量self.hidden_size = hidden_size# 情感分类类别数量self.num_classes = num_classes# 实例化嵌入层self.embedding_layer = nn.Embedding(num_embeddings, input_size, padding_idx=0)# 实例化LSTM层self.lstm_layer = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)# 实例化聚合层self.average_layer = AveragePooling()# 实例化输出层self.output_layer = nn.Linear(hidden_size * 2, num_classes)def forward(self, inputs):# 对模型输入拆分为序列数据和maskinput_ids, sequence_length = inputs# 获取词向量inputs_emb = self.embedding_layer(input_ids)packed_input = nn.utils.rnn.pack_padded_sequence(inputs_emb, sequence_length.cpu(), batch_first=True,enforce_sorted=False)# 使用lstm处理数据packed_output, _ = self.lstm_layer(packed_input)# 解包输出sequence_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)# 使用聚合层聚合sequence_outputbatch_mean_hidden = self.average_layer(sequence_output, sequence_length)# 输出文本分类logitslogits = self.output_layer(batch_mean_hidden)return logits
3 模型训练
from Runner import RunnerV3,Accuracy,plot_training_loss_acc
np.random.seed(0)
random.seed(0)
torch.seed()# 指定训练轮次
num_epochs = 1
# 指定学习率
learning_rate = 0.001
# 指定embedding的数量为词表长度
num_embeddings = len(word2id_dict)
# embedding向量的维度
input_size = 256
# LSTM网络隐状态向量的维度
hidden_size = 256
# 模型保存目录
save_dir = "./checkpoints/best.pdparams"# 实例化模型
model = Model_BiLSTM_FC(num_embeddings, input_size, hidden_size).to(device)
# 指定优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
# 指定损失函数
loss_fn = nn.CrossEntropyLoss()
# 指定评估指标
metric = Accuracy()
# 实例化Runner
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 模型训练
start_time = time.time()
runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=10, log_steps=10,save_path=save_dir)
end_time = time.time()
print("time: ", (end_time - start_time))
# ============ 绘制训练过程中在训练集和验证集上的损失图像和在验证集上的准确率图像 ===========
# sample_step: 训练损失的采样step,即每隔多少个点选择1个点绘制
# loss_legend_loc: loss 图像的图例放置位置
# acc_legend_loc: acc 图像的图例放置位置
plot_training_loss_acc(runner, fig_size=(16, 6), sample_step=10, loss_legend_loc="lower left",acc_legend_loc="lower right")
epoch =1:
[Evaluate] best accuracy performence has been updated: 0.81120 --> 0.81880
[Train] epoch: 0/1, step: 190/196, loss: 0.31320
[Evaluate] dev score: 0.81912, dev loss: 0.40167
[Evaluate] best accuracy performence has been updated: 0.81880 --> 0.81912
[Evaluate] dev score: 0.81728, dev loss: 0.40298
[Train] Training done!
time: 124.63010001182556
观察到模型未收敛,验证集和训练集的损失都在下降,于是我增大epoch为3;
epoch=3:
[Train] epoch: 0/3, step: 0/588, loss: 0.69394
[Train] epoch: 0/3, step: 10/588, loss: 0.70491
[Evaluate] dev score: 0.52592, dev loss: 0.68407
[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.52592
[Train] epoch: 0/3, step: 20/588, loss: 0.66260
[Evaluate] dev score: 0.62080, dev loss: 0.66250
[Evaluate] best accuracy performence has been updated: 0.52592 --> 0.62080
[Train] epoch: 0/3, step: 30/588, loss: 0.61762
[Evaluate] dev score: 0.62880, dev loss: 0.64522......
.....
[Train] epoch: 1/3, step: 370/588, loss: 0.31063
[Evaluate] dev score: 0.84656, dev loss: 0.36641
[Evaluate] best accuracy performence has been updated: 0.84272 --> 0.84656
[Train] epoch: 1/3, step: 380/588, loss: 0.30818
[Evaluate] dev score: 0.84264, dev loss: 0.37259
[Train] epoch: 1/3, step: 390/588, loss: 0.19482
[Evaluate] dev score: 0.84600, dev loss: 0.35535
......
[Evaluate] best accuracy performence has been updated: 0.85048 --> 0.85088
[Train] epoch: 2/3, step: 560/588, loss: 0.25688
[Evaluate] dev score: 0.84792, dev loss: 0.37273
[Train] epoch: 2/3, step: 570/588, loss: 0.12472
[Evaluate] dev score: 0.84856, dev loss: 0.36705
[Train] epoch: 2/3, step: 580/588, loss: 0.11621
[Evaluate] dev score: 0.84848, dev loss: 0.38805
[Evaluate] dev score: 0.84976, dev loss: 0.37620
[Train] Training done!
time: 356.31542706489563
验证集的损失及变化趋于平衡,且有过拟合的迹象,不必再增大epoch,模型在在验证集的准确率也在不断上升 。
4 模型评价
# ======== 模型评价 =============
model_path = "./checkpoints/best.pdparams"
runner.load_model(model_path)
accuracy, _ = runner.evaluate(test_loader)
print(f"Evaluate on test set, Accuracy: {accuracy:.5f}")
epoch =1:
Evaluate on test set, Accuracy: 0.81352
epoch=3:
Evaluate on test set, Accuracy: 0.84704
对比不同epoch的结果,当eooch=3时,模型训练较好,测试集上准确率达到0.847 。
5 模型预测
# =======模型预测==========
id2label={0:"消极情绪", 1:"积极情绪"}
text = "this movie is so great. I watched it three times already"
# 处理单条文本
sentence = text.split(" ")
words = [word2id_dict[word] if word in word2id_dict else word2id_dict['[UNK]'] for word in sentence]
words = words[:max_seq_len]
sequence_length = torch.tensor([len(words)], dtype=torch.int64)
words = torch.tensor(words, dtype=torch.int64).unsqueeze(0)
# 使用模型进行预测
logits = runner.predict((words.to(device), sequence_length.to(device)))
max_label_id = torch.argmax(logits, dim=-1).cpu().numpy()[0]
pred_label = id2label[max_label_id]
print("Label: ", pred_label)
Label: 积极情绪
根据输出可以看出模型预测正确。
6 完整代码
'''
@Function:基于双向LSTM实现文本分类
@Author: lxy
@Date: 2024/12/12
'''
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from functools import partial
import random
import numpy as np
import timedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
数据集处理部分
'''
# 加载IMDB数据集
def load_imdb_data(path):assert os.path.exists(path)# 初始化数据集列表trainset, devset, testset = [], [], []# 加载训练集数据for label in ['pos', 'neg']:label_path = os.path.join(path, 'train', label)for filename in os.listdir(label_path):if filename.endswith('.txt'):with open(os.path.join(label_path, filename), 'r', encoding='utf-8') as f:sentence = f.read().strip().lower() # 读取并处理每个评论trainset.append((sentence, label))# 加载测试集数据for label in ['pos', 'neg']:label_path = os.path.join(path, 'test', label)for filename in os.listdir(label_path):if filename.endswith('.txt'):with open(os.path.join(label_path, filename), 'r', encoding='utf-8') as f:sentence = f.read().strip().lower() # 读取并处理每个评论testset.append((sentence, label))# 随机拆分测试集的一半作为验证集random.shuffle(testset) # 打乱测试集顺序split_index = len(testset) // 2 # 计算拆分索引devset = testset[:split_index] # 选择测试集前一半作为验证集testset = testset[split_index:] # 剩下的部分作为测试集return trainset, devset, testset
# 加载IMDB数据集
train_data, dev_data, test_data = load_imdb_data("./dataset/")# # 打印一下加载后的数据样式
# print(train_data[4]) # 打印训练集中的第5条数据class IMDBDataset(Dataset):def __init__(self, examples, word2id_dict):super(IMDBDataset, self).__init__()self.word2id_dict = word2id_dictself.examples = self.words_to_id(examples)def words_to_id(self, examples):tmp_examples = []for idx, example in enumerate(examples):seq, label = example# 将单词映射为字典索引的ID, 对于词典中没有的单词用[UNK]对应的ID进行替代seq = [self.word2id_dict.get(word, self.word2id_dict['[UNK]']) for word in seq.split(" ")]# 映射标签: 'pos' -> 1, 'neg' -> 0label = 1 if label == 'pos' else 0 # 将标签从'pos'/'neg'转换为1/0tmp_examples.append([seq, label])return tmp_examplesdef __getitem__(self, idx):seq, label = self.examples[idx]return seq, labeldef __len__(self):return len(self.examples)# ===============ID映射=====================
def load_vocab(path):assert os.path.exists(path) # 确保词表文件路径存在words = [] # 初始化空列表,存储词表中的单词with open(path, "r", encoding="utf-8") as f: # 打开文件并读取内容words = f.readlines() # 读取文件中的所有行words = [word.strip() for word in words if word.strip()] # 移除每个单词的前后空白字符并去掉空行word2id = dict(zip(words, range(len(words)))) # 创建一个字典,将单词与对应的ID映射return word2id # 返回这个字典# 加载词表
word2id_dict = load_vocab("./dataset/imdb.vocab")
# 实例化Dataset
train_set = IMDBDataset(train_data, word2id_dict)
dev_set = IMDBDataset(dev_data, word2id_dict)
test_set = IMDBDataset(test_data, word2id_dict)# print('训练集样本数:', len(train_set))
# print('样本示例:', train_set[4])def collate_fn(batch_data, pad_val=0, max_seq_len=256):seqs, seq_lens, labels = [], [], []max_len = 0for example in batch_data:seq, label = example# 对数据序列进行截断seq = seq[:max_seq_len]# 对数据截断并保存于seqs中seqs.append(seq)seq_lens.append(len(seq))labels.append(label)# 保存序列最大长度max_len = max(max_len, len(seq))# 对数据序列进行填充至最大长度for i in range(len(seqs)):seqs[i] = seqs[i] + [pad_val] * (max_len - len(seqs[i]))return (torch.tensor(seqs).to(device), torch.tensor(seq_lens)), torch.tensor(labels).to(device)
# =======测试==============
# max_seq_len = 5
# batch_data = [[[1, 2, 3, 4, 5, 6], 1], [[2, 4, 6], 0]]
# (seqs, seq_lens), labels = collate_fn(batch_data, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
# print("seqs: ", seqs)
# print("seq_lens: ", seq_lens)
# print("labels: ", labels)# ===============封装dataloader=========================
max_seq_len = 256
batch_size = 128
collate_fn = partial(collate_fn, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,shuffle=True, drop_last=False, collate_fn=collate_fn)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size,shuffle=False, drop_last=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,shuffle=False, drop_last=False, collate_fn=collate_fn)
'''
数据集处理部分结束
''''''
模型构建部分
'''
# ======================汇聚层====================
class AveragePooling(nn.Module):def __init__(self):super(AveragePooling, self).__init__()def forward(self, sequence_output, sequence_length):# 假设 sequence_length 是一个 PyTorch 张量sequence_length = sequence_length.unsqueeze(-1).to(torch.float32)# 根据sequence_length生成mask矩阵,用于对Padding位置的信息进行maskmax_len = sequence_output.shape[1]mask = torch.arange(max_len, device='cuda') < sequence_length.to('cuda')mask = mask.to(torch.float32).unsqueeze(-1)# 对序列中paddling部分进行masksequence_output = torch.multiply(sequence_output, mask.to('cuda'))# 对序列中的向量取均值batch_mean_hidden = torch.divide(torch.sum(sequence_output, dim=1), sequence_length.to('cuda'))return batch_mean_hidden
# ===================模型汇总=====================
class Model_BiLSTM_FC(nn.Module):def __init__(self, num_embeddings, input_size, hidden_size, num_classes=2):super(Model_BiLSTM_FC, self).__init__()# 词典大小self.num_embeddings = num_embeddings# 单词向量的维度self.input_size = input_size# LSTM隐藏单元数量self.hidden_size = hidden_size# 情感分类类别数量self.num_classes = num_classes# 实例化嵌入层self.embedding_layer = nn.Embedding(num_embeddings, input_size, padding_idx=0)# 实例化LSTM层self.lstm_layer = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)# 实例化聚合层self.average_layer = AveragePooling()# 实例化输出层self.output_layer = nn.Linear(hidden_size * 2, num_classes)def forward(self, inputs):# 对模型输入拆分为序列数据和maskinput_ids, sequence_length = inputs# 获取词向量inputs_emb = self.embedding_layer(input_ids)packed_input = nn.utils.rnn.pack_padded_sequence(inputs_emb, sequence_length.cpu(), batch_first=True,enforce_sorted=False)# 使用lstm处理数据packed_output, _ = self.lstm_layer(packed_input)# 解包输出sequence_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)# 使用聚合层聚合sequence_outputbatch_mean_hidden = self.average_layer(sequence_output, sequence_length)# 输出文本分类logitslogits = self.output_layer(batch_mean_hidden)return logits
'''
模型构建部分结束
'''
'''
模型训练部分
'''
# ===============模型训练===================
from Runner import RunnerV3,Accuracy,plot_training_loss_acc
np.random.seed(0)
random.seed(0)
torch.seed()# 指定训练轮次
num_epochs = 1
# 指定学习率
learning_rate = 0.001
# 指定embedding的数量为词表长度
num_embeddings = len(word2id_dict)
# embedding向量的维度
input_size = 256
# LSTM网络隐状态向量的维度
hidden_size = 256
# 模型保存目录
save_dir = "./checkpoints/best.pdparams"# 实例化模型
model = Model_BiLSTM_FC(num_embeddings, input_size, hidden_size).to(device)
# 指定优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
# 指定损失函数
loss_fn = nn.CrossEntropyLoss()
# 指定评估指标
metric = Accuracy()
# 实例化Runner
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 模型训练
start_time = time.time()
runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=10, log_steps=10,save_path=save_dir)
end_time = time.time()
print("time: ", (end_time - start_time))
# ============ 绘制训练过程中在训练集和验证集上的损失图像和在验证集上的准确率图像 ===========
# sample_step: 训练损失的采样step,即每隔多少个点选择1个点绘制
# loss_legend_loc: loss 图像的图例放置位置
# acc_legend_loc: acc 图像的图例放置位置
plot_training_loss_acc(runner, fig_size=(16, 6), sample_step=10, loss_legend_loc="lower left",acc_legend_loc="lower right")
'''
模型训练部分结束
'''# ==================== 模型评价 =============
model_path = "./checkpoints/best.pdparams"
runner.load_model(model_path)
accuracy, _ = runner.evaluate(test_loader)
print(f"Evaluate on test set, Accuracy: {accuracy:.5f}")# =====================模型预测==========
id2label={0:"消极情绪", 1:"积极情绪"}
text = "this movie is so great. I watched it three times already"
# 处理单条文本
sentence = text.split(" ")
words = [word2id_dict[word] if word in word2id_dict else word2id_dict['[UNK]'] for word in sentence]
words = words[:max_seq_len]
sequence_length = torch.tensor([len(words)], dtype=torch.int64)
words = torch.tensor(words, dtype=torch.int64).unsqueeze(0)
# 使用模型进行预测
logits = runner.predict((words.to(device), sequence_length.to(device)))
max_label_id = torch.argmax(logits, dim=-1).cpu().numpy()[0]
pred_label = id2label[max_label_id]
print("Label: ", pred_label)
7 拓展实验
点击跳转--基于双向LSTM和注意力机制的文本分类
8 参考链接
参考资料: |
NNDL 实验6(下) - HBU_DAVID - 博客园 |
情感分析--数据集来源 |
IMDB数据集的解释_imdb数据集介绍-CSDN博客 |
一幅图真正理解LSTM、BiLSTM_bilstm和lstm的区别-CSDN博客 写的超细!! |
长短期记忆神经网络(LSTM)介绍以及简单应用分析 - 舞动的心 - 博客园 |
【掩码】深度学习时为什么需要掩码(Mask)? |
相关文章:
【实验16】基于双向LSTM模型完成文本分类任务
目录 1 数据集处理- IMDB 电影评论数据集 1.1 认识数据集 1.2 数据加载 1.3 构造Dataset类 1.4 封装DataLoader 1.4.1 collate_fn函数 1.4.2 封装dataloader 2 模型构建 2.1 汇聚层算子 2.2 模型汇总 3 模型训练 4 模型评价 5 模型预测 6 完整代码 7 拓展实验 …...
【中工开发者】鸿蒙商城app
这学期我学习了鸿蒙,想用鸿蒙做一个鸿蒙商城app,来展示一下。 项目环境搭建: 1.开发环境:DevEco Studio2.开发语言:ArkTS3.运行环境:Harmony NEXT base1 软件要求: DevEco Studio 5.0.0 Rel…...
SpringBoot 整合 MongoDB 实现文档存储
一、MongoDB 简介 MongoDB(来自于英文单词“Humongous”,中文含义为“庞大”)是可以应用于各种规模的企业、各个行业以及各类应用程序的开源数据库。基于分布式文件存储的数据库。由C语言编写。旨在为 WEB 应用提供可扩展的高性能数据存储解…...
鲲鹏麒麟安装ElasticSearch7.8.0
因项目需求需要在鲲鹏麒麟服务器上安装ElasticSearch7.8.0,考虑Docker方式安装比较简单,因此使用Docker方式安装 环境信息 操作系统:Kylin Linux Advanced Server release V10 (Tercel) Docker:18.09.0 [rootserver ~]# uname …...
NDN命名数据网络和域名的区别
NDN(Named Data Networking)网络的概念 NDN是一种新型的网络架构,也被称为命名数据网络。与传统的以IP地址为中心的网络架构不同,NDN是以数据(内容)本身命名为中心的网络架构。在传统网络中,我们通过IP地址来寻找主机设备,然后获取该设备上存储的内容。而在NDN网络中,…...
PyTorch基本使用-自动微分模块
学习目的:掌握自动微分模块的使用 训练神经网络时,最常用的算法就是反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。为了计算这些梯度,PyTorch 内置了名为 torch.autogra…...
关于linux kernel softlockup 的探究
1. 基本解释 softlockup:发生在某个 CPU 长时间占用资源,但 CPU 仍然可以响应中断 和调度器。软死锁通常不会导致系统崩溃,但可能会使系统响应变慢. 2. 驱动模拟softlockup 以下为代码实现 #include <linux/module.h> #include <…...
MySQL 时区参数 time_zone 详解
文章目录 前言1. 时区参数影响2. 如何设置3. 字段类型选择 前言 MySQL 时区参数 time_zone 有什么用?修改它有什么影响?如何设置该参数,本篇文章会详细介绍。 1. 时区参数影响 time_zone 参数影响着 MySQL 系统函数还有字段的 DEFAULT CUR…...
【计算机网络层】数据链路层 :局域网和交换机
🧸安清h:个人主页 🎥个人专栏:【计算机网络】【Mybatis篇】 🚦作者简介:一个有趣爱睡觉的intp,期待和更多人分享自己所学知识的真诚大学生。 目录 🎯局域网 🚦局域网…...
WebSocket、Socket、TCP 与 HTTP:深入探讨与对比
随着互联网技术的快速发展,现代Web应用对于实时通信的需求越来越高。传统的HTTP协议由于其无状态和请求-响应模式的限制,在实现高效、低延迟的实时通信方面存在一定的局限性。为了解决这一问题,WebSocket协议应运而生,它提供了一种…...
【开源免费】基于SpringBoot+Vue.JS在线办公系统(JAVA毕业设计)
本文项目编号 T 001 ,文末自助获取源码 \color{red}{T001,文末自助获取源码} T001,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…...
Vue指令
创建项目: vue init webpack 项目名称 element-ui npm i element-ui -saxios npm i axios1.1.3 -S vuex npm i vuex3.6.2 -S vuex持久化 npm i -S vuex-persistedstate4.1.0代理模版 proxyTable: {/api: {target: http://localhost:8081/,changeOrigin: true,pathRe…...
经典文献阅读之--ATI-CTLO(基于自适应时间间隔的连续时间Lidar-Only里程计)
0. 简介 激光雷达扫描中的运动失真,由机器人的激烈运动和环境地形特征引起,显著影响了3D激光雷达里程计的定位和制图性能。现有的失真校正解决方案在计算复杂性和准确性之间难以平衡。《ATI-CTLO: Adaptive Temporal Interval-based Continuous-Time Li…...
【GitHub分享】you-get项目
【GitHub分享】you-get 一、介绍二、安装教程三、使用教程四、配置ffmpeg五,卸载 如果大家想要更具体地操作可去开源网站查看手册,这里只是一些简单介绍,但是也够用一般,有什么问题,也可以留言。 一、介绍 you-get是一…...
右玉200MW光伏电站项目 微气象、安全警卫、视频监控系统
一、项目名称 山西右玉200MW光伏电站项目 微气象、安全警卫、视频监控系统 二、项目背景: 山西右玉光伏发电项目位于右玉县境内,总装机容量为200MW,即太阳能电池阵列共由200个1MW多晶硅电池阵列子方阵组成,每个子方阵包含太阳能…...
box 提取
box 提取 import json import os import shutilimport cv2 import numpy as np import pypinyinclass Aaa():passdef pinyin(word):s for i in pypinyin.pinyin(word, stylepypinyin.NORMAL):s .join(i)return s if __name__ __main__:selfAaa()base_dirrE:\data\dao\20241…...
【合作原创】使用Termux搭建可以使用的生产力环境(六)
前言 在上一篇【合作原创】使用Termux搭建可以使用的生产力环境(五)-CSDN博客我们讲到了如何美化xfce4桌面,达到类似于Windows的效果,这一篇将继续在上一篇桌面的基础上给我们的系统装上必要的软件,让它做到真正可以使…...
C#—索引器
C#—索引器 索引器(Indexer)是类中的一个特殊成员,它能够让对象以类似数组的形式来操作,使程序看起来更为直观,更容易编写。索引器与属性类似,在定义索引器时同样会用到 get 和 set 访问器,不同…...
Microsemi Libero SoC免费许可证申请指南(Microchip官网2024最新方法)
点击如下链接: https://www.microchip.com/en-us/products/fpgas-and-plds/fpga-and-soc-design-tools/fpga/licensing 点击右侧,请求免费的License 如果提示登录,请先登录Microchip账号。 点击Request Free License。 选项一年免费的Li…...
【CSS in Depth 2 精译_074】第 12 章 CSS 排版与间距概述 + 12.1 间距设置(下):行内元素的间距设置
当前内容所在位置(可进入专栏查看其他译好的章节内容) 第四部分 视觉增强技术 ✔️【第 12 章 CSS 排版与间距】 ✔️ 12.1 间距设置 12.1.1 使用 em 还是 px12.1.2 对行高的深入思考12.1.3 行内元素的间距设置 ✔️ 12.2 Web 字体12.3 谷歌字体 文章目…...
React 18
文章目录 React 18自动批处理并发特性Suspense 组件增强新 HookscreateRoot API 替代 ReactDOM.renderStrict Mode严格模式服务器端渲染改进性能优化 React 18 React 18 引入了一系列新特性和改进,旨在提升性能、改善用户体验,并简化开发流程。以下是 R…...
yolov,coco,voc标记的睡岗检测数据集,可识别在桌子上趴着睡,埋头睡觉,座椅上靠着睡,平躺着睡等多种睡姿的检测,6549张图片
yolov,coco,voc标记的睡岗检测数据集,可识别在桌子上趴着睡,埋头睡觉,座椅上靠着睡,平躺着睡等多种睡姿的检测,6549张图片 数据集分割 6549总图像数 训练组91% 5949图片 有效集9&#x…...
Pydantic中的discriminator:优雅地处理联合类型详解
Pydantic中的discriminator:优雅地处理联合类型详解 引言1. 什么是discriminator?2. 基本使用示例3. discriminator的工作原理4. 更复杂的实际应用场景5. 使用建议6. 潜在陷阱和注意事项结论最佳实践 引言 在Python的类型系统中,有时我们需要…...
vue实现文件流形式的导出下载
文章目录 Vue 项目中下载返回的文件流操作步骤一、使用 Axios 请求文件流数据二、设置响应类型为 ‘blob’三、创建下载链接并触发下载四、在 Vue 组件中集成下载功能五、解释与实例说明1、使用 Axios 请求文件流数据:设置响应类型为 blob:创建下载链接并…...
Dify工具前奏:一个好玩的镜像,selenium
文章目录 按照惯例,闲聊开篇通义千问给出的回答,蛮有趣的。什么是selenium?使用场景缺点按照惯例,闲聊开篇 眼看就要过0点了,今天写点有把握的。 我先卖个关子,问你们一个问题: 我用mobaxterm或者其它的工具,ssh访问到远程服务器。但我想在那台机器上打开浏览器该怎么…...
警惕!手动调整服务器时间可能引发的系统灾难
警惕!手动调整服务器时间可能引发的系统灾难 1. 鉴权机制1.1 基于时间戳的签名验证1.2 基于会话的认证机制(JWT、TOTP) 2. 雪花算法生成 ID 的影响2.1 时间戳回拨导致 ID 冲突2.2 ID 顺序被打乱 3. 日志记录与审计3.1 日志顺序错误3.2 审计日…...
Python泛型编程:TypeVar和Generic详解 - 写给初学者的指南
Python泛型编程:TypeVar和Generic详解 - 写给初学者的指南 前言1. 为什么需要泛型?2. TypeVar:定义泛型类型变量3. Generic:创建泛型类4. 多个泛型类型变量5. 使用场景小结结语 前言 大家好!今天我们来聊一聊Python中…...
单片机:实现控制LED灯亮灭(附带源码)
使用单片机控制LED灯的亮灭是一个非常基础的嵌入式应用项目,适合初学者学习如何操作GPIO(通用输入输出)端口以及如何控制外设。通过该项目,您可以学习如何通过按键输入、定时器控制或其他触发条件来控制LED灯的开关状态。 1. 项目…...
Dcoker安装nginx,完成反向代理和负载均衡
1. 简介 官网:nginx Nginx是一个高性能的 HTTP 和反向代理 Web 服务器。它的主要功能包括反向代理、负载均衡和动静分离等。正因为 Nginx的这些功能能够为系统带来性能和安全方面的诸多优势,我们在项目部署时需要引入 Nginx组件。接下来我们会逐一向大…...
Java转C之C/C++ 的宏定义和预处理
C/C 宏定义和预处理总结 C/C 的宏定义和预处理器是在编译前执行的一系列文本处理操作,用于包含文件、定义常量、条件编译和控制编译器行为。以下是全面总结,涵盖各种知识点、注意事项以及示例。 表1:C/C 预处理指令和功能 预处理指令功能描…...
【老白学 Java】数字格式化
数字格式化 文章来源:《Head First Java》修炼感悟。 很多时候需要对数字或日期进行格式化操作,来达到某些输出效果。Java 的 Formatter 类提供了很多扩展性功能用于字符串的格式化,只要调用 String 静态方法 format() ,传入参数…...
elementUI修改table样式
在Vue项目中,如果使用的是单文件组件(.vue),并且样式是通过<style>标签定义的,vue2可以使用/deep/,vue3可以使用::v-deep选择器来修改ElementUI组件的样式。 1.修改表头背景色 /deep/.el-table__head…...
Invalid bound statement (not found) 错误解决
出现这个错误提示:Invalid bound statement (not found): com.xxx.small_reservior.dao.WaterRainMapper.getWaterRainByRegion,通常表示 MyBatis 框架无法找到与给定的 getWaterRainByRegion 方法匹配的 SQL 映射语句。这种问题通常发生在以下几种情况中…...
AWD学习(二)
学习参考: AWD攻防学习总结(草稿状态,待陆续补充)_awd攻防赛入门-CSDN博客国赛分区赛awd赛后总结-安心做awd混子-安全客 - 安全资讯平台 记第一次 AWD 赛前准备与赛后小结-腾讯云开发者社区-腾讯云 AWD学习笔记 - DiaosSamas Blog…...
VUE常见问题汇总
目录 1、80端口占用问题 2、sass版本安装问题 3、Missing binding node_modules/node-sass/vendor/darwin-x64-72/binding.node 4、Downloading binary from https://github.com/sass/node-sass/releases/download/v4.14.1/win32-x64-83_binding.node Cannot download &qu…...
H.323音视频协议
概述 H.323是国际电信联盟(ITU)的一个标准协议栈,该协议栈是一个有机的整体,根据功能可以将其分为四类协议,也就是说该协议从系统的总体框架(H.323)、视频编解码(H.263)、…...
第六届地博会开幕,世界酒中国菜美食文化节同期启幕推动地标发展
第六届知交会暨地博会开幕,辽黔欧三地馆亮点纷呈,世界酒中国菜助力地理标志产品发展 第六届知交会暨地博会盛大开幕,多地展馆亮点频出,美食文化节同期启幕推动地标产业发展 12月9日,第六届粤港澳大湾区知识产权交易博…...
如何在安卓系统里面用C++写一个获取摄像头原始数据并保存成.yuv文件
在 Android 系统中使用 C 编写一个获取摄像头原始数据并保存为 .yuv 文件的程序,并且通过 Android.bp 编译,你需要结合 V4L2 和 Android NDK 的特性来实现。以下是详细的步骤和代码示例。 步骤 1: 设置权限 确保你的应用程序有访问摄像头的权限。在 An…...
算法分析与设计之分治算法
文章目录 前言一、分治算法divide and conquer1.1 分治定义1.2 分治法的复杂性分析:递归方程1.2.1 主定理1.2.2 递归树法1.2.3 迭代法 二、典型例题2.1 Mergesort2.2 Counting Inversions2.3 棋盘覆盖2.4 最大和数组2.5 Closest Pair of Points2.6 Karatsuba算法&am…...
AI大模型学习笔记|多目标算法梳理、举例
多目标算法学习内容推荐: 1.通俗易懂讲算法-多目标优化-NSGA-II(附代码讲解)_哔哩哔哩_bilibili 2.多目标优化 (python pyomo pareto 最优)_哔哩哔哩_bilibili 学习笔记: 通过网盘分享的文件:多目标算法学习笔记 链接: https://pan.baidu.com…...
【竞技宝】LOL:JDG官宣yagao离队
北京时间2024年12月13日,在英雄联盟S14全球总决赛结束之后,各大赛区都已经进入了休赛期,目前休赛期也快进入尾声,LPL大部分队伍都开始陆续官宣转会期的动向,其中JDG就在近期正式官宣中单选手yagao离队,而后者大概率将直接选择退役。 近日,JDG战队在官方微博上连续发布阵容变动消…...
iOS runtime总结数据结构,消息传递、转发和应用场景
runtime篇 首先看一下runtiem底层的数据结构 首先从objc_class这么一个结构体(数据结构)开始,objc_class继承于objc_object。 objc_object当中有一个成员变量叫isa_t,那么这个isa_t指针就指向一个objc_class类型的类对象ÿ…...
An error happened while trying to locate the file on the Hub and we cannot f
An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on. 关于上述comfy ui使用control net预处理器的报错问…...
力扣刷题TOP101: 32.BM39 序列化二叉树
目录: 目的 思路 复杂度 记忆秘诀 python代码 目的: 请实现两个函数,分别用来序列化和反序列化二叉树,不对序列化之后的字符串进行约束,但要求能够根据序列化之后的字符串重新构造出一棵与原二叉树相同的树。 思路…...
modern-screenshot: 一个比html2canvas 性能更好的网页截屏工具
在低代码平台等设计工具中,生成缩略图是非常基础的功能,最开始我们使用的是 html2canvas 工具,但是带来的问题是这个工具非常吃性能,生成缩略图时间动辄6s以上,后来发现 modern-screenshot 这个工具性能会好一些&#…...
使用 GD32F470ZGT6,手写 I2C 的实现
我的代码:https://gitee.com/a1422749310/gd32_-official_-code I2C 具体代码位置:https://gitee.com/a1422749310/gd32_-official_-code/blob/master/Hardware/i2c/i2c.c 黑马 - I2C原理 官方 - IIC 协议介绍 个人学习过程中的理解,有错误&…...
力扣 53. 最大子数组和 (动态规划)
给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分。 示例 1: 输入:nums [-2,1,-3,4,-1,2,1,-5,4] 输出:…...
【牛客小白月赛107 题解】
比赛链接 A. Cidoai的吃饭 题目大意 给定一个正整数 n n n,再给定三个正整数 a , b , c a, \ b, \ c a, b, c。初始时 a n s 0 ans 0 ans0。现在开始循环,每次循环按照从上到下的顺序选择第一条符合的执行(即执行完就再从 1. 1. 1. …...
Web day11 SpringBoot原理
目录 1.配置优先级: 2.Bean的管理: bean的作用域: 第三方bean: 方案一: 方案二: SpringBoot原理: 扫描第三方包: 方案1:ComponentScan 组件扫描 方案2࿱…...
JAVA实战:借助阿里云实现短信发送功能
亲爱的小伙伴们😘,在求知的漫漫旅途中,若你对深度学习的奥秘、JAVA 、PYTHON与SAP 的奇妙世界,亦或是读研论文的撰写攻略有所探寻🧐,那不妨给我一个小小的关注吧🥰。我会精心筹备,在…...