【从0到1搞懂大模型】神经网络的实现:数据策略、模型调优与评估体系(3)
一、数据集的划分
(1)按一定比例划分为训练集和测试集
我们通常取8-2、7-3、6-4、5-5比例切分,直接将数据随机划分为训练集和测试集,然后使用训练集来生成模型,再用测试集来测试模型的正确率和误差,以验证模型的有效性。
这种方法常见于决策树、朴素贝叶斯分类器、线性回归和逻辑回归等任务中。
(2)交叉验证法
交叉验证一般采用k折交叉验证,即k-fold cross validation,往往k取为10。在这种数据集划分法中,我们将数据集划分为k个子集,每个子集均做一次测试集,每次将其余的作为训练集。在交叉验证时,我们重复训练k次,每次选择一个子集作为测试集,并将k次的平均交叉验证的正确率作为最终的结果。
from sklearn.model_selection import KFold
import numpy as np# 将PyTorch Tensor转为Numpy
X_np = X.numpy()
y_np = y.numpy()# 5折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_np)):# 转换为PyTorch DataLoadertrain_fold = TensorDataset(torch.from_numpy(X_np[train_idx]), torch.from_numpy(y_np[train_idx]))val_fold = TensorDataset(torch.from_numpy(X_np[val_idx]), torch.from_numpy(y_np[val_idx]))train_loader = DataLoader(train_fold, batch_size=32, shuffle=True)val_loader = DataLoader(val_fold, batch_size=32)print(f"Fold {fold+1}: Train {len(train_fold)}, Val {len(val_fold)}")
(3)训练集、验证集、测试集法
我们首先将数据集划分为训练集和测试集,由于模型的构建过程中也需要检验模型,检验模型的配置,以及训练程度,过拟合还是欠拟合,所以会将训练数据再划分为两个部分,一部分是用于训练的训练集,另一部分是进行检验的验证集。验证集可以重复使用,主要是用来辅助我们构建模型的。
训练集用于训练得到神经网络模型,然后用验证集验证模型的有效性,挑选获得最佳效果的模型,直到我们得到一个满意的模型为止。最后,当模型“通过”验证集之后,我们再使用测试集测试模型的最终效果,评估模型的准确率,以及误差等。测试集只在模型检验时使用,绝对不能根据测试集上的结果来调整网络参数配置,以及选择训练好的模型,否则会导致模型在测试集上过拟合。
一般来说,最终的正确率,训练集大于验证集,验证集大于测试集。
对于部分机器学习任务,我们划分的测试集必须是模型从未见过的数据,比如语音识别中一个完全不同的人的说话声,图像识别中一个完全不同的识别个体。这时,一般来说,训练集和验证集的数据分布是同分布的,而测试集的数据分布与前两者会略有不同。在这种情况下,通常,测试集的正确率会比验证集的正确率低得多,这样就可以看出模型的泛化能力,可以预测出实际应用中的真实效果。
下面是按照 8-1-1 划分数据集的代码示例:
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split# 生成示例数据
X = torch.randn(1000, 10) # 1000个样本,10维特征
y = torch.randint(0, 2, (1000,)) # 二分类标签
dataset = TensorDataset(X, y)# 划分比例:80%训练,10%验证,10%测试
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_sizetrain_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size],generator=torch.Generator().manual_seed(42) # 固定随机种子
)# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)print(f"训练集样本数: {len(train_dataset)}") # 直接查看原始数据集
print(f"DataLoader 批次数量: {len(train_loader)}") # 总批次 = 样本数 // batch_size# 获取第一个批次的数据
batch = next(iter(train_loader)) # 返回的是 (features, labels)
features, labels = batch# 查看特征和标签的形状
print("特征张量形状:", features.shape) # [batch_size, 10]
print("标签张量形状:", labels.shape) # [batch_size]
print("标签示例:", labels[:5]) # 查看前5个标签
二、偏差与方差
-
假设这就是数据集,如果给这个数据集拟合一条直线,可能得到一个逻辑回归拟合,但它并不能很好地拟合该数据,这是高偏差(high bias)的情况,我们称为**“欠拟合”(underfitting)**。
-
相反的如果我们拟合一个非常复杂的分类器,比如深度神经网络或含有隐藏单元的神经网络,可能就非常适用于这个数据集,但是这看起来也不是一种很好的拟合方式分类器方差较高(high variance),数据过度拟合(overfitting)。
衡量方式
- 一般可通过查看训练集与验证集误差来诊断。
-
评估偏差(bias),一般看训练集 训练集误差大——偏差较高,欠拟合
-
评估方差(variance),一般看验证集 训练集误差小,验证集误差大——方差较高,过拟合
学习曲线
-
学习曲线作用: 查看模型的学习效果; 通过学习曲线可以清晰的看出模型对数据的过拟合和欠拟合;
-
学习曲线:随着训练样本的逐渐增多,算法训练出的模型的表现能力;
-
表现能力:也就是模型的预测准确情况。
总结就是如果训练样本准确率一直上不去就是欠拟合,如果训练集准确率很高,但是验证集很低,就是过拟合。下面是两个案例
案例 1——欠拟合
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 生成非线性数据(100个样本)
torch.manual_seed(42)
x = torch.unsqueeze(torch.linspace(-5, 5, 100), 1)
y = torch.sin(x) * 2 + torch.normal(0, 0.3, x.shape)# 划分训练集(70%)和验证集(30%)
split = int(0.7 * len(x))
x_train, y_train = x[:split], y[:split]
x_val, y_val = x[split:], y[split:]# 构建欠拟合模型(单层线性回归)
class UnderfitModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(1, 1) # 仅一个线性层def forward(self, x):return self.linear(x)model = UnderfitModel()
optimizer = optim.SGD(model.parameters(), lr=0.01) # 使用低学习率
criterion = nn.MSELoss()# 训练过程(仅50次迭代)
train_loss = []
val_loss = []
for epoch in range(50):# 训练模式model.train()output = model(x_train)loss = criterion(output, y_train)optimizer.zero_grad()loss.backward()optimizer.step()# 验证模式model.eval()with torch.no_grad():val_pred = model(x_val)v_loss = criterion(val_pred, y_val)# 记录损失train_loss.append(loss.item())val_loss.append(v_loss.item())# 可视化损失曲线
plt.figure(figsize=(10,5))
plt.plot(train_loss, label='train_loss')
plt.plot(val_loss, label='valid_loss')
plt.ylim(0, 5)
plt.legend()
plt.title("欠拟合训练过程")
plt.show()# 最终预测可视化
model.eval()
with torch.no_grad():pred = model(x)plt.figure(figsize=(12,5))
plt.scatter(x_train, y_train, c='r', label='train_data')
plt.scatter(x_val, y_val, c='g', label='valid_data')
plt.plot(x, pred.numpy(), 'b-', lw=3, label='model_predict')
plt.plot(x, torch.sin(x)*2, 'k--', label='true_function')
plt.legend()
plt.show()# 输出误差指标
print(f'[最终误差] 训练集:{train_loss[-1]:.4f} | 验证集:{val_loss[-1]:.4f}')
print(f'模型参数:w={model.linear.weight.item():.2f}, b={model.linear.bias.item():.2f}')
它最终的学习曲线如下
可以很明显看到模型预测与真实曲线相差很远
案例 2——过拟合
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 生成少量训练数据(20个样本)
torch.manual_seed(1)
x_train = torch.unsqueeze(torch.linspace(-5, 5, 20), dim=1)
y_train = 1.2 * x_train + 0.8 + torch.normal(0, 0.5, size=x_train.size())# 构建过参数化模型(4层全连接网络)
class OverfitModel(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 100),nn.ReLU(),nn.Linear(100, 100),nn.ReLU(),nn.Linear(100, 100),nn.ReLU(),nn.Linear(100, 1))def forward(self, x):return self.net(x)model = OverfitModel()
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# 训练循环(3000次迭代)
loss_history = []
for epoch in range(3000):output = model(x_train)loss = criterion(output, y_train)optimizer.zero_grad()loss.backward()optimizer.step()loss_history.append(loss.item())# 生成测试数据(同分布但未参与训练)
x_test = torch.unsqueeze(torch.linspace(-6, 6, 100), dim=1)
y_test = 1.2 * x_test + 0.8 + torch.normal(0, 0.5, size=x_test.size())# 绘制结果对比
plt.figure(figsize=(12,5))
plt.scatter(x_train.numpy(), y_train.numpy(), c='r', label='train_data')
plt.plot(x_test.numpy(), model(x_test).detach().numpy(), 'b-', lw=3, label='predict')
plt.plot(x_test.numpy(), 1.2*x_test+0.8, 'g--', label='true_function')
plt.legend()
plt.show()# 输出训练误差和测试误差
train_loss = criterion(model(x_train), y_train)
test_loss = criterion(model(x_test), y_test)
print(f'训练误差:{train_loss.item():.4f}')
print(f'测试误差:{test_loss.item():.4f}')
可以看到最后生成的拟合曲线如蓝色所示,很明显过拟合,切测试误差比训练误差大很多
三、过拟合&欠拟合的处理方式
1、首先根据训练集效果来判断是否是高偏差?也就是是否欠拟合。
如果不是,跳转到下一步(判断是否高方差)。
如果是,有四种可尝试的方法:
A、新网络,比如:更多的隐藏层或隐藏单元。
B、增加新特征,可以考虑加入进特征组合、高次特征或者添加多项式特征(将线性模型通过添加二次项或者三次项使模型泛化能力更强)
C、用更多时间训练算法。
D、尝试更先进的优化算法。
反复调试,直到偏差降到和接受范围内,然后进行下一步。
2、根据验证集效果来判断是否是高方差?也就是是否过拟合。
如果不是,说明得到了很好的结果,训练结束,开始将该模型放入测试集。
如果是,有三种可尝试的方法:
A、更多数据来训练。
B、正则化来减少过拟合。
C、控制模型的复杂度,用dropout、early stopping等方法
D、尝试新网络框架(有时有用有时没用)。
名词解释
- 正则化
L2正则化:目标函数中增加所有权重w参数的平方之和, 逼迫所有w尽可能趋向零但不为零. 因为过拟合的时候, 拟合函数需要顾忌每一个点, 最终形成的拟合函数波动很大, 在某些很小的区间里, 函数值的变化很剧烈, 也就是某些w非常大. 为此, L2正则化的加入就惩罚了权重变大的趋势.
model = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)# 设置优化器时添加weight_decay参数(L2系数)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
L1正则化:目标函数中增加所有权重w参数的绝对值之和, 逼迫更多w为零(也就是变稀疏. L2因为其导数也趋0, 奔向零的速度不如L1给力了).大家对稀疏规则化趋之若鹜的一个关键原因在于它能实现特征的自动选择。一般来说,xi的大部分元素(也就是特征)都是和最终的输出yi没有关系或者不提供任何信息的,在最小化目标函数的时候考虑xi这些额外的特征,虽然可以获得更小的训练误差,但在预测新的样本时,这些没用的特征权重反而会被考虑,从而干扰了对正确yi的预测。稀疏规则化算子的引入就是为了完成特征自动选择的光荣使命,它会学习地去掉这些无用的特征,也就是把这些特征对应的权重置为0。
L1 正则化 torch 没有直接实现,可以手动实现
def l1_regularization(model, lambda_l1):l1_loss = 0for param in model.parameters():l1_loss += torch.sum(torch.abs(param))return lambda_l1 * l1_loss# 训练循环
for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)# 添加L1正则项l1_lambda = 0.001loss += l1_regularization(model, l1_lambda)loss.backward()optimizer.step()
下面是一个同时引用 L1 正则化和 L2 正则化的案例
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass RegularizedModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Sequential(nn.Flatten(), nn.Linear(784, 256),nn.BatchNorm1d(256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):return self.fc(x)def train(model, train_loader, lambda_l1=0, lambda_l2=0):optimizer = optim.AdamW(model.parameters(), lr=0.001)for epoch in range(100):total_loss = 0for x, y in train_loader:optimizer.zero_grad()pred = model(x)loss = F.cross_entropy(pred, y)# L1正则项if lambda_l1 > 0:l1 = sum(p.abs().sum() for p in model.parameters())loss += lambda_l1 * l1# L2正则项if lambda_l2 > 0:l2 = sum(p.pow(2).sum() for p in model.parameters())loss += 0.5 * lambda_l2 * l2loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}")# 使用示例import torch
from torch.utils.data import DataLoader, TensorDataset# 参数配置
batch_size = 16
input_dim = 784 # 对应模型输入维度
num_classes = 10
num_samples = 1000 # 总样本量# 生成正态分布虚拟数据
X = torch.randn(num_samples, input_dim) # 形状 (1000, 784)
y = torch.randint(0, num_classes, (num_samples,)) # 随机标签# 创建数据集和数据加载器
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 验证数据形状
sample_X, sample_y = next(iter(train_loader))
print(f"输入数据形状: {sample_X.shape}") # 应输出 torch.Size([16, 784])
print(f"标签形状: {sample_y.shape}") # 应输出 torch.Size([16])model = RegularizedModel()
train(model, train_loader, lambda_l1=1e-5, lambda_l2=1e-4)
- dropout
在训练的运行的时候,让神经元以超参数p的概率被激活(也就是1-p的概率被设置为0), 每个w因此随机参与, 使得任意w都不是不可或缺的, 效果类似于数量巨大的模型集成。
class RegularizedModel(nn.Module):def __init__(self, input_dim):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Dropout(0.5), # 添加Dropout层nn.Linear(128, 64),nn.ReLU(),nn.Dropout(0.3),nn.Linear(64, 1))def forward(self, x):return torch.sigmoid(self.net(x)).squeeze()
- early stop
提前终止训练,即在模型对训练数据集迭代收敛之前停止迭代来防止过拟合,常用的停止条件就是当 N 轮迭代都loss 都没有降低后可以停止迭代
下面案例的停止条件就是当 loss 连续 10 次都没有低于最佳 loss-0.001时就触发,这里设置了一个delta为 0.001,就能保证即使损失有波动,只要未突破阈值就计数
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset# 早停监控器(带模型保存功能)
class EarlyStopper:def __init__(self, patience=5, delta=0, path='best_model.pth'):self.patience = patience # 容忍epoch数self.delta = delta # 视为改进的最小变化量self.path = path # 最佳模型保存路径self.counter = 0 # 未改进计数器self.best_score = None # 最佳监控指标值self.early_stop = False # 停止标志def __call__(self, val_loss, model):score = -val_loss # 默认监控验证损失(越大越好)if self.best_score is None:self.best_score = scoreself.save_checkpoint(model)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter}/{self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(model)self.counter = 0def save_checkpoint(self, model):torch.save(model.state_dict(), self.path)# 生成模拟数据(回归任务)
def generate_data(samples=1000):X = torch.linspace(-10, 10, samples).unsqueeze(1)y = 0.5 * X**3 - 2 * X**2 + 3 * X + torch.randn(X.size()) * 10return X, y# 过参数化的全连接网络
class OverfitModel(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 1))def forward(self, x):return self.net(x)# 训练函数(集成早停)
def train_with_earlystop(model, train_loader, val_loader, epochs=1000):optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.MSELoss()early_stopper = EarlyStopper(patience=10, delta=0.001)train_losses = []val_losses = []for epoch in range(epochs):# 训练阶段model.train()train_loss = 0for X_batch, y_batch in train_loader:optimizer.zero_grad()pred = model(X_batch)loss = criterion(pred, y_batch)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)train_losses.append(train_loss)# 验证阶段model.eval()val_loss = 0with torch.no_grad():for X_val, y_val in val_loader:pred_val = model(X_val)val_loss += criterion(pred_val, y_val).item()val_loss /= len(val_loader)val_losses.append(val_loss)print(f'Epoch {epoch+1:03d} | 'f'Train Loss: {train_loss:.4f} | 'f'Val Loss: {val_loss:.4f}')# 早停检查early_stopper(val_loss, model)if early_stopper.early_stop:print("==> Early stopping triggered")break# 恢复最佳模型model.load_state_dict(torch.load('best_model.pth'))return train_losses, val_losses# 可视化训练过程
def plot_learning_curve(train_loss, val_loss):plt.figure(figsize=(10, 6))plt.plot(train_loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Learning Curve with Early Stopping')plt.legend()plt.grid(True)plt.show()# 主程序
if __name__ == "__main__":# 数据准备X, y = generate_data()dataset = TensorDataset(X, y)# 划分训练集和验证集(8:2)train_size = int(0.8 * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=32)# 初始化模型model = OverfitModel()# 开始训练train_loss, val_loss = train_with_earlystop(model, train_loader, val_loader, epochs=1000)# 绘制学习曲线plot_learning_curve(train_loss, val_loss)# 最终模型测试model.eval()with torch.no_grad():test_input = torch.tensor([[5.0], [-3.0], [8.0]])predictions = model(test_input)print("\nModel Predictions at x=5, -3, 8:")print(predictions.numpy())
四、模型的效果评估
具体模型的评估方法也可以看我之前的文章,本文主要补充一些代码案例【Python数据分析】数据挖掘建模——分类与预测算法评价(含ROC曲线、F1等指标的解释)_分类f1指标 python-CSDN博客
(1)分类任务评估(混淆矩阵、AUC)
from sklearn.metrics import confusion_matrix, roc_auc_scoredef evaluate_model(model, loader):model.eval()all_preds = []all_labels = []with torch.no_grad():for X_batch, y_batch in loader:outputs = model(X_batch)preds = (outputs > 0.5).float()all_preds.extend(preds.cpu().numpy())all_labels.extend(y_batch.cpu().numpy())# 计算指标cm = confusion_matrix(all_labels, all_preds)auc = roc_auc_score(all_labels, all_preds)print("Confusion Matrix:")print(cm)print(f"AUC Score: {auc:.4f}")# 在测试集上评估
evaluate_model(model, test_loader)
(2)回归任务评估(MAE、MSE)
def evaluate_regression(model, loader):model.eval()total_mae = 0total_mse = 0with torch.no_grad():for X_batch, y_batch in loader:outputs = model(X_batch)mae = torch.abs(outputs - y_batch).mean()mse = ((outputs - y_batch)**2).mean()total_mae += mae.item() * X_batch.size(0)total_mse += mse.item() * X_batch.size(0)mae = total_mae / len(loader.dataset)mse = total_mse / len(loader.dataset)print(f"MAE: {mae:.4f}, MSE: {mse:.4f}")
当然也可以把二者结合一下
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (roc_curve, auc, confusion_matrix, precision_recall_curve,r2_score,mean_squared_error
)
import seaborn as snsdef evaluate_model(model, data_loader, task_type='classification'):"""综合模型评估函数参数:model : 训练好的PyTorch模型data_loader : 数据加载器task_type : 任务类型 ['classification', 'regression']返回:包含各项指标的字典"""model.eval()device = next(model.parameters()).deviceall_targets = []all_outputs = []with torch.no_grad():for inputs, targets in data_loader:inputs = inputs.to(device)outputs = model(inputs)all_targets.append(targets.cpu().numpy())all_outputs.append(outputs.cpu().numpy())y_true = np.concatenate(all_targets)y_pred = np.concatenate(all_outputs)metrics = {}if task_type == 'classification':# 分类任务指标y_prob = torch.softmax(torch.tensor(y_pred), dim=1).numpy()y_pred_labels = np.argmax(y_pred, axis=1)# 多分类AUC计算(OvR策略)fpr = dict()tpr = dict()roc_auc = dict()n_classes = y_prob.shape[1]for i in range(n_classes):fpr[i], tpr[i], _ = roc_curve((y_true == i).astype(int), y_prob[:, i])roc_auc[i] = auc(fpr[i], tpr[i])# 计算宏观平均AUCfpr["macro"], tpr["macro"], _ = roc_curve(y_true.ravel(), y_prob.ravel())roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])metrics.update({'accuracy': np.mean(y_pred_labels == y_true),'auc_macro': roc_auc["macro"],'confusion_matrix': confusion_matrix(y_true, y_pred_labels),'classification_report': classification_report(y_true, y_pred_labels)})# 绘制ROC曲线plt.figure(figsize=(10, 6))for i in range(n_classes):plt.plot(fpr[i], tpr[i], lw=1,label=f'Class {i} (AUC = {roc_auc[i]:.2f})')plt.plot([0, 1], [0, 1], 'k--', lw=1)plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title('ROC Curves')plt.legend(loc="lower right")plt.show()elif task_type == 'regression':# 回归任务指标metrics.update({'mse': mean_squared_error(y_true, y_pred),'mae': np.mean(np.abs(y_true - y_pred)),'r2': r2_score(y_true, y_pred)})# 绘制预测值与真实值散点图plt.figure(figsize=(8, 6))plt.scatter(y_true, y_pred, alpha=0.5)plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--')plt.xlabel('True Values')plt.ylabel('Predictions')plt.title('Regression Evaluation')plt.show()return metrics# 使用示例(分类任务)
if __name__ == "__main__":# 假设已有训练好的分类模型和数据加载器from sklearn.metrics import classification_report# 加载测试数据test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 进行评估classification_metrics = evaluate_model(model, test_loader, 'classification')# 打印关键指标print(f"准确率: {classification_metrics['accuracy']:.4f}")print(f"宏观平均AUC: {classification_metrics['auc_macro']:.4f}")print("\n分类报告:")print(classification_metrics['classification_report'])# 绘制混淆矩阵plt.figure(figsize=(10, 8))sns.heatmap(classification_metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues')plt.title('Confusion Matrix')plt.xlabel('Predicted Label')plt.ylabel('True Label')plt.show()# 回归任务使用示例
# regression_metrics = evaluate_model(model, test_loader, 'regression')
# print(f"MSE: {regression_metrics['mse']:.4f}")
# print(f"R²: {regression_metrics['r2']:.4f}")
好啦 神经网络的基础就在这结束啦,之后就开始进一步讲深度学习内容(包括 RNN、LSTM、transformer 等)最后就引出大模型的原理,继续期待叭~
相关文章:
【从0到1搞懂大模型】神经网络的实现:数据策略、模型调优与评估体系(3)
一、数据集的划分 (1)按一定比例划分为训练集和测试集 我们通常取8-2、7-3、6-4、5-5比例切分,直接将数据随机划分为训练集和测试集,然后使用训练集来生成模型,再用测试集来测试模型的正确率和误差,以验证…...
CTF工具集合-持续更新
工具地址https://github.com/huan-cdm/ctf_tools工具介绍: 1.ARCHPR:压缩包密码破解工具 2.StegSolve-1.4.jar:隐写图片查看工具 3.ctf_decrypt_tool.rar:随波逐流CTF编码工具 4.010_Editor_All_Versions_For_Windows_CracKed.…...
小方摄像头接入本地服务器的方法
最早众筹时买了几个小方摄像头,后来嫌弃分辨率,就淘汰吃灰好几年,最近想折腾个摄像头识别的小项目,秉着不投入先凑合跑起来的原则,想到了尘封已久的小方,想看看能不能通过网络拉取数据流。 搜索了下&#x…...
取反符号~
取反符号 ~ 用于对整数进行按位取反操作。它会将二进制表示中的每一位取反,即 0 变 1,1 变 0。 示例 a 5 # 二进制表示为 0000 0101 b ~a # 按位取反,结果为 1111 1010(补码表示) print(b) # 输出 -6解释 5 的二…...
Jenkins实现自动化构建与部署:上手攻略
一、持续集成与Jenkins核心价值 1.1 为什么需要自动化构建? 在现代化软件开发中,团队每日面临以下挑战: 高频代码提交:平均每个开发者每天提交5-10次代码。多环境部署:开发、测试、预发布、生产环境需频繁同步。复杂…...
爱普生温补晶振 TG5032CFN高精度稳定时钟的典范
在科技日新月异的当下,众多领域对时钟信号的稳定性与精准度提出了极为严苛的要求。爱普生温补晶振TG5032CFN是一款高稳定性温度补偿晶体振荡器(TCXO)。该器件通过内置温度补偿电路,有效抑制环境温度变化对频率稳定性的影响&#x…...
【Java 面试 八股文】计算机网络篇
操作系统篇 1. 什么是HTTP? HTTP 和 HTTPS 的区别?2. 为什么说HTTPS比HTTP安全? HTTPS是如何保证安全的?3. 如何理解UDP 和 TCP? 区别? 应用场景?3.1 TCP 和 UDP 的特点3.2 适用场景 4. 如何理解TCP/IP协议?5. DNS协议 是什么?说说DNS 完整的查询…...
OpenHarmony5.0分布式系统源码实现分析—软总线
一、引言 OpenHarmony 作为一款面向万物互联的操作系统,其分布式软总线(Distributed SoftBus)是实现设备间高效通信和协同的核心技术之一。分布式软总线通过构建一个虚拟的总线网络,使得不同设备能够无缝连接、通信和协同工作。本…...
Spring Boot/Spring Cloud 整合 ELK(Elasticsearch、Logstash、Kibana)详细避坑指南
我们在开发中经常会写日志,所以需要有个日志可视化界面管理,使用ELK可以实现高效集中化的日志管理与分析,提升性能稳定性,满足安全合规要求,支持开发运维工作。 下述是我在搭建ELK时遇到的许许多多的坑,希望…...
云原生周刊:Istio 1.25.0 正式发布
开源项目推荐 Dstack Dstack 是一个开源的 AI 计算管理平台,旨在简化 AI 任务的部署和管理。它支持本地和云端运行 AI 工作负载,并提供自动化的 GPU 资源调度,使开发者能够更高效地利用计算资源。Dstack 兼容 K8s,可以无缝集成到…...
微前端如何拯救大型项目
前言 在前端开发的世界中,我们经常会遇到这样的问题:一个大型项目往往由多个团队共同开发,每个团队负责一部分功能。然而,随着项目的不断扩大和复杂化,前端代码库变得越来越庞大和难以维护。这时,微前端&a…...
RabbitMQ 高级特性:从 TTL 到消息分发的全面解析 (下)
RabbitMQ高级特性 RabbitMQ 高级特性解析:RabbitMQ 消息可靠性保障 (上)-CSDN博客 RabbitMQ 高级特性:从 TTL 到消息分发的全面解析 (下)-CSDN博客 引言 RabbitMQ 作为一款强大的消息队列中间件ÿ…...
OpenManus-通过源码方式本地运行OpenManus,含踩坑及处理方案
前言:最近 Manus 火得一塌糊涂啊,OpenManus 也一夜之间爆火,那么作为程序员应该来尝尝鲜 1、前期准备 FastGithub:如果有科学上网且能正常访问 github 则不需要下载此软件,此软件是提供国内直接访问 githubGit&#…...
Ubuntu22.04修改root用户并安装cuda
由于本人工作原因,经常会遇到需要给ubuntu打显卡驱动的问题,虽然说不难吧,但是耐不住机器多,重复多次也就烦了,于是抽出了一点时间,并且在deepseek的帮助之下,写了一个自动安装驱动的脚本&#…...
Java LeetCode 热题 100 回顾38
干货分享,感谢您的阅读!LeetCode 热题 100 回顾_力code热题100-CSDN博客 一、哈希部分 1.两数之和 (简单) 题目描述 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两…...
MySQL复习笔记
文章目录 1.MySQL1.1什么是数据库1.2 数据库分类1.3 MySQL简介1.4连接数据库 2. 操作数据库2.1 操作数据库2.2 数据库的列类型2.3 数据库的字段属性(重点)2.4 创建数据库表(重点)2.5 数据表的类型2.6 修改数据表 3. MySQL 数据管理…...
解释 TypeScript 中的类型系统,如何定义和使用类型?
1. 类型系统的核心作用 TypeScript类型系统本质上是JavaScript的静态类型增强方案,提供三个核心价值: 开发阶段类型检查(类似编译时eslint)更清晰的API文档(类型即文档)更好的IDE自动补全支持 代码示例&…...
安裝do時出現log file support is not available
“log file support is not available (press RETURN)” 这个提示信息表明日志文件支持不可用,让你按回车键继续。出现这种情况可能是因为 Odoo 的日志相关配置存在问题或者一些必要的依赖没有正确安装配置。以下是一些可以尝试的解决办法: 1. 检查 Odo…...
[HTTP协议]应用层协议HTTP从入门到深刻理解并落地部署自己的云服务(1)知识基础
[HTTP协议]应用层协议HTTP从入门到深刻理解并落地部署自己的云服务(1)知识基础 水墨不写bug 文章目录 (一)概念梳理1.什么是协议?2.什么是应用层?3. 为什么要进行分层? (二)HTTP协议2.1 初识HTTP协议2.2HTTP协议的URL2.2.1域名2.2.2端口号2…...
机票改签请求
示例代码: tool def update_ticket_to_new_flight(ticket_no: str, new_flight_id: int) -> str:"""Update the users ticket to a new valid flight.Args:ticket_no (str): The ticket number to be updated.new_flight_id (int): The ID of th…...
linux下文件读写操作
Linux下,文件I/O是操作系统与文件系统之间进行数据传输的关键部分。文件I/O操作允许程序读取和写入文件,管理文件的打开、关闭、创建和删除等操作。 1. 文件描述符 在Linux中,每个打开的文件都由一个文件描述符来表示。文件描述符是一个非负…...
命名管道的创建和通信实现
目录 命名管道的创建 使用函数创建命名管道的通信 预备创建 makefile设计 server.hpp设计 clent.hpp设计 comm.hpp设计 server.cc设计 clent.cc设计 测试运行 今天我们来学习命名管道 由于匿名管道(pipe())无法在两个毫不相干的进程之间进行通…...
C++和OpenGL实现3D游戏编程【连载24】——父物体和子物体之间的坐标转换
欢迎来到zhooyu的C++和OpenGL游戏专栏,专栏连载的所有精彩内容目录详见下边链接: 🔥C++和OpenGL实现3D游戏编程【总览】 父子物体的坐标转换 1、本节要实现的内容 前面章节我们了解了父物体与子物体的结构,它不仅能够表示物体之间的层次关系,更重要的一个作用就是展示物…...
21.HarmonyOS Next CustomSlider组件步长控制教程(三)
温馨提示:本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦! 文章目录 1. 步长控制概述2. 步长基本概念2.1 什么是步长?2.2 步长的作用 3. 设置步长3.1 基本参数3.2 代码示例 4. 步长与范围的关系4…...
小白学习:rag向量数据库
学习视频: https://www.bilibili.com/video/BV11zf6YyEnT/?spm_id_from333.337.search-card.all.click 例子: 用户提出问题 客服机器人基于rag回答用户问题 过程拆解: 客户问题 – 转化为向量表示 – 在向量数据库中进行相似性搜索 – 系…...
STM32 CAN模块原理与应用详解
目录 概述 一、CAN模块核心原理 1. CAN协议基础 2. STM32 CAN控制器结构 3. 波特率配置 二、CAN模块配置步骤(基于HAL库) 1. 初始化CAN外设 2. 配置过滤器 3. 启动CAN通信 三、数据收发实现 1. 发送数据帧 2. 接收数据帧(中断方式…...
NO.29十六届蓝桥杯备战|string九道练习|reverse|翻转|回文(C++)
P5015 [NOIP 2018 普及组] 标题统计 - 洛谷 #include <bits/stdc.h> using namespace std;int main() {ios::sync_with_stdio(false);cin.tie(nullptr);string s;getline(cin, s);int sz s.size();int cnt 0;for (int i 0; i < sz; i){if (isspace(s[i]))continue…...
最新版本TOMCAT+IntelliJ IDEA+MAVEN项目创建(JAVAWEB)
前期所需: 1.apache-tomcat-10.1.18-windows-x64(tomcat 10.1.8版本或者差不多新的版本都可以) 2.IntelliJ idea 24年版本 或更高版本 3.已经配置好MAVEN了(一定先配置MAVEN再搞TOMCAT会事半功倍很多) 如果有没配置…...
MAC-禁止百度网盘自动升级更新
通过终端禁用更新服务(推荐) 此方法直接移除百度网盘的自动更新组件,无需修改系统文件。 步骤: 1.关闭百度网盘后台进程 按下 Command + Space → 输入「活动监视器」→ 搜索 BaiduNetdisk 或 UpdateAgent → 结束相关进程。 2.删除自动更新配置文件 打开终端…...
Unity DOTS从入门到精通之EntityCommandBufferSystem
文章目录 前言安装 DOTS 包ECBECB可以执行的指令示例: 前言 DOTS(面向数据的技术堆栈)是一套由 Unity 提供支持的技术,用于提供高性能游戏开发解决方案,特别适合需要处理大量数据的游戏,例如大型开放世界游…...
【AIGC系列】6:HunyuanVideo视频生成模型部署和代码分析
AIGC系列博文: 【AIGC系列】1:自编码器(AutoEncoder, AE) 【AIGC系列】2:DALLE 2模型介绍(内含扩散模型介绍) 【AIGC系列】3:Stable Diffusion模型原理介绍 【AIGC系列】4࿱…...
【Linux】使用问题汇总
#1 ssh连接的时候报Key exchange failed 原因:服务端版本高,抛弃了一些不安全的交换密钥算法,且客户端版本比较旧,不支持安全性较高的密钥交换算法。 解决方案: 如果是内网应用,安全要求不这么高…...
nnUNet V2修改网络——全配置替换MultiResBlock模块
更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 MultiRes Block 是 MultiResUNet 中核心组件之一,旨在解决传统 U-Net 在处理多尺度医学图像时的局…...
Git合并工具在开发中的使用指南
在团队协作开发中,Git 是最常用的版本控制工具,而代码合并(Merge)是多人协作不可避免的环节。当多个开发者同时修改同一文件的相同区域时,Git 无法自动完成合并,此时需要借助合并工具(Merge Too…...
AutoDL平台租借GPU,创建transformers环境,使用VSCode SSH登录
AutoDL平台租借GPU,创建transformers环境,使用VSCode SSH登录 一、AutoDl平台租用GPU 1.注册并登录AutoDl官网:https://www.autodl.com/home 2.选择算力市场,找到需要的GPU: 我这里选择3090显卡 3.这里我们就选择P…...
listen EACCES: permission denied 0.0.0.0:811
具体错误 npm run serve> bige-v0.0.0 serve > viteThe CJS build of Vites Node API is deprecated. See https://vitejs.dev/guide/troubleshooting.html#vite-cjs-node-api-deprecated for more details. error when starting dev server: Error: listen EACCES: per…...
OpenAI API模型ChatGPT各模型功能对比,o1、o1Pro、GPT-4o、GPT-4.5调用次数限制附ChatGPT订阅教程
本文包含OpenAI API模型对比页面以及ChatGPT各模型功能对比表 - 截至2025最新整理数据:包含模型分类及描述;调用次数限制; 包含模型的类型有: Chat 模型(如 GPT-4o、GPT-4.5、GPT-4)专注于对话,…...
六十天前端强化训练之第十五天React组件基础案例:创建函数式组件展示用户信息(第15-21天:前端框架(React))
欢迎来到编程星辰海的博客讲解 我们已经学了14天了,再坚持坚持,马上我们就可以变得更优秀了,加油,我相信大家,接下来的几天,我会给大家更新前端框架(React),看完可以给一…...
北大一二三四版全套DeepSeek教学资料
DeepSeek学习资料合集:https://pan.quark.cn/s/bb6ebf0e9b4d DeepSeek实操变现指南:https://pan.quark.cn/s/76328991eaa2 你是否渴望深入探索人工智能的前沿领域?是否在寻找一份能引领你从理论到实践,全面掌握AI核心技术的学习…...
计算机网络:计算机网络的组成和功能
计算机网络的组成: 计算机网络的工作方式: 计算机网络的逻辑功能; 总结: 计算机网络的功能: 1.数据通信 2.资源共享 3.分布式处理:计算机网络的分布式处理是指将计算任务分散到网络中的多个节点(计算机或设备&…...
管理网络安全
防火墙在 Linux 系统安全中有哪些重要的作用? 防火墙作为网络安全的第一道防线,能够根据预设的规则,对进出系统的网络流量进行严格筛选。它可以阻止未经授权的外部访问,只允许符合规则的流量进入系统,从而保护系统免受…...
音频进阶学习十九——逆系统(简单进行回声消除)
文章目录 前言一、可逆系统1.定义2.解卷积3.逆系统恢复原始信号过程4.逆系统与原系统的零极点关系 二、使用逆系统去除回声获取原信号的频谱原系统和逆系统幅频响应和相频响应使用逆系统恢复原始信号整体代码如下 总结 前言 在上一篇音频进阶学习十八——幅频响应相同系统、全…...
Redis7系列:设置开机自启
前面的文章讲了Redis和Redis Stack的安装,随着服务器的重启,导致Redis 客户端无法连接。原来的是Redis没有配置开机自启。此文记录一下如何配置开机自启。 1、修改配置文件 前面的Redis和Redis Stack的安装的文章中已经讲了redis.config的配置…...
word甲烷一键下标
Sub 甲烷下标()甲烷下标 宏Selection.Find.ClearFormattingSelection.Find.Replacement.ClearFormattingWith Selection.Find.Text "CH4".Replacement.Text "CHguoshao4".Forward True.Wrap wdFindContinue.Format False.MatchCase False.MatchWhole…...
SSH 连接中主机密钥验证失败问题的解决方法
问题描述 在尝试通过 SSH 建立连接时,出现以下错误信息: WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now (man-in-the-middle attack…...
网络安全工具nc(NetCat)
NetCat是一个非常简单的Unix工具,可以读、写TCP或UDP网络连接(network connection)。它被设计成一个可靠的后端(back-end)工具,能被其它的程序程序或脚本直接地或容易地驱动。同时,它又是一个功能丰富的 网络调试和开发工具,因为它…...
探索在生成扩散模型中基于RAG增强生成的实现与未来
概述 像 Stable Diffusion、Flux 这样的生成扩散模型,以及 Hunyuan 等视频模型,都依赖于在单一、资源密集型的训练过程中通过固定数据集获取的知识。任何在训练之后引入的概念——被称为 知识截止——除非通过 微调 或外部适应技术(如 低秩适…...
【Linux】37.网络版本计算器
文章目录 1. Log.hpp-日志记录器2. Daemon.hpp-守护进程工具3. Protocol.hpp-通信协议解析器4. ServerCal.hpp-计算器服务处理器5. Socket.hpp-Socket通信封装类6. TcpServer.hpp-TCP服务器框架7. ClientCal.cc-计算器客户端8. ServerCal.cc-计算器服务器9. 代码时序1. 服务器启…...
3.6c语言
#define _CRT_SECURE_NO_WARNINGS #include <math.h> #include <stdio.h> int main() {int sum 0,i,j;for (j 1; j < 1000; j){sum 0;for (i 1; i < j; i){if (j % i 0){sum i;} }if (sum j){printf("%d是完数\n", j);}}return 0; }#de…...
【 IEEE出版 | 快速稳定EI检索 | 往届已EI检索】2025年储能及能源转换国际学术会议(ESEC 2025)
重要信息 主会官网:www.net-lc.net 【论文【】投稿】 会议时间:2025年5月9-11日 会议地点:中国-杭州 截稿时间:见官网 提交检索:IEEE Xplore, EI Compendex, Scopus 主会NET-LC 2025已进入IEEE 会议官方列表!&am…...