pytorch实现半监督学习
人工智能例子汇总:AI常见的算法和例子-CSDN博客
半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:
1. 数据准备
- 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
- 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
- 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。
2. 选择半监督学习方法
常见的半监督学习方法包括:
- 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
- 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
- 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
- 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。
3. 训练初始模型
- 仅使用有标签数据训练一个初始模型。
- 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
- 训练过程中可以使用数据增强、正则化等优化策略。
4. 利用无标签数据增强训练
- 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
- 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
- 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。
5. 模型迭代更新
- 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
- 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。
6. 评估和测试
- 使用测试集(通常是有标签的数据)评估模型性能。
- 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。
7. 调优和部署
- 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
- 结合业务需求,将最终模型部署到实际应用中。
关键步骤:
- 初始化模型:首先使用有标签数据训练模型。
- 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。
- 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。
- 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt# 简化的神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 8, kernel_size=3) # 缩小卷积层的输出通道self.fc1 = nn.Linear(8 * 26 * 26, 10) # 调整全连接层的输入和输出尺寸def forward(self, x):x = F.relu(self.conv1(x))x = x.view(x.size(0), -1) # 展平x = self.fc1(x)return x# 自定义数据集
class CustomDataset(Dataset):def __init__(self, data, labels=None):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):if self.labels is not None:return self.data[idx], self.labels[idx]else:return self.data[idx], -1 # 无标签数据# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):model.train()labeled_loss_value = 0pseudo_loss_value = 0for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)unlabeled_data = unlabeled_data.to(device)# 1. 有标签数据训练optimizer.zero_grad()labeled_output = model(labeled_data)labeled_loss = F.cross_entropy(labeled_output, labeled_labels)labeled_loss.backward()# 2. 无标签数据伪标签生成unlabeled_output = model(unlabeled_data)probs = F.softmax(unlabeled_output, dim=1)max_probs, pseudo_labels = torch.max(probs, dim=1)# 伪标签置信度筛选pseudo_mask = max_probs > threshold # 置信度大于阈值的数据作为伪标签if pseudo_mask.sum() > 0:pseudo_labels = pseudo_labels[pseudo_mask]unlabeled_data_pseudo = unlabeled_data[pseudo_mask]# 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)optimizer.zero_grad() # 清除之前的梯度pseudo_output = model(unlabeled_data_pseudo)pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)pseudo_loss.backward() # 计算反向梯度optimizer.step() # 更新模型参数# 累加损失用于展示labeled_loss_value += labeled_loss.item()if pseudo_mask.sum() > 0:pseudo_loss_value += pseudo_loss.item()return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28) # 28x28 灰度图像
num_classes = 10labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True) # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True) # 缩小批量大小# 模型、优化器和设备设置
device = torch.device("cpu") # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []for epoch in range(num_epochs):labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)labeled_losses.append(labeled_loss)pseudo_losses.append(pseudo_loss)print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()# 展示伪标签生成效果(可视化一些样本的伪标签预测结果)
model.eval()
with torch.no_grad():sample_unlabeled_data = unlabeled_data[:10].to(device)output = model(sample_unlabeled_data)probs = F.softmax(output, dim=1)_, predicted_labels = torch.max(probs, dim=1)# 展示预测的标签print("Generated Pseudo Labels for Samples:")print(predicted_labels)# 假设这些是伪标签预测的图片fig, axes = plt.subplots(2, 5, figsize=(12, 5))for i, ax in enumerate(axes.flat):# 将tensor转换为NumPy数组img = sample_unlabeled_data[i].cpu().numpy().squeeze() # 转为NumPy数组ax.imshow(img, cmap='gray') # 使用灰度显示图像ax.set_title(f"Pred: {predicted_labels[i].item()}")ax.axis('off')plt.show()
相关文章:
pytorch实现半监督学习
人工智能例子汇总:AI常见的算法和例子-CSDN博客 半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下: 1. 数…...
X Window System 架构概述
X Window System 架构概述 1. X Server 与 X Client 这里引入一张维基百科的图,在Linux系统中,若用户需要图形化界面,则可以使用X Window System,其使用**Client-Server**架构,并通过网络传输相关信息。 X…...
中国证券基本知识汇总
中国证券市场是一个多层次、多领域的市场,涉及到各种金融工具、交易方式、市场参与者等内容。以下是中国证券基本知识的汇总: 1. 证券市场概述 证券市场:是指买卖证券(如股票、债券、基金等)的市场。证券市场可以分为…...
虚幻基础17:动画蓝图
能帮到你的话,就给个赞吧 😘 文章目录 animation blueprint图表(Graph): 编辑动画逻辑。变量(Variables): 管理动画参数。函数(Functions): 自定义…...
初入机器学习
写在前面 本专栏专门撰写深度学习相关的内容,防止自己遗忘,也为大家提供一些个人的思考 一切仅供参考 概念辨析 深度学习: 本质是建模,将训练得到的模型作为系统的一部分使用侧重于发现样本集中隐含的规律难点是认识并了解模型&…...
中间件的概念及基本使用
什么是中间件 中间件是ASP.NET Core的核心组件,MVC框架、响应缓存、身份验证、CORS、Swagger等都是内置中间件。 广义上来讲:Tomcat、WebLogic、Redis、IIS;狭义上来讲,ASP.NET Core中的中间件指ASP.NET Core中的一个组件。中间件…...
Docker 部署教程jenkins
Docker 部署 jenkins 教程 Jenkins 官方网站 Jenkins 是一个开源的自动化服务器,主要用于持续集成(CI)和持续交付(CD)过程。它帮助开发人员自动化构建、测试和部署应用程序,显著提高软件开发的效率和质量…...
LeetCode:53.最大子序和
跟着carl学算法,本系列博客仅做个人记录,建议大家都去看carl本人的博客,写的真的很好的! 代码随想录 LeetCode:53.最大子序和 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数…...
C++ 游戏开发:完整指南
目录 什么是游戏开发? 为什么选择 C 进行游戏开发? C 游戏开发:完整指南 1. 理解游戏开发的基础 2. 学习游戏引擎 3. 精通 C 进行游戏开发 4. 学习数学在游戏开发中的应用 5. 探索图形编程 6. 专注于游戏开发的某一领域 7. 通过游戏项目进行实…...
数据结构:时间复杂度
文章目录 为什么需要时间复杂度分析?一、大O表示法:复杂度的语言1.1 什么是大O?1.2 常见复杂度速查表 二、实战分析:解剖C语言代码2.1 循环结构的三重境界单层循环:线性时间双重循环:平方时间动态边界循环&…...
测试工程师的DS使用指南
目录 引言DeepSeek在测试设计中的应用 2.1 智能用例生成2.2 边界值分析2.3 异常场景设计DeepSeek在自动化测试中的应用 3.1 脚本智能转换3.2 日志智能分析3.3 测试数据生成DeepSeek在质量保障体系中的应用 4.1 测试策略优化4.2 缺陷模式预测4.3 技术方案验证DeepSeek在测试效能…...
http3网站的设置(AI不会配,得人工配)
堡塔PHP项目中配置nginx1.26.0设置http3协议 # 文件所在服务器中的路径 /www/server/nginx/conf/nginx.confuser www www; worker_processes auto; error_log /www/wwwlogs/nginx_error.log crit; pid /www/server/nginx/logs/nginx.pid; worker_rlimit_nofile 512…...
搜索引擎快速收录:关键词布局的艺术
本文来自:百万收录网 原文链接:https://www.baiwanshoulu.com/21.html 搜索引擎快速收录中的关键词布局,是一项既精细又富有策略性的工作。以下是对关键词布局艺术的详细阐述: 一、关键词布局的重要性 关键词布局影响着后期页面…...
WPF进阶 | WPF 动画特效揭秘:实现炫酷的界面交互效果
WPF进阶 | WPF 动画特效揭秘:实现炫酷的界面交互效果 前言一、WPF 动画基础概念1.1 什么是 WPF 动画1.2 动画的基本类型1.3 动画的核心元素 二、线性动画详解2.1 DoubleAnimation 的使用2.2 ColorAnimation 实现颜色渐变 三、关键帧动画深入3.1 DoubleAnimationUsin…...
基于微信小程序的辅助教学系统的设计与实现
标题:基于微信小程序的辅助教学系统的设计与实现 内容:1.摘要 摘要:随着移动互联网的普及和微信小程序的兴起,基于微信小程序的辅助教学系统成为了教育领域的一个新的研究热点。本文旨在设计和实现一个基于微信小程序的辅助教学系统,以提高教…...
给AI加知识库
1、加载 Document Loader文档加载器 在 langchain_community. document_loaders 里有很多种文档加载器 from langchain_community. document_loaders import *** 1、纯文本加载器:TextLoader,纯文本(不包含任何粗体、下划线、字号格式&am…...
【LeetCode 刷题】回溯算法(5)-棋盘问题
此博客为《代码随想录》二叉树章节的学习笔记,主要内容为回溯算法棋盘问题相关的题目解析。 文章目录 51. N皇后37. 解数独332.重新安排行程 51. N皇后 题目链接 class Solution:def solveNQueens(self, n: int) -> List[List[str]]:board [[. for _ in rang…...
Vue.js组件开发-实现字母向上浮动
使用Vue实现字母向上浮动的效果 实现步骤 创建Vue项目:使用Vue CLI来创建一个新的Vue项目。定义组件结构:在组件的模板中,定义包含字母的元素。添加样式:使用CSS动画来实现字母向上浮动的效果。绑定动画类:在Vue组件…...
2025蓝桥杯JAVA编程题练习Day2
1.大衣构造字符串 问题描述 已知对于一个由小写字母构成的字符串,每次操作可以选择一个索引,将该索引处的字符用三个相同的字符副本替换。 现有一长度为 NN 的字符串 UU,请帮助大衣构造一个最小长度的字符串 SS,使得经过任意次…...
WPF进阶 | WPF 样式与模板:打造个性化用户界面的利器
WPF进阶 | WPF 样式与模板:打造个性化用户界面的利器 一、前言二、WPF 样式基础2.1 什么是样式2.2 样式的定义2.3 样式的应用 三、WPF 模板基础3.1 什么是模板3.2 控件模板3.3 数据模板 四、样式与模板的高级应用4.1 样式继承4.2 模板绑定4.3 资源字典 五、实际应用…...
趣味Python100例初学者练习01
1. 1 抓交通肇事犯 一辆卡车违反交通规则,撞人后逃跑。现场有三人目击该事件,但都没有记住车号,只记下了车号的一些特征。甲说:牌照的前两位数字是相同的;乙说:牌照的后两位数字是相同的,但与前…...
每日一题——有效括号序列
有效括号序列 题目描述数据范围:复杂度要求: 示例题解代码实现代码解析1. 定义栈和栈操作2. 栈的基本操作3. 主函数 isValid4. 返回值 时间和空间复杂度分析 题目描述 给出一个仅包含字符 (, ), {, }, [, ] 的字符串,判断该字符串是否是一个…...
MQTT 术语表
Broker 有时我们也会直接将服务端称为 Broker,这两个术语可以互换使用。 Clean Start 客户端可以在连接时使用这个字段来指示是期望从已存在的会话中恢复通信,还是创建一个全新的会话。仅限 MQTT v5.0。 Client 使用 MQTT 协议连接到服务端的设备或…...
每天学点小知识之设计模式的艺术-策略模式
行为型模式的名称、定义、学习难度和使用频率如下表所示: 1.如何理解模板方法模式 模板方法模式是结构最简单的行为型设计模式,在其结构中只存在父类与子类之间的继承关系。通过使用模板方法模式,可以将一些复杂流程的实现步骤封装在一系列基…...
ubuntuCUDA安装
系列文章目录 移动硬盘制作Ubuntu系统盘 前言 根据前篇“移动硬盘制作Ubuntu系统盘”安装系统后,还不能够使用显卡。 如果需要使用显卡,还需要进行相关驱动的安装(如使用的为Nvidia显卡,就需要安装相关的Nvidia显卡驱动ÿ…...
信息学奥赛一本通 2113:【24CSPJ普及组】小木棍(sticks) | 洛谷 P11229 [CSP-J 2024] 小木棍
【题目链接】 ybt 2113:【24CSPJ普及组】小木棍(sticks) 洛谷 P11229 [CSP-J 2024] 小木棍 【题目考点】 1. 思维题,找规律 【解题思路】 解法1:找规律 该题为:求n根木棍组成的无前导0的所有可能的数…...
【数据结构】(5) ArrayList 顺序表
一、使用 ArrayList ArrayList 就是数组的封装,但是数组只有 [] 操作存取值,和 .length 操作获取数组内存长度;而 ArrayList 有更多的功能: 1、创建对象 2、扩容机制 ArrayList 有自动扩容机制,在插入元素时不用担心数…...
Elasticsearch 指南 [8.17] | Search APIs
Search API 返回与请求中定义的查询匹配的搜索结果。 http GET /my-index-000001/_search Request GET /<target>/_search GET /_search POST /<target>/_search POST /_search Prerequisites 如果启用了 Elasticsearch 安全功能,针对目标数据流…...
【自开发工具介绍】SQLSERVER的ImpDp和ExpDp工具03
SQLSERVER的ImpDp和ExpDp工具 1、全部的表导出(仅表结构导出) 2、导出的表结构,导入到新的数据库 导入前,test3数据没有任何表 导入 导入结果确认:表都被做成,但是没有数据 3、全部的表导出&#x…...
JVM-运行时数据区
JVM的组成 运行时数据区-总览 Java虚拟机在运行Java程序过程中管理的内存区域,称之为运行时数据区。 《Java虚拟机规范》中规定了每一部分的作用 运行时数据区-应用场景 Java的内存分成哪几部分? Java内存中哪些部分会内存溢出? JDK7 和J…...
经典本地影音播放器MPC-BE.
经典本地影音播放器MPC-BE 链接:https://pan.xunlei.com/s/VOIAZbbIuBM1haFdMYCubsU-A1?pwd4iz3# MPC-BE(Media Player Classic Black Edition)是来自 MPC-HC(Media Player Classic Home Cinema)的俄罗斯开发者重新…...
求水仙花数,提取算好,打表法。或者暴力解出来。
暴力解法 #include<bits/stdc.h> using namespace std; int main() {int n,m;cin>>n>>m;if(n<3||n>7||m<0){cout<<"-1";return 0;}int powN[10];//记录0-9的n次方for(int i0;i<10;i){powN[i](int)pow(i,n);}int low(int) pow(1…...
后盾人JS -- 原型
没有原型的对象 也有没有原型的对象 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document<…...
Deepseek-R1 和 OpenAI o1 这样的推理模型普遍存在“思考不足”的问题
每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…...
Nginx 命令行参数
文章来源:命令行参数 -- nginx中文文档|nginx中文教程 nginx 支持以下命令行参数: -?| — 打印帮助 以获取命令行参数。-h-c file— 使用替代项 configuration 而不是 default 文件。file-e file— 使用替代项 error log 来存储日志 而不是默认文件 &…...
YOLOV11-1:YoloV11-安装和CLI方式训练模型
YoloV11-安装和CLI方式训练模型 1.安装和运行1.1安装的基础环境1.2安装yolo相关组件1.3命令行方式使用1.3.1 训练1.3.2 预测 本文介绍yoloV11的安装和命令行接口 1.安装和运行 1.1安装的基础环境 GPU环境,其中CUDA是12.4版本 1.2安装yolo相关组件 # 克隆github…...
Docker Hub 镜像 Pull 失败的解决方案
目录 引言一、问题二、原因三、解决方法四、参考文献 引言 在云原生技术火热的当下,Docker可谓是其基础,由于其简单以及方便性,让开发人员不必再为环境配置问题而伤脑筋,因为可将其看作一个虚拟机程序去理解。所以掌握好它可谓是…...
重新思考绩效管理变革
Peter Cappelli 和 Anna Tavis 在绩效管理变革一文中,为我们带来了很多关于绩效管理变革的思考。企业为什么做绩效管理变革,为什么现在需要?让我们看看这些学者是如何思考的。 摘要 受到老板和下属的痛恨,传统的绩效考核已经被超…...
内核定时器2-高分辨率定时器
高分辨率定时器与低分辨率定时器 高分辨率定时器与低分辨率定时器相比,有如下两个根本性的不同。 (1) 高分辨率定时器使用红黑树对定时器进行管理。 (2) 定时器独立于周期时钟。即不基于jiffies,精度可以达到纳秒级别。 内核2.6.16版本开始ÿ…...
【自开发工具介绍】SQLSERVER的ImpDp和ExpDp工具02
工具运行前的环境准备 1、登录用户管理员权限确认 工具使用的登录用户(-u后面的用户),必须具有管理员的权限,因为需要读取系统表 例:Export.bat -s 10.48.111.12 -d db1 -u test -p test -schema dbo 2、Powershell的安全策略确认…...
数据结构【单链表操作大全详解】【c语言版】(只有输入输出为了方便用的c++)
单链表操作的C/C实现详解 在数据结构中,单链表是一种非常基础且重要的数据结构。它由一系列节点组成,每个节点包含数据和指向下一个节点的指针。今天我们就来深入探讨用C/C实现的单链表及其各种操作。 一、单链表的定义 const int N 1e5; //单链表 t…...
【R语言】环境空间
一、环境空间种类 R语言中有5种环境: 全局环境:也叫用户环境,指在当前用户下R程序运行的环境空间。 内部环境:通过“new.env()”命令创建的环境空间,也可以是匿名的环境空间。 父环境:当前环境空间所处…...
Python处理数据库:MySQL与SQLite详解
Python处理数据库:MySQL与SQLite详解 在数据处理和存储方面,数据库扮演着至关重要的角色。Python提供了多种与数据库交互的方式,其中pymysql库用于连接和操作MySQL数据库,而SQLite则是一种轻量级的嵌入式数据库,Pytho…...
软考高项笔记 信息技术及其发展
信息技术及其发展 ❝ 信息系统项目管理师第二章第一节 1. 网络标准协议的定义 网络协议是为计算机网络中进行数据交换而建立的规则、标准或约定的集合。网络协议由三个要素组成,分别是语义、语法和时序。 语义:解释控制信息每个部分的含义,它…...
HAO的Graham学习笔记
前置知识:凸包 摘录oiwiki 在平面上能包含所有给定点的最小凸多边形叫做凸包。 其定义为:对于给定集合 X,所有包含 X 的凸集的交集 S 被称为 X 的 凸包。 说人话就是用一个橡皮筋包含住所有给定点的形态 如图: 正题:…...
C#基础知识
0 C#介绍 定义与背景 C#(发音为C - sharp)是微软公司开发的一种高级编程语言。它是专门为构建在微软的.NET平台上运行的各种应用程序而设计的。在2000年左右推出,目的是结合当时编程语言的优点,如C的强大功能和Java的简单性与安全…...
Kafka中文文档
文章来源:https://kafka.cadn.net.cn 什么是事件流式处理? 事件流是人体中枢神经系统的数字等价物。它是 为“永远在线”的世界奠定技术基础,在这个世界里,企业越来越多地使用软件定义 和 automated,而软件的用户更…...
Tyrant(暴君):反向Shell-后门注入与持久化控制的渗透测试工具
Tyrant Tyrant 是一款用于渗透测试和远程控制持久化的恶意工具,具备以下功能: 反向Shell:允许攻击者通过指定用户UID进行反弹对应权限的Shell会话。后门注入与持久化:在目标系统中注入后门并确保即使重启后依然能恢复控制。Tyran…...
leetcode刷题-贪心04
代码随想录贪心算法part04|452. 用最少数量的箭引爆气球、435. 无重叠区间、763.划分字母区间 452. 用最少数量的箭引爆气球435. 无重叠区间763.划分字母区间 今天的三道题目,都算是 重叠区间 问题,大家可以好好感受一下。 都属于那种看起来好复杂&#…...
系统学习算法: 专题八 二叉树中的深搜
深搜其实就是深度优先遍历(dfs),与此相对的还有宽度优先遍历(bfs) 如果学完数据结构有点忘记,如下图,左边是dfs,右边是bfs 而二叉树的前序,中序,后序遍历都可…...