当前位置: 首页 > news >正文

大模型学习:从零到一实现一个BERT微调

目录

一、准备阶段

1.导入模块

2.指定使用的是GPU还是CPU

3.加载数据集

二、对数据添加词元和分词

1.根据BERT的预训练,我们要将一个句子的句头添加[CLS]句尾添加[SEP]

2.激活BERT词元分析器

3.填充句子为固定长度

代码解释:

三、数据处理

1.创建masks掩码矩阵

代码解释:

2.拆分数据集

3.将所有的数据转换为torch张量

4.选择批量大小并创建迭代器

代码解释:

四、BERT模型配置

1.初始化一个不区分大小写的 BERT 配置:

代码解释:

2.这些配置参数的作用:

3.加载模型

4.优化器分组参数

代码解释:

5.训练循环的超参数

代码解释:

五、训练循环

代码解释:

训练图解:

 六、使用测试数据集进行预测和评估

七、使用马修斯相关系数(MCC)评估

2. 代码实现:

1. 测试数据预处理与预测

2.模型预测与结果收集

3.计算MCC

到这里就完美收官咯!!!!! 大家点个赞吧!!!


本章将微调一个BERT模型来预测下游的可接受性判断任务,如果你的电脑还没有配置相关环境的可以去使用 Colaboratory - Colab,里面已经全部帮你配置好啦!而且还可以免费使用GPU。

一、准备阶段

1.导入模块

导入所需的预训练相关模块,包括用于词元化的 BertTokenizer、用于配置 BERT 模型的 BertConfig,还有 Adam 优化器(AdamW),以及序列分类模块(BertFo SequenceClassification):

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset,DataLoader,RandomSampler,SequentialSampler
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer,BertConfig
from transformers import BertForSequenceClassification,get_linear_schedule_with_warmup
# from transformers import AdamW
from torch.optim import AdamW
from tqdm import tqdm,trange
import pandas as pd
import io
import numpy as np
import matplotlib.pyplot as plt# 导入进度条
from tqdm import tqdm,trange
# 导入常用的标注python模块
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2.指定使用的是GPU还是CPU

使用GPU加速对我们的训练非常又有帮助

device = torch.device("cuda" if torch.cuda.is_avsilable() else "cpu")

3.加载数据集

这里的代码和数据是参考https://github.com/Denis2054/Transformers-for-NLP-2nd-Edition/tree/main/Chapter03

使用git仓库拉取一下就可以得到数据了

df=pd.read_csv("你拉取的数据in_domain_train.tsv的路径",delimiter='\t',header=None,names=['sentence_source','label','label_notes','sentence'])
# 展示数据维度
df.shape  # (8551,4)

随机抽取十个样本数据的看看:

df.sample(10)

 以看到数据集中的数据包含了以下四列(即.tsv文件中四个用制表符分隔的列)。

● 第1列:句子来源(用编号表示)

● 第2列:标注(0=不可接受,1= 可接受)

● 第3列:作者的标注

● 第4列;要分类的句子

二、对数据添加词元和分词

1.根据BERT的预训练,我们要将一个句子的句头添加[CLS]句尾添加[SEP]

代码解释:将数据中的需要分类的句子提取为sentences,循环出每个句子,在每个句子的句头添加[CLS]句尾添加[SEP],将数据中的标签提取为labels

sentences=df.sentence.values
sentences=["[CLS]"+ sentence+"[SEP]" for sentence in sentences]
labels=df.label.values

2.激活BERT词元分析器

这里是初始化一个预训练BERT词元分析器。相比与从头开始训练一个词元分析器相对,节省很多时间和资源。们选择了一个不区分大小写的词元分析器,激活它,并展示对第一个句子词元 化之后的结果:

代码讲解:tokenizer是我们初始化的词元分析器,BertTokenizer.from_pretrained('bert-base-uncased')是使用BERT中自带的预训练好的参数,关于BertTokenizer.from_pretrained可以去看我的另外一章博客BertTokenizer.from_pretreined。

        tokenizer_texts是已经每个句子词元分析好的一个迭代器,因为sentences中保存的句子的type为array类型,而tokenize中要传入的是字符串类型,所以这里要强转一下

词元分析后的一条句子为:Tokenize the first sentence: ['[CLS]', 'our', 'friends', 'wo', 'n', "'", 't', 'buy', 'this', 'analysis', ',', 'let', 'alone', 'the', 'next', 'one', 'we', 'propose', '.', '[SEP]']

tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer_texts=[tokenizer.tokenize(str(sent)) for sent in sentences]
print("Tokenize the first sentence: ")
print(tokenizer_texts[0])
# Tokenize the first sentence: 
['[CLS]', 'our', 'friends', 'wo', 'n', "'", 't', 'buy', 'this', 'analysis', ',', 'let', 'alone', 'the', 'next', 'one', 'we', 'propose', '.', '[SEP]']

3.填充句子为固定长度

上面的处理中我们不难想到,每个句子分析后的长度会随句子的大小而改变,而在BERT微调的时候需要保证句子的长度应用,所以我们要将长度不够的句子进行填充,我们将这个最大长度设置为128,对于长度超过128的数据我们将它截断,保证每个句子序列的大小都为128

代码解释:

input_ids:里面保存的是将上面的句子分词后的词元列表转化成对应数字的列表,其中将词元一个个的循环出来后经过tokenizer.convert_tokens_to_ids,它能把分词后的词元转化为整数 ID,从而让深度学习模型能够处理文本数据。在不同的库中,其使用方式可能会有所不同,但核心功能是一致的。

第二个input_ids:保存的是将每个词元序列转化成相同大小后的迭代器,pad_sequences函数的主要作用是将多个序列填充或截断至相同的长度,这在处理序列数据(像文本序列)时十分关键,因为神经网络通常要求输入数据具有统一的形状

from tensorflow.keras.preprocessing.sequence import pad_sequences
MAX_LEN=128
input_ids=[tokenizer.convert_tokens_to_ids(x) for x in tokenizer_texts]
print(input_ids[0])
input_ids=pad_sequences(input_ids,maxlen=MAX_LEN,dtype='long',truncating='post',padding='post')
print(input_ids[0])

三、数据处理

1.创建masks掩码矩阵

如果不知道为什么需要掩码的可以去看看transformers架构

为了防止模型对填充词元进行注意力计算,我们在前面的步骤中对序列进行了填充补齐。但是我们希望防止模型对这些填充 的词元进行注意力计算!首先创建一个空的 attention_masks 列表,用于存储每个序列的注意力掩码。然后, 对于输入序列(input_ids)中的每个序列(seq),我们遍历其中的每个词元。

针对每个词元,我们判断其索引是否大于 0。如果大于 0,则将对应位置的掩码 值设置为1,表示该词元是有效词元。如果等于0,则将对应位置的掩码值设置为0, 表示该词元是填充词元。最终得到的 attention_masks 列表中的每个元素都是一个与对应输入序列长度相同的 列表,其中每个位置的掩码值表示该位置的词元是否有效(1表示有效,0表示填充)。

通过使用注意力掩码,可确保在模型的注意力计算中,只有真实的词元会被考虑, 而填充词元则被忽略。这样可提高计算效率,并减少模型学习无用信息的概率。

代码解释:

attention_masks是保存所有词元掩码的列表,上面我们说到input_ids保存的是将每个词元序列转化成相同大小后的迭代器,我们将里面的每个词元序列遍历为seq,判断seq中的值是否大于0,如果大于0,那么它是有效词元,将他对应的掩码设置为1,反正为0

attention_masks=[]
for seq in input_ids:seq_mask=[float(i>0) for i in seq]attention_masks.append(seq_mask)
print(attention_masks[0])
"""[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"""

2.拆分数据集

将数据拆分成训练集和验证集,训练集和测试集的比例为9:1

# 拆分训练集和验证集 (90%训练,10%验证)
train_inputs, val_inputs, train_labels, val_labels = train_test_split(input_ids, labels, test_size=0.1, random_state=2025)
train_masks, val_masks, _, _ = train_test_split(attention_masks, labels, test_size=0.1, random_state=2025)  

3.将所有的数据转换为torch张量

微调模型需要使用 torch 张量,所以我们需要将数据转换为 torch 张量:

train_inputs = torch.tensor(train_inputs)validation_inputs = torch.tensor(validation_inputs)txain_labels=torch.tensor(train_labels)validation_labels = torch.tensor(validation_labels)train_masks = torch.tensor(train_masks)validation_masks = torch.tensor(validation_masks)

4.选择批量大小并创建迭代器

如果一股脑地将所有数据都喂进机器,会导致机器因为内存不足而崩溃。所以需 要将数据一批一批地喂给机器。这里将把批量大小(batch size)设置为 32 并创建迭代 器。然后将迭代器与 torch的 DataLoader 相结合,以批量训练大量数据集,以免导致 机器因为内存不足而崩溃:

代码解释:

这里我们选的批量大小(batch_size)为32,使用TensorDataset和DataLoader创建训练数据迭代器,如果对TensorDataset和DataLoader不清楚的可以去看看我的另外一篇博客:TensorData和DataLoader

RandomSampler 是一种随机采样器,从给定的数据集中随机抽取样本,可选择有放回或无放回采样。无放回采样时,从打乱的数据集里抽取样本;有放回采样时,可指定抽取的样本数量 num_samples

batch_size=32
# 训练数据迭代器
train_data=TensorDataset(train_inputs,train_masks,train_labels)
train_sampler=RandomSampler(train_data)
train_dataloader=DataLoader(train_data,sampler=train_sampler,batch_size=batch_size)
# 测试数据迭代器validation_data=TensorDataset(validation_inputs,validation_masks,validation_label)
validation_sampler=RandomSampler(train_data)
validation_dataloader=DataLoader(validation_data,sampler=validation_sampler,batch_size=batch_size)

四、BERT模型配置

1.初始化一个不区分大小写的 BERT 配置:

代码解释:

后面的代码寻妖用到transformers这个包,如果没有的pip安装一下。

configuration是初始化了一个包含BERT预训练模型中所有超参数的配置实例,如果BertConfig中不加任何的参数,那么会生成一个标准的BERT-base配置:

{"hidden_size": 768,          # 每个Transformer层的维度"num_hidden_layers": 12,     # Transformer层数(深度)"num_attention_heads": 12,    # 注意力头的数量"intermediate_size": 3072,   # FeedForward层的中间维度"vocab_size": 30522,         # 词表大小(需与预训练模型一致)"max_position_embeddings": 512, # 最大序列长度...
}

 model:BertModel是根据配置生成一个随机初始化权重的BERT模型,根据传入的配置信息生成。

configuration是一个保存模型内部存储的配置信息副本

try:import transformers
except:print("installing transformers")
from transformers import BertModel,BertConfig
configuration=BertConfig()model=BertModel(config=configuration)
configuration=model.config
print(configuration)
"""
输出为:
BertConfig {"_attn_implementation_autoset": true,"attention_probs_dropout_prob": 0.1,"classifier_dropout": null,"hidden_act": "gelu","hidden_dropout_prob": 0.1,"hidden_size": 768,"initializer_range": 0.02,"intermediate_size": 3072,"layer_norm_eps": 1e-12,"max_position_embeddings": 512,"model_type": "bert","num_attention_heads": 12,"num_hidden_layers": 12,"pad_token_id": 0,"position_embedding_type": "absolute","transformers_version": "4.50.0","type_vocab_size": 2,"use_cache": true,"vocab_size": 30522
}
"""

2.这些配置参数的作用:

"""

attention probs_dropout_prob:对注意力概率应用的 dropout 率,这里设置为

0.1。

● hidden_act;编码器中的非线性激活函数,这里使用 gelu。gelu 是高斯误差线

性单位(Gaussian Eror Linear Units)激活函数的简称,它对输入按幅度加权,

使其成为非线性。

● hidden_dropout_prob:应用于全连接层的 dropout 概率。嵌入、编码器和汇聚

器层中都有全连接。输出不总是对序列内容的良好反映。汇聚隐藏状态的序

第3章 微调BERT 模型

列可改善输出序列。这里设置为0.1。

● hidden_size:编码器层的维度,也是汇聚层的维度,这里设置为768。

● initializer_range:初始化权重矩阵时的标准偏差值,这里设置为0.02。

· intermediate_size:编码器前馈层的维度,这里设置为3072。

● layer_norm_eps:是层规范化层的 epsilon 值,这里设置为le-12。

● max_position_embeddings:模型使用的最大长度,这里设置为512。

● model_type:模型的名称,这里设置为 bert。

● numattention_heads:注意力头数,这里设置为12。

· num_hidden_layers:层数,这里设置为12。

● pad_tokenid:使用0作为填充词元的HD,以避免对填充词元进行训练。

57

 · type_vocab_size:token_type_ids的大小用于标识序列。例如,“the dog[SEP]

 The cat.[SEP]”可用词元 ID [0,0,0,1,1,1]表示。

· vocab_size:模型用于表示 input_ids 的不同词元数量。换句话说,这是模型

可以识别和处理的不同词元或单词的总数。在训练过程中,模型会根据给定

的词表将文本输入转换为对应的词元序列,其中包含的词元数量是

vocab_size。通过使用这个词表,模型能够理解和表示更广泛的语言特征。这

里设置为 30522。

讲解完这些参数后,接下来将加载预训练模型。

"""

3.加载模型

现在开始加载预训练BERT模型

BertForSequenceClassification.from_pretrained 能够让你加载预训练的 BERT 模型权重,并且可以根据需求调整模型以适应特定的序列分类任务。这个方法非常实用,因为借助预训练的权重,模型通常能更快收敛,并且在特定任务上表现更优。第一个参数bert-base-uncased意思是加载BERT的默认权重,如果你有别的模型权重可以填写它的名字或者路径;nums_labels:表示你这个任务中的类别数,我们这个任务的label只有两种,所以这里是2。

DataParallel:DataParallel 是一种数据并行的实现方式,其核心思想是将大规模的数据集分割成若干个较小的数据子集,然后将这些子集分配到不同的计算节点(如 GPU)上,每个节点运行相同的模型副本,但处理不同的数据子集。在每一轮训练结束后,各节点会将计算得到的梯度进行汇总,并更新模型参数。如果不知道分布式计算的可以去看看我的另外一篇博客如何在多个GPU上训练

model=BertForSequenceClassification.from_pretrained("bert-base-uncased",num_labels=2)
model=nn.DataParallel(model)
model.to(device)

4.优化器分组参数

在将为模型的参数初始化优化器。在进行模型微调的过程中,首先需要初始化 预训练模型已学到的参数值。 微调一个预训练模型时,通常会使用之前在大规模数据上训练好的模型作为初始 模型。这些预训练模型已通过大量数据和计算资源进行了训练,学到了很多有用的特 征表示和参数权重。因此,我们希望在微调过程中保留这些已经学到的参数值,而不 是重新随机初始化它们。 所以,程序会使用预训练模型的参数值来初始化优化器,以便在微调过程中更好 地利用这些已经学到的参数。这样可以加快模型收敛速度并提高微调效果;

代码解释:

这段代码是用与为BERT模型的参数设置差异化的权重衰减策略,是训练Transformer模型时的常用技巧

param_optimizer是以字典的方式保存模型中所有可训练参数的名称和值,named_parameters()函数是获取模型中所有的参数名称和值。

no_decay是定义无需权重衰减的参数类型,权重衰减对偏置项bias和归一化层的weight无益,bias可能破坏模型对称性,LayerNorm的weight需保持灵活性,正则化会抑制其适应性。

optimizer_grouped_parametes: 组1:分组设置优化策略,筛选出参数名称中不包含bias和LayerNorm.weight的参数然后将权重衰减率设为0.1。组2:禁止权重衰减的参数,筛选出参数名称中包含bias和LayerNorm.weight的参数将权重衰减率设为0.0

param_optimizer=list(model.named_parameters())no_decay=['bias','LayerNorm.weight']
optimizer_grouped_parametes=[{'params':[p for n,p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay_rate':0.1},{"params":[p for n,p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay_rate':0.0}
]

5.训练循环的超参数

训练循环中的超参数非常重要,尽管它们看起来可能无害。例如,Adam 优化器 会激活权重衰减并经历一个预热阶段。学习率(lr)和预热率(warnup)应该在优化阶段的早期设置为一个非常小的值,在一 定迭代次数后逐渐增加。这样可以避免出现过大的梯度和超调问题,以更好地优化模 型目标。

代码解释:

使用AdamW优化器,将模型的参数传入,并设置初始学习率为2e-5

定义一个函数calculate_accuracy来度量准确率,用于测试结果与标注进行比较,向函数中传入预测结果的概率分布和真实标签,pred_flat是查找这个概率分布中的最大概率的索引,flatten是将数组一维化,labels_flat是真实结果的一维化。最后,计算预测结果展平后的数组中与展平后的标签数组相等的元素数量占标签数组长度的比例,并将这个比例作为结果返回。

optimizer=AdamW(optimizer_grouped_parametes,lr=2e-5)
def calculate_accuracy(preds, labels):"""计算准确率的优化版本"""preds = np.argmax(preds, axis=1).flatten()labels = labels.flatten()return np.sum(preds == labels) / len(labels)

五、训练循环

我们的训练循环将遵循标准的学习过程。轮数(epochs)设置为 4,并将绘制损失和 准确率的度量值。训练循环使用 dataloader 来加载和训练批量。我们将对训练过程进 行度量和评估。 首先初始化 train_loss_set(用于存储损失和准确率的数值,以便后续绘图)。然后 开始训练每一轮,并运行标准的训练循环,

代码解释:

这段代码实现了BERT模型的完整训练和验证流程,包含以下核心步骤:

  1. 初始化训练记录容器
  2. 循环训练多个epoch
  3. 每个epoch包含训练阶段和验证阶段
  4. 记录并输出训练指标

train_loss_history:  储存每个epoch的平均训练损失

val_accuracy_history:存储每个epoch的验证集准确率

代码太多了,大家在代码中看注释吧,这里主要说一下训练步骤

1.初始化记录容器

2.设置外层epoch(循环次数)循环

3.设置model为训练模式

4.将数据挨个前向传播和反向传播更新参数

5.计算平均训练损失

6.将model设置为评估阶段并进行评估

7.计算平均验证准确率

8.打印训练信息

train_loss_history = []  # 存储每个epoch的平均训练损失
val_accuracy_history = []  # 存储每个epoch的验证集准确率for epoch_i in trange(epochs, desc="Epoch"):# ========== 训练阶段 ==========model.train()  # 设置模型为训练模式total_train_loss = 0  # 初始化累计损失for batch in train_dataloader:# 数据转移到GPUb_input_ids, b_input_mask, b_labels = tuple(t.to(device) for t in batch)# 梯度清零model.zero_grad()# 前向传播outputs = model(b_input_ids,attention_mask=b_input_mask,labels=b_labels)# 多GPU处理:取平均损失loss = outputs.loss.mean()# 反向传播loss.backward()"""在深度学习训练时,梯度可能会变得非常大,这会导致训练不稳定,甚至引发梯度爆炸的问题。torch.nn.utils.clip_grad_norm_ 函数通过对梯度的范数进行裁剪,避免梯度变得过大,从而让训练过程更加稳定。"""torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪# 参数更新optimizer.step()scheduler.step()total_train_loss += loss.item()# 记录平均训练损失avg_train_loss = total_train_loss / len(train_dataloader)train_loss_history.append(avg_train_loss)  # 记录历史损失# ========== 验证阶段 ==========model.eval()total_eval_accuracy = 0model.eval()  # 设置模型为评估模式total_eval_accuracy = 0  # 初始化累计准确率with torch.no_grad():  # 禁用梯度计算for batch in val_dataloader:b_input_ids, b_input_mask, b_labels = tuple(t.to(device) for t in         batch)# 前向传播outputs = model(b_input_ids, attention_mask=b_input_mask)# .logits是将model中还没有经过激活函数的值提取出来,因为验证不需要激活# 这样节省了显存,.detach将张量从计算图中分离,​断开梯度追踪。因为验证阶段不需要计算梯度,这一步可以节省内存并避免不必要的计算。# .cpu():如果张量在GPU上(例如通过.to('cuda')加载),这一步会将其移动到CPU内存中。NumPy无法直接处理GPU上的张量,必须转移到CPU。logits = outputs.logits.detach().cpu().numpy()label_ids = b_labels.to('cpu').numpy()total_eval_accuracy += calculate_accuracy(logits, label_ids)avg_val_accuracy = total_eval_accuracy / len(val_dataloader)val_accuracy_history.append(avg_val_accuracy)# 打印训练信息print(f"\nEpoch {epoch_i + 1}/{epochs}")print(f"Train loss: {avg_train_loss:.4f}")print(f"Validation Accuracy: {avg_val_accuracy:.4f}")

训练图解:

graph TDA[开始训练] --> B[设置训练模式]B --> C[遍历训练数据]C --> D[数据转GPU]D --> E[梯度清零]E --> F[前向传播]F --> G[计算损失]G --> H[反向传播]H --> I[梯度裁剪]I --> J[参数更新]J --> K[学习率更新]K --> CC --> L[计算平均损失]L --> M[验证模式]M --> N[遍历验证数据]N --> O[前向传播]O --> P[计算准确率]P --> Q[计算平均准确率]Q --> R[记录结果]R --> S[打印信息]S --> T[完成epoch?]T --是--> U[结束训练]T --否--> B

 六、使用测试数据集进行预测和评估

们使用了in_domain_traintsv 数据集训练 BERT下游模型。现在我们将使用基 于留出法!分出的测试数据集 outof_domain_dev.v 文件进行预测。我们的目标是预 测句子在语法上是否正确。 以下代码展示了测试数据准备过程:

# 加载测试数据
test_df = pd.read_csv("out_of_domain_dev.tsv", delimiter='\t', header=None,names=['sentence_source', 'label', 'label_notes', 'sentence'])# 预处理测试数据
test_input_ids, test_attention_masks, test_labels = preprocess_data(test_df, tokenizer, max_len=128)# 创建预测DataLoader
prediction_dataset = TensorDataset(test_input_ids, test_attention_masks, test_labels)
prediction_dataloader = DataLoader(prediction_dataset, sampler=SequentialSampler(prediction_dataset),batch_size=batch_size)# 初始化存储
predictions = []
true_labels = []model.eval()
for batch in prediction_dataloader:batch = tuple(t.to(device) for t in batch)b_input_ids, b_input_mask, b_labels = batchwith torch.no_grad():outputs = model(b_input_ids, attention_mask=b_input_mask)logits = outputs.logits.detach().cpu().numpy()label_ids = b_labels.cpu().numpy()predictions.append(logits)true_labels.append(label_ids)# 计算准确率
flat_predictions = np.concatenate(predictions, axis=0)
flat_predictions = np.argmax(flat_predictions, axis=1)
flat_true_labels = np.concatenate(true_labels, axis=0)accuracy = np.sum(flat_predictions == flat_true_labels) / len(flat_true_labels)
print(f"Test Accuracy: {accuracy:.4f}")

七、使用马修斯相关系数(MCC)评估

​1、MCC的核心原理与优势

马修斯相关系数(Matthews Correlation Coefficient, MCC)是一种综合评估二分类模型性能的指标,尤其适用于类别不平衡数据集。其优势包括:

  1. 全面性:同时考虑真阳性(TP)、真阴性(TN)、假阳性(FP)、假阴性(FN)。、
  2. 鲁棒性:在类别分布不均衡时仍能准确反映模型性能(例如医学诊断中的罕见病检测)。
  3. 可解释性:取值范围为[-1, 1],1表示完美预测,0表示随机猜测,-1表示完全错误。

计算公式:

2. 代码实现:

1. 测试数据预处理与预测
# 加载测试数据(示例路径需替换为实际路径)
test_df = pd.read_csv("out_of_domain_dev.tsv", delimiter='\t', header=None,names=['sentence_source', 'label', 'label_notes', 'sentence'])# 预处理(复用preprocess_data函数)
test_input_ids, test_attention_masks, test_labels = preprocess_data(test_df, tokenizer, max_len=128)# 创建DataLoader
prediction_dataset = TensorDataset(test_input_ids, test_attention_masks, test_labels)
prediction_dataloader = DataLoader(prediction_dataset, sampler=SequentialSampler(prediction_dataset), batch_size=batch_size)
2.模型预测与结果收集

 

# 初始化存储
predictions = []
true_labels = []model.eval()
for batch in prediction_dataloader:batch = tuple(t.to(device) for t in batch)b_input_ids, b_input_mask, b_labels = batchwith torch.no_grad():outputs = model(b_input_ids, attention_mask=b_input_mask)logits = outputs.logits.detach().cpu().numpy()label_ids = b_labels.cpu().numpy()predictions.append(logits)true_labels.append(label_ids)# 合并结果
flat_predictions = np.concatenate(predictions, axis=0)
flat_predictions = np.argmax(flat_predictions, axis=1)  # 将logits转为类别(0/1)
flat_true_labels = np.concatenate(true_labels, axis=0)
3.计算MCC
from sklearn.metrics import matthews_corrcoefmcc = matthews_corrcoef(flat_true_labels, flat_predictions)
print(f"Test MCC: {mcc:.4f}")

到这里就完美收官咯!!!!! 大家点个赞吧!!!

相关文章:

大模型学习:从零到一实现一个BERT微调

目录 一、准备阶段 1.导入模块 2.指定使用的是GPU还是CPU 3.加载数据集 二、对数据添加词元和分词 1.根据BERT的预训练,我们要将一个句子的句头添加[CLS]句尾添加[SEP] 2.激活BERT词元分析器 3.填充句子为固定长度 代码解释: 三、数据处理 1.…...

Git和GitCode使用(从Git安装到上传项目一条龙)

第一步 菜鸟教程-Git教程 点击上方链接,完成Git的安装,并了解Git 工作流程,知道Git 工作区、暂存区和版本库的区别 第二步 GitCode官方帮助文档-SSH 公钥管理 点击上方链接,完成SSH公钥设置 第三步(GitCode的官方引…...

NodeJs之http模块

一、概念: 1、协议:双方必须共同遵从的一组约定。 Hypertext Transfer Protocol:HTTP,超文本传输协议 2、请求: ① 请求报文的组成: 请求行请求头空行请求体 ② 请求行: 请求方法URLHTTP版本…...

【深度学习与实战】2.3、线性回归模型与梯度下降法先导案例--最小二乘法(向量形式求解)

为了求解损失函数 对 的导数,并利用最小二乘法向量形式求解 的值‌ 这是‌线性回归‌的平方误差损失函数,目标是最小化预测值 与真实值 之间的差距。 ‌损失函数‌: 考虑多个样本的情况,损失函数为所有样本的平方误差之和&a…...

在word中使用zotero添加参考文献并附带超链接

一、引言 在写大论文时,为了避免文中引用与文末参考文献频繁对照、修改文中引用顺序/引用文献时手动维护参考文献耗易出错,拟在 word 中使用 zotero 插入参考文献,并为每个参考文献附加超链接,实现交互式阅读。 版本&#xff1a…...

在Linux系统中将html保存为PNG图片

1 前言 之前使用Pyecharts库在Windows系统中生成图表并转换为PNG格式图片(传送门),现将代码放于Linux服务器上运行,结果发现错误,生成html文件之后无法保存图片。 2 原理 基于Selenium库的保存方案,其原…...

presto任务优化参数

presto引擎业内通常用它来做即席查询,它基于内存计算效率确实快,不过它自身的任务优化参数比较杂,不同类型的catalog能用的参数不完全一样,在官网上倒是可以看到相关资料,配置文件中写的见https://prestodb.io/docs/cu…...

uniapp + Axios + 小程序封装网络请求

前言 小程序自带的网络请求使用起来比较麻烦,不便于管理,就需要封装网络请求,减少繁琐步骤,封装最终效果,根据类别将网络请求封装在文件中,使用得时候调用文件名名称加文件中得自定义名称,就可…...

《网络管理》实践环节01:OpenEuler22.03sp4安装zabbix6.2

兰生幽谷,不为莫服而不芳; 君子行义,不为莫知而止休。 1 环境 openEuler 22.03 LTSsp4PHP 8.0Apache 2Mysql 8.0zabbix6.2.4 表1-1 Zabbix网络规划(用你们自己的特征网段规划) 主机名 IP 功能 备注 zbx6svr 19…...

4.6js面向对象

js原型继承 JavaScript 的原型链继承是其核心特性之一,理解原型链对于掌握 JavaScript 的面向对象编程至关重要。 1. ​原型(Prototype)基础 在 JavaScript 中,每个对象都有一个内部属性 [[Prototype]](可以通过 __p…...

【云服务器】在Linux CentOS 7上快速搭建我的世界 Minecraft Fabric 服务器搭建,Fabric 模组详细搭建教程

【云服务器】在Linux CentOS 7上快速搭建我的世界 Minecraft Fabric 服务器搭建,Fabric 模组详细搭建教程 一、 服务器介绍二、安装 JDK 21三、搭建 Minecraft 服务端四、本地测试连接五、如何添加模组(mods)六、添加服务,并设置开…...

SQL SELECT DISTINCT 语句详解:精准去重的艺术

SQL SELECT DISTINCT 语句详解:精准去重的艺术 一、为什么需要数据去重? 在日常数据库操作中,我们经常会遇到这样的场景:查询客户表时发现重复的邮箱地址,统计销售数据时出现冗余的订单记录,分析用户行为…...

从ChatGPT到AutoGPT——AI Agent的范式迁移

一、AI Agent的范式迁移 1. ChatGPT的局限性与Agent化需求 单轮对话的“工具属性” vs. 多轮复杂任务的“自主性” ChatGPT 作为强大的生成式AI,虽然能够进行连贯对话,但本质上仍然是“工具型”AI,依赖用户提供明确的指令,而无法自主规划和执行任务。 人类介入成本过高:提…...

SQL EXISTS 与 NOT EXISTS 运算符

EXISTS 和 NOT EXISTS 是 SQL 中的逻辑运算符,用于检查子查询是否返回任何行。它们通常用在 WHERE 子句中,与子查询一起使用。 EXISTS 运算符 EXISTS 运算符用于检查子查询是否返回至少一行数据。如果子查询返回任何行,EXISTS 返回 TRUE&…...

AGI 的概念、意义与未来展望

随着人工智能技术的飞速发展,我们已经见证了在图像识别、自然语言处理等特定领域取得的巨大突破。然而,这些成就都属于弱人工智能(Narrow AI)的范畴,它们只能在预设的任务范围内高效工作。 人们对于一种拥有更广泛、更…...

基于Java与Go的下一代DDoS防御体系构建实战

引言:混合云时代的攻防对抗新格局 2024年某金融平台遭遇峰值2.3Tbps的IPv6混合攻击,传统WAF方案在新型AI驱动攻击面前全面失效。本文将以Java与Go为技术栈,揭示如何构建具备智能决策能力的防御系统。 一、攻击防御技术矩阵重构 1.1 混合攻击特征识别 攻击类型Java检测方案…...

FPGA调试笔记

XILINX SSTL属性电平报错 错误如下: [DRC BIVRU-1] Bank IO standard Vref utilization: Bank 33 contains ports that use a reference voltage. In order to use such standards in a bank that is not configured to use INTERNAL_VREF, the banks VREF pin mu…...

Axure项目实战:智慧城市APP(七)我的、消息(显示与隐藏交互)

亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:智慧城市APP 主要内容:我的、消息、活动模块页面 应用场景:消息页设计、我的页面设计以及活动页面设计 案例展示&#xff…...

深度学习——图像余弦相似度

计算机视觉是研究图像的学问,在图像的最终评价时,往往需要用到一些图像相似度的度量指标,因此,在本文中我们将详细地介绍原生和调用第三方库的计算图像余弦相似度的方法。 使用原生numpy实现 import numpy as npdef image_cosin…...

求矩阵某列的和

设计函数sum_column( int A[E1(n)][E2(n)], int j ),E1(n)和E2(n)分别为用宏定义的行数和列数,j为列号。在该函数中,设计指针ptr&A[0][j],通过*ptr及ptrptrE2(n)访问第j列元素,从而求得第j列元素的和。在主函数中定…...

【论文分析】无人机轨迹规划,Fast-Planner:实时避障+全局最优的路径引导优化算法

这篇论文《Robust Real-time UAV Replanning Using Guided Gradient-based Optimization and Topological Paths》由香港科技大学提出,主要针对无人机(UAV)在复杂环境中的实时轨迹重新规划问题,提出了一种结合梯度优化和拓扑路径搜…...

李飞飞、吴佳俊团队新作:FlowMo如何以零卷积、零对抗损失实现ImageNet重构新巅峰

目录 一、摘要 二、引言 三、相关工作 四、方法 基于扩散先前的离散标记化器利用广告 架构 阶段 1A:模式匹配预训练 阶段 1B:模式搜索后训练 采样 第二阶段:潜在生成建模 五、Coovally AI模型训练与应用平台 六、实验 主要结果 …...

AutoDev 2.0 正式发布:智能体 x 开源生态,AI 自动开发新标杆

在我们等待了几个月之后,国内终于有模型(DeepSeek V3-0324)能支持 AutoDev 的能力,也因此是时候发布 AutoDev 2.0 了!在 AutoDev 2.0 中,你可以: 编码智能体 Sketch 进行自动化编程自动化编程的…...

PHP 应用MYSQL 架构SQL 注入跨库查询文件读写权限操作

MYSQL 注入:(目的获取当前 web 权限) 1 、判断常见四个信息(系统,用户,数据库名,版本) 2 、根据四个信息去选择方案 root 用户:先测试读写,后测试获取…...

鸿蒙-全屏播放页面(使用相对布局)---持续更新中

最终实现效果图: 实现步骤 创建FullScreenPlay.ets全品播放页面 并将其修改为启动页面。 全屏播放,屏幕必然横过来,所以要将窗口横过来。 编辑 src/main/ets/entryability/EntryAbility.ets 若写在/EntryAbility.ets中,则所有…...

第4期:重构软件测试体系——生成式AI如何让BUG无所遁形

真实战场报告 某金融系统上线前,测试团队用AI生成3000条边缘用例,发现了一个隐藏极深的并发漏洞——该BUG在传统用例覆盖下需要7年才会触发一次。这次发现直接避免了可能上亿元的资金风险! 一、测试革命:当AI遇见质量保障 场景1&…...

力扣.旋转矩阵Ⅱ

59. 螺旋矩阵 II - 力扣&#xff08;LeetCode&#xff09; 代码区&#xff1a; class Solution {const int MAX25; public:vector<vector<int>> generateMatrix(int n) {vector<vector<int>> ans;vector<int> hang;int len_nn;int arry[25][25]…...

Docker 安装部署Harbor 私有仓库

Docker 安装部署Harbor 私有仓库 系统环境:redhat x86_64 一、首先部署docker 环境 定制软件源 wget https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo -O /etc/yum.repos.d/docker-ce.repoyum install -y yum-utils device-mapper-persistent-data lvm2…...

SQL Server 中常见的数据类型及其详细解释、内存占用和适用场景

以下是 SQL Server 中常见的数据类型及其详细解释、内存占用和适用场景&#xff1a; 数据类型类别数据类型解释内存占用适用场景整数类型bigint用于存储范围较大的整数&#xff0c;范围是 -2^63 (-9,223,372,036,854,775,808) 到 2^63-1 (9,223,372,036,854,775,807)8 字节需要…...

javascript实现一个函数,将字符串中的指定子串全部替换为另一个字符串的原理,以及多种方法实现。

大白话javascript实现一个函数&#xff0c;将字符串中的指定子串全部替换为另一个字符串的原理&#xff0c;以及多种方法实现。 在JavaScript里&#xff0c;要是你想把字符串里的指定子串都替换成另外一个字符串&#xff0c;有不少方法可以实现。下面我会详细介绍实现的原理&a…...

Python 3 与 MySQL 数据库连接:mysql-connector 模块详解

Python 3 与 MySQL 数据库连接&#xff1a;mysql-connector 模块详解 概述 在Python 3中&#xff0c;与MySQL数据库进行交互是一个常见的需求。mysql-connector是一个流行的Python模块&#xff0c;它提供了与MySQL数据库连接和交互的接口。本文将详细介绍mysql-connector模块…...

HCIA-Datacom高阶:基础的单区域 OSPF 与多区域 OSPF的配置

动态路由协议是实现网络高效通信的关键技术之一。开放式最短路径优先&#xff08;Open Shortest Path First&#xff0c;OSPF&#xff09;协议作为内部网关协议&#xff08;IGP&#xff09;的一种&#xff0c;因其高效性、稳定性和扩展性&#xff0c;在大型网络中得到了广泛应用…...

蓝桥杯单片机刷题——E2PROM记录开机次数

设计要求 使用E2PROM完成数据记录功能&#xff0c;单片机复位次数记录到E2PROM的地址0中。每复位一次数值加1&#xff0c;按下按键S4&#xff0c;串口发送复位次数。串口发送格式如下&#xff1a; Number&#xff1a;1 备注&#xff1a; 单片机IRC振荡器频率设置为12MHz。 …...

杂草YOLO系列数据集4000张

一份开源数据集——杂草YOLO数据集&#xff0c;该数据集适用于农业智能化、植物识别等计算机视觉应用场景。 数据集详情 ​训练集&#xff1a;3,664张高清标注图像​测试集&#xff1a;180张多样性场景样本​验证集&#xff1a;359张严格筛选数据 下载链接 杂草YOLO数据集分…...

Python自动化面试通关秘籍

Python自动化测试工程师面试&#xff0c;不仅仅是考察你的代码能力&#xff0c;更看重你如何在项目中灵活运用工具和框架解决实际问题。如果你正准备面试&#xff0c;这篇文章将为你总结最常见的高频考题及答题技巧&#xff0c;帮助你快速上手&#xff0c;通关面试&#xff0c;…...

机器学习的一百个概念(1)单位归一化

前言 本文隶属于专栏《机器学习的一百个概念》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索&…...

Python 笔记 (二)

Python Note 2 1. Python 慢的原因2. 三个元素3. 标准数据类型4. 字符串5. 比较大小: 富比较方法 rich comparison6. 数据容器 (支持*混装* )一、允许重复类 (list、tuple、str)二、不允许重复类 (set、dict)1、集合(set)2、字典(dict)3、特殊: 双端队列 deque 三、数据容器的共…...

【商城实战(97)】ELK日志管理系统的全面应用

【商城实战】专栏重磅来袭!这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建,运用 uniapp、Element Plus、SpringBoot 搭建商城框架,到用户、商品、订单等核心模块开发,再到性能优化、安全加固、多端适配,乃至运营推广策略,102 章内容层层递进。无论是想…...

3.使用epoll实现单线程并发服务器

目录 1. epoll的概述 2. 多线程与epoll的处理流程 2.1 多线程处理流程 2.2 epoll处理流程 3. epoll与多线程的比较 4. epoll的操作函数 4.1 epoll_create() 4.2 epoll_ctl() 4.3 epoll_wait() 5. 示例代码 6. epoll的工作模式 7. 使用O_NONBLOCK防止阻塞 8.运行代…...

蓝桥杯真题------R格式(高精度乘法,高精度加法)

对于高精度乘法和加法的同学可以学学这几个题 高精度乘法 高精度加法 文章目录 题意分析部分解全解 后言 题意 给出一个整数和一个浮点数&#xff0c;求2的整数次幂和这个浮点数相乘的结果最后四舍五入。、 分析 我们可以发现&#xff0c;n的范围是1000,2的1000次方非常大&am…...

PyCharm操作基础指南

一、安装与配置 1. 版本选择 专业版&#xff1a;支持 Web 开发&#xff08;Django/Flask&#xff09;、数据库工具、科学计算等&#xff08;需付费&#xff09;。 社区版&#xff1a;免费&#xff0c;适合纯 Python 开发。 2. 安装步骤 访问 JetBrains 官网 下载对应版本。…...

21 python __name__ 与 __main__

在办公室里&#xff0c;每个员工都有自己的工牌&#xff0c;上面写着姓名和部门。 一、__name__&#xff1a;模块的名字 Python 模块也有类似的 "工牌"——__name__属性&#xff0c;它记录了模块的身份&#xff1a; 直接运行时 → __name__ "__main__"&…...

NixVis 开源轻量级 Nginx 日志分析工具

NixVis NixVis 是一款基于 Go 语言开发的、开源轻量级 Nginx 日志分析工具&#xff0c;专为自部署场景设计。它提供直观的数据可视化和全面的统计分析功能&#xff0c;帮助您实时监控网站流量、访问来源和地理分布等关键指标&#xff0c;无需复杂配置即可快速部署使用。 演示…...

elementUI el-image图片加载失败解决

是不是&#xff0c;在网上找了一些&#xff0c;都不行&#xff0c;这里一行代码&#xff0c;解决&#xff0c;后端返回图片路径&#xff0c;el-image图片加载失败的问题 解决办法&#xff0c; vue项目里&#xff0c;index.html文件里加一行代码就可 <meta name"refe…...

lxd-dashboard 图形管理LXD/LXC

前言 LXD-WEBGUI是一个完全用AngularJS编写的Web应用程序,无需应用服务器、数据库或其他后端服务支持。只需要简单地托管静态HTML和JavaScript文件,就能立即投入使用。这个项目目前处于测试阶段,提供了直观的用户界面,帮助用户便捷地管理和控制LXD实例。 安装lxd-dashboa…...

C# MemoryStream 使用详解

总目录 前言 在.NET开发中&#xff0c;流&#xff08;Stream&#xff09;是一个用于处理输入和输出的抽象类&#xff0c;MemoryStream是流的一个具体实现&#xff0c;它允许我们在内存中读写数据&#xff0c;就像操作文件一样&#xff0c;而无需涉及磁盘 I/O 操作。尤其适合需…...

(二)万字长文解析:deepResearch如何用更长的思考时间换取更高质量的回复?各家产品对比深度详解

DeepResearch的研究背景 业务背景&#xff1a;用更长的等待时间&#xff0c;换取更高质量、更具实用性的结果 当前AI技术发展正经历从“即时响应”到“深度思考”的范式转变。用户对延迟的容忍度显著提升&#xff0c;从传统200ms的交互响应放宽至数秒甚至数分钟&#xff0c;以…...

Redis场景问题1:缓存穿透

Redis 缓存穿透是指在缓存系统&#xff08;如 Redis&#xff09;中&#xff0c;当客户端请求的数据既不在缓存中&#xff0c;也不在数据库中时&#xff0c;每次请求都会直接穿透缓存访问数据库&#xff0c;从而给数据库带来巨大压力&#xff0c;甚至可能导致数据库崩溃。下面为…...

数据结构(并查集,图)

并查集 练习版 class UnionFindSet { public:void swap(int* a, int* b){int tmp *a;*a *b;*b tmp;}UnionFindSet(size_t size):_ufs(size,-1){}int UnionFind(int x){}void Union(int x1, int x2){}//长分支改为相同节点int FindRoot(int x){}bool InSet(int x1, int x2)…...

深度学习篇---断点重训模型部署文件

文章目录 前言一、断点重训&#xff08;Checkpoint&#xff09;文件1. 动态图&#xff08;DyGraph&#xff09;模式.pdparams 文件.pdopt 文件.pdscaler 文件.pdmeta 或 .pkl 文件 2. 静态图&#xff08;Static Graph&#xff09;模式.pdparams 和 .pdopt 文件.ckpt 文件 3. 恢…...