当前位置: 首页 > news >正文

RNN实现精神分裂症患者诊断(pytorch)

RNN理论知识

RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN 具有“记忆”能力,能够利用过去的信息来影响当前的计算结果。

1. RNN 的基本结构

RNN 的核心特点是 “循环”结构,它会将前一个时间步 ( t − 1 ) (t-1) t1计算出的隐藏状态 h t − 1 h_{t-1} ht1 传递给当前时间步 ( t ) (t) t,使得网络可以保留历史信息。

这种结构可以表示为:

h t = f ( W x X t + W h h t − 1 + b ) h_t=f(W_xX_t+W_hh_{t-1}+b) ht=f(WxXt+Whht1+b)

其中:

  • X t X_t Xt:当前时刻的输入数据。
  • h t h_t ht:当前时刻的隐藏状态 。
  • W x 、 W h 、 b W_x、W_h、b WxWhb:可训练的参数 。
  • f f f:激活函数(通常是 tanh 或ReLU)。

RNN 的展开结构:
在时间步(time step)上,RNN 结构可以展开成如下形式:
在这里插入图片描述
图示解释:

X 1 , X 2 , X 3 , . . . X_1,X_2,X_3,... X1,X2,X3,... 代表输入的 序列数据(如文本、时间序列信号)。
h 0 , h 1 , h 2 , h 3 , . . . h_0,h_1,h_2,h_3,... h0,h1,h2,h3,... 代表 隐藏状态,用于存储过去的信息。
Y 1 , Y 2 , Y 3 , . . . Y_1,Y_2,Y_3,... Y1,Y2,Y3,...代表 输出。
在每个时间步,RNN 使用当前输入 X t X_t Xt 和前一时刻的隐藏状态 h t − 1 h_{t-1} ht1来计算新的隐藏状态 h t h_t ht,然后生成输出 Y t Y_t Yt

2. RNN 的缺点

尽管 RNN 在处理序列数据方面有独特的优势,但它也存在一些明显的问题:
(1)梯度消失(Vanishing Gradient)
在长序列训练时,误差的梯度会随着时间步增多而逐渐变小,导致网络无法有效学习较远时间步的信息。
解决方案:使用 LSTM(长短时记忆网络) 或 GRU(门控循环单元) 结构。
(2)梯度爆炸(Exploding Gradient)
如果梯度在反向传播过程中不断累积,可能会变得 非常大,导致模型更新过快或无法收敛。
解决方案:使用 梯度裁剪(Gradient Clipping) 来防止梯度过大。
(3)无法并行计算
由于 RNN 依赖前一个时间步的计算结果,因此无法像 CNN 那样并行计算,这导致训练速度较慢。
解决方案:使用 Transformer 模型(如 BERT、GPT)来替代 RNN。

3. RNN 的改进版本

由于 RNN 存在梯度消失等问题,研究人员提出了更强大的 变种 RNN 结构:
(1)LSTM(Long Short-Term Memory)
在这里插入图片描述

  • LSTM 引入了 “记忆单元” 和 “门机制”,使得它能够保留长期信息,解决梯度消失问题。
  • 包含 遗忘门(Forget Gate)、输入门(Input Gate)、输出门(Output Gate) 三部分来控制信息流。

(2)GRU(Gated Recurrent Unit)

  • GRU 是 LSTM 的简化版本,只包含 更新门(Update Gate) 和 重置门(Reset Gate),计算效率更高。

数据集

精神分裂症数据集,是一个包含精神分裂症人口统计和临床数据的综合数据集。该数据集包括患者的诊断状态、症状评分、治疗史和社会因素。

代码目标

基于给定的特征(如性别、年龄、收入、症状评分等),预测一个人的诊断标签(是否患有精神分裂症),通过可视化训练损失和计算准确率,评估模型的训练效果与性能。

一、前期准备工作

我的环境:

  • 操作系统:windows10
  • 语言环境:Python3.9
  • 编译器:Jupyter notebook
  • 数据集:精神分裂症患者数据集(“schizophrenia_dataset.csv”)

1. 导入库,设置硬件设备

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

代码输出:

device(type='cpu')

使用 torch.device() 方法检查当前系统是否有 GPU,并根据条件设置计算设备为 GPU(CUDA)或 CPU。

2. 导入数据

读取指定路径的 CSV 文件,并加载到 pandas 的 DataFrame 中,然后打印出数据框的前五行,用于检查数据的内容。

# 读取数据
file_path = 'schizophrenia_dataset.csv'     # 设置数据文件的路径
df = pd.read_csv(file_path)                 # 使用pandas的read_csv函数读取CSV文件,结果存储在DataFrame对象df中
print(df.head())            # 打印数据框的前五行,检查数据的结构和内容

代码输出:

   Hasta_ID  Yaş  Cinsiyet  Eğitim_Seviyesi  Medeni_Durum  Meslek  \
0         1   72         1                4             2       0   
1         2   49         1                5             2       2   
2         3   53         1                5             3       2   
3         4   67         1                3             2       0   
4         5   54         0                1             2       0   Gelir_Düzeyi  Yaşadığı_Yer  Tanı  Hastalık_Süresi  Hastaneye_Yatış_Sayısı  \
0             2             1     0                0                       0   
1             1             0     1               35                       1   
2             1             0     1               32                       0   
3             2             0     0                0                       0   
4             2             1     0                0                       0   Ailede_Şizofreni_Öyküsü  Madde_Kullanımı  İntihar_Girişimi  \
0                        0                0                 0   
1                        1                1                 1   
2                        1                0                 0   
3                        0                1                 0   
4                        0                0                 0   Pozitif_Semptom_Skoru  Negatif_Semptom_Skoru  GAF_Skoru  Sosyal_Destek  \
0                     32                     48         72              0   
1                     51                     63         40              2   
2                     72                     85         51              0   
3                     10                     21         74              1   
4                      4                     27         98              0   Stres_Faktörleri  İlaç_Uyumu  
0                 2           2  
1                 2           0  
2                 1           1  
3                 1           2  
4                 1           0  

二、构建数据集

1. 划分数据集

处理数据中的不必要列(唯一标识符)和缺失值,以准备好干净的数据进行模型训练。

df = df.drop(columns=['Hasta_ID'])      # 删除 'Hasta_ID' 列,因为该列是唯一标识符,不需要用作模型输入
df = df.fillna(df.mean())      # 使用每一列的均值填充数据框中的缺失值。这里使用 `df.mean()` 来计算均值,并用它来填充缺失值

数据处理流程:

  • 使用 LabelEncoder 将类别变量转换为数值。
  • 将数据划分为特征(X)和目标(y)。
  • 标准化特征数据。
  • 将数据划分为训练集和测试集。
  • 将数据转换为 PyTorch 张量。
  • 调整张量维度以符合 RNN 模型的要求。
label_encoder = LabelEncoder()     # 创建LabelEncoder实例,用于将类别变量转换为数值
df['Cinsiyet'] = label_encoder.fit_transform(df['Cinsiyet'])       # 将 'Cinsiyet'列中的类别值转化为数值
df['Medeni_Durum'] = label_encoder.fit_transform(df['Medeni_Durum'])     # 将 'Medeni_Durum'列中的类别值转化为数值
df['Yaşadığı_Yer'] = label_encoder.fit_transform(df['Yaşadığı_Yer'])     # 将 'Yaşadığı_Yer'列中的类别值转化为数值# 将特征和目标分开
X = df.drop(columns=['Tanı'])     # 将数据框中的 'Tanı' 列移除,剩下的列作为特征(X)
y = df['Tanı']      # 'Tanı' 列作为目标变量(y),表示是否患有精神分裂症(二分类标签)scaler = StandardScaler()     # 创建 StandardScaler 实例,用于标准化特征数据
X_scaled = scaler.fit_transform(X)     # 对特征进行标准化,使得每列的均值为0,标准差为1# 使用 train_test_split 将数据随机划分为训练集和测试集,测试集占20%。random_state=42 设置随机种子,以确保每次划分结果相同
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 将数据转换为PyTorch的tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)        # 将训练特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)          # 将测试特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long)    # 将训练目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)      # 将测试目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)# 确保数据的形状符合RNN的要求: [batch_size, seq_len, features]
X_train_tensor = X_train_tensor.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]
X_test_tensor = X_test_tensor.unsqueeze(1)    # [batch_size, features] --> [batch_size, 1, features]# 输出tensor的形状,确保数据正确
print(f"训练数据形状: {X_train_tensor.shape}")     # 打印训练数据的形状,检查是否正确
print(f"测试数据形状: {X_test_tensor.shape}")      # 打印测试数据的形状,检查是否正确

代码输出:

训练数据形状: torch.Size([8000, 1, 18])
测试数据形状: torch.Size([2000, 1, 18])

2. 构建数据加载器

将训练集和测试集的数据(特征和标签)封装成 TensorDataset 对象,并使用 DataLoader 创建数据加载器。
训练集和测试集被分批次加载,每个批次包含 64 个样本。
shuffle=False 表示数据在加载时不进行打乱,在评估的时候顺序保持一致。

from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train_tensor, y_train_tensor),     # 将训练数据、目标数据包装成一个数据集,并创建一个训练数据加载器batch_size=64, shuffle=False)test_dl  = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),      # 将测试数据、目标数据包装成一个数据集,并创建一个测试数据加载器shuffle=False)

三、模型训练

1. 构建模型

import torch.nn as nn#定义一个名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构
class _RNN_Base(nn.Module):def __init__(self, c_in, c_out, hidden_size=100, n_layers=1, bias=True, rnn_dropout=0, bidirectional=False, fc_dropout=0., init_weights=True):"""RNN基础类,支持不同RNN单元(如RNN、LSTM、GRU)的实现。"""super(_RNN_Base, self).__init__()  # 确保正确调用父类的构造函数# 定义RNN层,支持RNN、LSTM、GRU等self.rnn = self._cell(c_in, hidden_size, num_layers=n_layers, bias=bias, batch_first=True, dropout=rnn_dropout, bidirectional=bidirectional)# 定义全连接层的dropout,如果fc_dropout为0则直接用Identityself.dropout = nn.Dropout(fc_dropout) if fc_dropout else nn.Identity()self.fc = nn.Linear(hidden_size * (1 + bidirectional), c_out)def forward(self, x): """        参数:- x: 形状为[batch_size, n_vars, seq_len]。返回:- output: 形状为[batch_size, c_out]。"""# [batch_size, n_vars, seq_len] --> [batch_size, seq_len, n_vars]x = x.transpose(2,1)  # 输出形状为[batch_size, seq_len, hidden_size * (1 + bidirectional)]output, _ = self.rnn(x) # 取最后一个时间步的输出,形状为[batch_size, hidden_size * (1 + bidirectional)]output = output[:, -1]  output = self.fc(self.dropout(output))return output# 定义RNN类,继承自_RNN_Base
class RNN(_RNN_Base):_cell = nn.RNN  # 使用nn.RNN单元# 定义LSTM类,继承自_RNN_Base
class LSTM(_RNN_Base):_cell = nn.LSTM  # 使用nn.LSTM单元# 定义GRU类,继承自_RNN_Base
class GRU(_RNN_Base):_cell = nn.GRU  # 使用nn.GRU单元

定义名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构。

_RNN_Base 类的参数解释:

  • c_in:输入特征的维度,即每个时间步的特征数量。
  • c_out:输出类别数量,即模型的输出维度。
  • hidden_size:RNN隐藏层的大小。
  • n_layers:RNN的层数。
  • bias:是否在RNN层中使用偏置项。
  • rnn_dropout:RNN层中的dropout比例。
  • bidirectional:是否使用双向RNN。
  • fc_dropout:全连接层的dropout比例。
  • init_weights:是否初始化权重。

关于_cell ,定义 RNN 层。self._cell 是一个占位符,它将会被具体子类(RNN、LSTM、GRU)的 _cell 属性替代,相关参数解释:

  • c_in:输入特征的数量。
  • hidden_size:RNN单元的隐藏层大小。
  • num_layers:RNN的层数。
  • bias:是否使用偏置项。
  • batch_first=True:意味着输入和输出的格式为 [batch_size, seq_len,features]。
  • dropout=rnn_dropout:RNN中dropout的概率,用来防止过拟合。
  • bidirectional=bidirectional:是否使用双向RNN(即处理序列时同时考虑正向和反向的时间步)。
# 创建一个基于 RNN 的神经网络模型,并将模型移动到指定的设备(CPU 或 GPU)
model = RNN(c_in=X_train_tensor.shape[1], c_out=2).to(device)    
model 

代码输出:

RNN((rnn): RNN(1, 100, batch_first=True)(dropout): Identity()(fc): Linear(in_features=100, out_features=2, bias=True)
)
from torchinfo import summaryrnn_model = RNN(c_in=3, c_out=5, hidden_size=100,n_layers=2,bidirectional=True, rnn_dropout=.5, fc_dropout=.5)    # 初始化一个 RNN 模型,并设置相关参数summary(rnn_model, input_size=(16, 3, 5))    # 调用 summary 函数,输出 rnn_model 的结构和每一层的详细信息

代码输出:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
RNN                                      --                        --
├─RNN: 1-1                               [16, 5, 200]              81,400
├─Dropout: 1-2                           [16, 200]                 --
├─Linear: 1-3                            [16, 5]                   1,005
==========================================================================================
Total params: 82,405
Trainable params: 82,405
Non-trainable params: 0
Total mult-adds (M): 6.53
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.13
Params size (MB): 0.33
Estimated Total Size (MB): 0.46
==========================================================================================

2. 定义训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取数据及其标签X, y = X.to(device), y.to(device)# 1. 确保输入数据有三个维度,添加一个seq_len维度if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]# 2. 前向传播pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的损失# 3. 反向传播optimizer.zero_grad()  # 清零梯度loss.backward()        # 反向传播optimizer.step()       # 更新参数# 记录准确率和损失train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3. 定义测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)   # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)# 1. 确保输入数据有三个维度,添加一个seq_len维度if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]# 2. 计算损失pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4. 正式训练模型

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 2e-5   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)    # 使用 Adam 优化器,并将学习率 learn_rate 应用到优化器中
epochs     = 20     # 设置训练的总轮数为 20。每轮训练都将通过整个训练集一次train_loss = []  # 初始化一个空列表用于记录每一轮的训练损失
train_acc  = []  # 初始化一个空列表用于记录每一轮的训练准确率
test_loss  = []  # 初始化一个空列表用于记录每一轮的测试损失
test_acc   = []  # 初始化一个空列表用于记录每一轮的测试准确率# 循环遍历训练轮数
for epoch in range(epochs):model.train()    # 设置模型为训练模式epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()    # 设置模型为评估模式epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)  # 将当前训练轮的准确率添加到列表中train_loss.append(epoch_train_loss)  # 将当前训练轮的损失添加到列表中test_acc.append(epoch_test_acc)  # 将当前测试轮的准确率添加到列表中test_loss.append(epoch_test_loss)  # 将当前测试轮的损失添加到列表中# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']# 格式化输出每一轮训练和测试的准确率、损失以及当前学习率template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)

代码输出:

Epoch: 1, Train_acc:70.1%, Train_loss:0.665, Test_acc:70.9%, Test_loss:0.636, Lr:2.00E-05
Epoch: 2, Train_acc:71.4%, Train_loss:0.596, Test_acc:70.3%, Test_loss:0.558, Lr:2.00E-05
Epoch: 3, Train_acc:72.7%, Train_loss:0.507, Test_acc:80.2%, Test_loss:0.442, Lr:2.00E-05
Epoch: 4, Train_acc:90.8%, Train_loss:0.337, Test_acc:95.7%, Test_loss:0.259, Lr:2.00E-05
Epoch: 5, Train_acc:95.9%, Train_loss:0.212, Test_acc:96.4%, Test_loss:0.179, Lr:2.00E-05
Epoch: 6, Train_acc:96.0%, Train_loss:0.161, Test_acc:96.4%, Test_loss:0.146, Lr:2.00E-05
Epoch: 7, Train_acc:96.2%, Train_loss:0.137, Test_acc:96.7%, Test_loss:0.128, Lr:2.00E-05
Epoch: 8, Train_acc:96.5%, Train_loss:0.121, Test_acc:96.7%, Test_loss:0.116, Lr:2.00E-05
Epoch: 9, Train_acc:96.6%, Train_loss:0.110, Test_acc:96.8%, Test_loss:0.107, Lr:2.00E-05
Epoch:10, Train_acc:96.8%, Train_loss:0.103, Test_acc:96.7%, Test_loss:0.100, Lr:2.00E-05
Epoch:11, Train_acc:96.9%, Train_loss:0.097, Test_acc:96.7%, Test_loss:0.095, Lr:2.00E-05
Epoch:12, Train_acc:96.9%, Train_loss:0.092, Test_acc:96.7%, Test_loss:0.091, Lr:2.00E-05
Epoch:13, Train_acc:97.0%, Train_loss:0.089, Test_acc:96.8%, Test_loss:0.088, Lr:2.00E-05
Epoch:14, Train_acc:97.1%, Train_loss:0.085, Test_acc:96.9%, Test_loss:0.084, Lr:2.00E-05
Epoch:15, Train_acc:97.2%, Train_loss:0.082, Test_acc:97.0%, Test_loss:0.081, Lr:2.00E-05
Epoch:16, Train_acc:97.3%, Train_loss:0.078, Test_acc:97.0%, Test_loss:0.077, Lr:2.00E-05
Epoch:17, Train_acc:97.4%, Train_loss:0.075, Test_acc:97.2%, Test_loss:0.073, Lr:2.00E-05
Epoch:18, Train_acc:97.5%, Train_loss:0.071, Test_acc:97.4%, Test_loss:0.070, Lr:2.00E-05
Epoch:19, Train_acc:97.6%, Train_loss:0.068, Test_acc:97.5%, Test_loss:0.065, Lr:2.00E-05
Epoch:20, Train_acc:97.9%, Train_loss:0.063, Test_acc:97.9%, Test_loss:0.061, Lr:2.00E-05
==================== Done ====================

四、模型评估

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 200        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))   # 创建一个新的图表,并设置图表的大小
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')   # 绘制训练准确率曲线
plt.plot(epochs_range, test_acc, label='Test Accuracy')    # 绘制测试准确率曲线
plt.legend(loc='lower right')      # 显示图例,位置为右下角
plt.title('Training and Validation Accuracy')     # 设置子图的标题
plt.xlabel(current_time)    # 将当前时间作为横坐标标签plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')   # 绘制训练损失曲线
plt.plot(epochs_range, test_loss, label='Test Loss')    # 绘制测试损失曲线
plt.legend(loc='upper right')     # 显示图例,位置为右上角
plt.title('Training and Validation Loss')     # 设置子图的标题
plt.show()    # 显示图表

代码输出:

在这里插入图片描述

2. 混淆矩阵

混淆矩阵(Confusion Matrix) 是一种常用的分类模型评估工具,特别适用于 二分类 和 多分类问题。它能够清晰地展示模型的 真实类别(True Labels) 与 预测类别(Predicted Labels) 之间的对应关系,深入分析模型的分类性能。

# 确保输入数据的维度为 [batch_size, seq_len, features]
print("==============输入数据Shape为==============")
print("X_test.shape:", X_test_tensor.shape)
print("y_test.shape:", y_test_tensor.shape)# 获取预测结果
pred = model(X_test_tensor.to(device)).argmax(1).cpu().numpy()print("\n==============输出数据Shape为==============")
print("pred.shape:", pred.shape)

代码输出:

==============输入数据Shape为==============
X_test.shape: torch.Size([2000, 1, 18])
y_test.shape: torch.Size([2000])==============输出数据Shape为==============
pred.shape: (2000,)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))    # 创建一个新的图形,设置图形的大小为 6x5 英寸
plt.suptitle('')     # 设置图形的总标题,这里设置为空字符串 '',即不显示总标题
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")    # 使用 seaborn 的热力图函数绘制混淆矩阵# 修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)# 显示图
plt.tight_layout()  # 调整布局防止重叠
plt.show()

代码输出:

在这里插入图片描述

3. 调用模型进行预测

# 选择单个样本并调整形状为 [batch_size, seq_len, features] 
test_X = X_test_tensor[0].reshape(1, 1, -1)  # 注意这里调整为三维的 [1, 1, features] # 获取模型的预测结果
pred = model(test_X.to(device)).argmax(1).item()print("模型预测结果为:", pred)
print("==" * 20)
print("0:未患病")
print("1:已患病")

代码输出:

模型预测结果为: 0
========================================
0:未患病
1:已患病

相关文章:

RNN实现精神分裂症患者诊断(pytorch)

RNN理论知识 RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN…...

Python中字符串的常用操作

一、r原样输出 在 Python 中,字符串前加 r(即 r"string" 或 rstring)表示创建一个原始字符串(raw string)。下面详细介绍原始字符串的特点、使用场景及与普通字符串的对比。 特点 忽略转义字符&#xff1…...

uniapp 本地数据库多端适配实例(根据运行环境自动选择适配器)

项目有个需求,需要生成app和小程序,app支持离线数据库,如果当前没有网络提醒用户开启离线模式,所以就随便搞了下,具体的思路就是: 一个接口和多个实现类(类似后端的模板设计模式)&am…...

Spring Cloud Gateway 整合Spring Security

做了一个Spring Cloud项目,网关采用 Spring Cloud Gateway,想要用 Spring Security 进行权限校验,由于 Spring Cloud Gateway 采用 webflux ,所以平时用的 mvc 配置是无效的,本文实现了 webflu 下的登陆校验。 1. Sec…...

【异地访问本地DeepSeek】Flask+内网穿透,轻松实现本地DeepSeek的远程访问

写在前面:本博客仅作记录学习之用,部分图片来自网络,如需引用请注明出处,同时如有侵犯您的权益,请联系删除! 文章目录 前言依赖Flask构建本地网页访问LM Studio 开启网址访问DeepSeek 调用模板Flask 访问本…...

Windows对比MacOS

Windows对比MacOS 文章目录 Windows对比MacOS1-环境变量1-Windows添加环境变量示例步骤 1:打开环境变量设置窗口步骤 2:添加系统环境变量 2-Mac 系统添加环境变量示例步骤 1:打开终端步骤 2:编辑环境变量配置文件步骤 3&#xff1…...

React实现无缝滚动轮播图

实现效果: 由于是演示代码,我是直接写在了App.tsx里面在 文件位置如下: App.tsx代码如下: import { useState, useEffect, useCallback, useRef } from "react"; import { ImageContainer } from "./view/ImageC…...

Ubuntu20.04确认cuda和cudnn已经安装成功

当我们通过官网安装cuda和cudnn时,终端执行完命令后我们仍不能确定是否已经安装成功。接下来教大家用几句命令测试。 cuda 检测版本号 nvcc -V如果输出如下,则安装成功。 可以看到版本号是11.2 cudnn检测版本号 有两种命令:如果你的cudn…...

sqlilab 46 关(布尔、时间盲注)

sqlilabs 46关(布尔、时间盲注) 46关有变化了,需要我们输入sort,那我们就从sort1开始 递增测试: 发现测试到sort4就出现报错: 我们查看源码: 从图中可看出:用户输入的sort值被用于查…...

AI时代保护自己的隐私

人工智能最重要的就是数据,让我们面对现实,大多数人都不知道他们每天要向人工智能提供多少数据。你输入的每条聊天记录,你发出的每条语音命令,人工智能生成的每张图片、电子邮件和文本。我建设了一个网站(haptool.com)&#xff0c…...

模型优化之强化学习(RL)与监督微调(SFT)的区别和联系

强化学习(RL)与监督微调(SFT)是机器学习中两种重要的模型优化方法,它们在目标、数据依赖、应用场景及实现方式上既有联系又有区别。 想了解有关deepseek本地训练的内容可以看我的文章: 本地基于GGUF部署的…...

Buildroot 添加自定义模块-内置文件到文件系统

目录 概述实现步骤1. 创建包目录和文件结构2. 配置 Config.in3. 定义 cp_bin_files.mk4. 添加源文件install.shmy.conf 5. 配置与编译 概述 Buildroot 是一个高度可定制和模块化的嵌入式 Linux 构建系统,适用于从简单到复杂的各种嵌入式项目. buildroot的源码中bui…...

蓝牙接近开关模块感应开锁手机靠近解锁支持HID低功耗

ANS-BT101M是安朔科技推出的蓝牙接近开关模块,低功耗ble5.1,采用UART通信接口,实现手机自动无感连接,无需APP,人靠近车门自动开锁,支持苹果、安卓、鸿蒙系统,也可以通过手机手动开锁或上锁&…...

计算机毕业设计SpringBoot+Vue.js基于工程教育认证的计算机课程管理平台(源码+文档+PPT+讲解)

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…...

企业知识库搭建:14款开源与免费系统选择

本文介绍了以下14 款知识库管理系统:1.Worktile;2.PingCode;3.石墨文档; 4. 语雀; 5. 有道云笔记; 6. Bitrix24; 7. Logseq等。 在如今的数字化时代,企业和团队面临着越来越多的信息…...

蓝桥杯(握手问题)

小蓝组织了一场算法交流会议,总共有 50 人参加了本次会议。在会议上,大家进行了握手交流。按照惯例他们每个人都要与除自己以外的其他所有人进行一次握手 (且仅有一次)。 但有 7个人,这 7 人彼此之间没有进行握手 (但这 7 人与除这 7 人以外…...

如何使用 Jenkins 实现 CI/CD 流水线:从零开始搭建自动化部署流程

如何使用 Jenkins 实现 CI/CD 流水线:从零开始搭建自动化部署流程 在软件开发过程中,持续集成(CI)和持续交付(CD)已经成为现代开发和运维的标准实践。随着代码的迭代越来越频繁,传统的手动部署方式不仅低效,而且容易出错。为了提高开发效率和代码质量,Jenkins作为一款…...

c++字符编码/乱码问题

基本概念 c11版本引入了char16_t和char32_t两个类型,他们的特点分别如下: char16_t 16位的unicode字符类型用于表示UTF-16编码大小:2字节字面量前缀:u char32_t 32位unicode字符类型用于表示UTF-32编码大小:4字节…...

侯捷 C++ 课程学习笔记:深入理解类与继承

文章目录 每日一句正能量一、课程背景二、学习内容:类与继承(一)类的基本概念1. 类的定义与实例化2. 构造函数与析构函数 (二)继承1. 单继承与多继承2. 虚函数与多态 三、学习心得四、总结 每日一句正能量 有种承担&am…...

初始化列表

一:声明,定义,赋值的区别 ①:声明 这里,int _year; int _month;int _day; 是成员变量的声明,它们告诉编译器: 类 Date中有三个成员变量_year和 _month和_day。 它们的类型分别都是 int 此…...

7.1 - 定时器之中断控制LED实验

文章目录 1 实验任务2 系统框图3 软件设计 1 实验任务 本实验任务是通过CPU私有定时器的中断,每 200ms 控制一次PS LED灯的亮灭。 2 系统框图 3 软件设计 注意事项: 定时器中断在收到中断后,只需清除中断状态,无需禁用中断、启…...

Pytest之fixture的常见用法

文章目录 1.前言2.使用fixture执行前置操作3.使用conftest共享fixture4.使用yield执行后置操作 1.前言 在pytest中,fixture是一个非常强大和灵活的功能,用于为测试函数提供固定的测试数据、测试环境或执行一些前置和后置操作等, 与setup和te…...

【分库分表】基于mysql+shardingSphere的分库分表技术

目录 1.什么是分库分表 2.分片方法 3.测试数据 4.shardingSphere 4.1.介绍 4.2.sharding jdbc 4.3.sharding proxy 4.4.两者之间的对比 5.留个尾巴 1.什么是分库分表 分库分表是一种场景解决方案,它的出现是为了解决一些场景问题的,哪些场景喃…...

合并两个有序链表:递归与迭代的实现分析

合并两个有序链表:递归与迭代的实现分析 在算法与数据结构的世界里,链表作为一种基本的数据结构,经常被用来解决各种问题。特别是对于有序链表的合并,既是经典面试题,也是提高编程能力的重要练习之一。合并两个有序链…...

HTML AI 编程助手

HTML AI 编程助手 引言 随着人工智能技术的飞速发展,编程领域也迎来了新的变革。HTML,作为网页制作的基础语言,与AI技术的结合,为开发者带来了前所未有的便利。本文将探讨HTML AI编程助手的功能、应用场景以及如何利用它提高编程…...

备战蓝桥杯Day11 DFS

DFS 1.要点 (1)朴素dfs 下面保存现场和恢复现场就是回溯法的思想,用dfs实现,而本质是用递归实现,代码框架: ans; //答案,常用全局变量表示 int mark[N]; //记录状态i是否被处理过 …...

Oracle 认证为有哪几个技术方向

Oracle 认证技术方向,分别是数据库管理、开发、云平台,每个方向都有不同的学习等级 数据库运维方向 Oracle Certified Professional(OCP):19c OCA内容已和OCP合并 OCP 19c属于oracle认证专家,要求考生掌握深…...

25物理学研究生复试面试问题汇总 物理学专业知识问题很全! 物理学复试全流程攻略 物理学考研复试调剂真题汇总

正在为物理考研复试专业面试发愁的你,是不是不知道从哪开始准备? 学姐告诉你,其实物理考研复试并没有你想象的那么难!只要掌握正确的备考方法,稳扎稳打,你也可以轻松拿下高分!今天给大家准备了…...

网络安全技术与应用

文章详细介绍了网络安全及相关技术,分析了其中的一类应用安全问题——PC机的安全问题,给出了解决这类问题的安全技术——PC防火墙技术。 1 网络安全及相关技术 自20世纪…...

APISIX Dashboard上的配置操作

文章目录 登录配置路由配置消费者创建后端服务项目配置上游再创建一个路由测试 登录 http://192.168.10.101:9000/user/login?redirect%2Fdashboard 根据docker 容器里的指定端口: 配置路由 通过apisix 的API管理接口来创建(此路由,直接…...

深度剖析数据分析职业成长阶梯

一、数据分析岗位剖析 目前,数据分析领域主要有以下几类岗位:业务数据分析师、商业数据分析师、数据运营、数据产品经理、数据工程师、数据科学家等,按照工作侧重点不同,本文将上述岗位分为偏业务和偏技术两大类,并对…...

HarmonyOS学习第11天:布局秘籍RelativeLayout进阶之路

布局基础:RelativeLayout 初印象 在 HarmonyOS 的界面开发中,布局是构建用户界面的关键环节,它决定了各个组件在屏幕上的位置和排列方式。而 RelativeLayout(相对布局)则是其中一种功能强大且灵活的布局方式&#xff0…...

问题修复-后端返给前端的时间展示错误

问题现象: 后端给前端返回的时间展示有问题。 需要按照yyyy-MM-dd HH:mm:ss 的形式展示 两种办法: 第一种 在实体类的属性上添加JsonFormat注解 第二种(建议使用) 扩展mvc框架中的消息转换器 代码: 因为配置类继…...

怎么排查页面响应慢的问题

一、排查流程图 -----------------| 全局监控报警触发 |-----------------|▼-----------------| 定位异常服务节点 |-----------------|------------------▼ ▼ ----------------- ----------------- | 基础设施层排查 | | 应用层代码排查 | | (网…...

第二十四:5.2【搭建 pinia 环境】axios 异步调用数据

第一步安装&#xff1a;npm install pinia 第二步&#xff1a;操作src/main.ts 改变里面的值的信息&#xff1a; <div class"count"><h2>当前求和为&#xff1a;{{ sum }}</h2><select v-model.number"n">  // .number 这里是…...

SpringBoot——生成Excel文件

在Springboot以及其他的一些项目中&#xff0c;或许我们可能需要将数据查询出来进行生成Excel文件进行数据的展示&#xff0c;或者用于进行邮箱发送进行附件添加 依赖引入 此处demo使用maven依赖进行使用 <dependency><groupId>org.apache.poi</groupId>&…...

java高级(IO流多线程)

file 递归 字符集 编码 乱码gbk&#xff0c;a我m&#xff0c;utf-8 缓冲流 冒泡排序 //冒泡排序 public static void bubbleSort(int[] arr) {int n arr.length;for (int i 0; i < n - 1; i) { // 外层循环控制排序轮数for (int j 0; j < n -i - 1; j) { // 内层循环…...

MySQL 用户权限管理深度解析:从基础到高阶实践(2000字指南)

MySQL 用户权限管理是数据库安全与运维的核心环节。无论是本地开发环境还是企业级生产环境,合理配置用户权限、理解版本差异、遵循安全规范都至关重要。本文将从 ​基础权限配置、版本差异详解、安全加固策略、高阶权限管理、故障排查​ 等多个维度展开,覆盖 MySQL 5.7、8.0 …...

【0011】HTML其他文本格式化标签详解(em标签、strong标签、b标签、i标签、sup标签、sub标签......)

如果你觉得我的文章写的不错&#xff0c;请关注我哟&#xff0c;请点赞、评论&#xff0c;收藏此文章&#xff0c;谢谢&#xff01; 本文内容体系结构如下&#xff1a; 本文旨在深入探讨HTML中其他的文本格式化标签&#xff0c;主要有<em> 标签、<strong> 标签、…...

数据虚拟化的中阶实践:从概念到实现

数据虚拟化的中阶实践:从概念到实现 在大数据时代,数据的数量、种类和来源呈现爆炸式增长,如何高效、灵活地访问和利用这些数据成为了企业面临的重要问题。数据虚拟化作为一种创新的技术,正逐渐成为解决这一难题的关键。它通过抽象化层将底层数据源与应用程序隔离,使得数…...

AI辅助学习vue第十四章

第十四章&#xff1a;技术引领与未来展望 在第十五章&#xff0c;你已经在Vue技术领域深耕许久&#xff0c;积累了丰富的经验与卓越的影响力。此时&#xff0c;你将站在行业前沿&#xff0c;引领技术走向&#xff0c;为Vue技术的未来发展开辟新道路。 1. 引领Vue技术发展方向…...

DeepEP库开源啦!DeepSeek优化GPU通信,破算力瓶颈。

在人工智能和大数据日益盛行的今天&#xff0c;算力成为了制约技术发展的关键因素之一。随着模型规模的不断扩大&#xff0c;GPU间的通信瓶颈问题日益凸显&#xff0c;成为了制约深度学习训练效率的一大难题。近日&#xff0c;DeepSeek团队开源了DeepEP库&#xff0c;旨在通过优…...

蓝桥杯web第三天

展开扇子题目&#xff0c; #box:hover #item1 { transform:rotate(-60deg); } 当悬浮在父盒子&#xff0c;子元素旋转 webkit display: -webkit-box&#xff1a;将元素设置为弹性伸缩盒子模型。-webkit-box-orient: vertical&#xff1a;设置伸缩盒子的子元素排列方…...

Gin从入门到精通 (七)文件上传和下载

文件上传和下载 1.文件上传 1.1单文件上传 在 Gin 中处理单文件上传&#xff0c;可以使用 c.FormFile 方法获取上传的文件&#xff0c;然后使用 c.SaveUploadedFile 方法保存文件。 package mainimport ("github.com/gin-gonic/gin""log" )func main()…...

【Java】Stream API

概述 Stream API ( java.util.stream) 把真正的函数式编程风格引入到Java中。这是目前为止对Java类库最好的补充&#xff0c;因为Stream API可以极大提供Java程序员的生产力&#xff0c;让程序员写出高效率、干净、简洁的代码。 Stream是Java8中处理集合的关键抽象概念&#…...

linux-Dockerfile及docker-compose.yml相关字段用途

文章目录 计算机系统5G云计算LINUX Dockerfile及docker-conpose.yml相关字段用途一、Dockerfile1、基础指令2、.高级指令3、多阶段构建指令 二、Docker-Compose.yml1、服务定义&#xff08;services&#xff09;2、高级服务配置3、网络配置 (networks)4、卷配置 (volumes)5、扩…...

基于Selenium的Python淘宝评论爬取教程

文章目录 前言1. 环境准备安装 Python&#xff1a;安装 Selenium&#xff1a;下载浏览器驱动&#xff1a; 2. 实现思路3. 代码实现4. 代码解释5. 注意事项 前言 以下是一个基于 Selenium 的 Python 淘宝评论爬取教程&#xff0c;需要注意的是&#xff0c;爬取网站数据应当遵守…...

网络空间安全(7)攻防环境搭建

一、搭建前的准备 硬件资源&#xff1a;至少需要两台计算机&#xff0c;一台作为攻击机&#xff0c;用于执行攻击操作&#xff1b;另一台作为靶机&#xff0c;作为被攻击的目标。 软件资源&#xff1a; 操作系统&#xff1a;如Windows、Linux等&#xff0c;用于安装在攻击机和…...

【Veristand】Veristand 预编写教程目录

很久没有更新&#xff0c;最近打算出一期Veristand教程&#xff0c;暂时目录列成下面这个表格&#xff0c;如果各位有关心的遗漏的点&#xff0c;可以在评论区提问&#xff0c;我后期可以考虑添加进去&#xff0c;但是提前声明&#xff0c;太过小众的点我不会&#xff0c;欢迎各…...

大白话页面加载速度,如何优化提升?

大白话页面加载速度&#xff0c;如何优化提升&#xff1f; 咱来好好唠唠页面加载速度这事儿&#xff0c;再说说怎么把它提上去。 页面加载速度是咋回事儿 页面加载速度啊&#xff0c;就好比你去餐厅吃饭&#xff0c;从你坐下点餐到饭菜端上桌的时间。在网页里&#xff0c;就…...