Pytorch 第十四回:神经网络编码器——变分自动编解码器
Pytorch 第十四回:神经网络编码器——变分自动编解码器
本次开启深度学习第十四回,基于Pytorch的神经网络编码器。本回分享VAE变分自动编码器。在本回中,通过minist数据集来分享如何建立一个变分自动编码器。接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0
文章目录
- Pytorch 第十四回:神经网络编码器——变分自动编解码器
- 前言
- 1 变分自动编码
- 2 证据下界
- 一、数据准备
- 二、使用步骤
- 1、定义模型
- 2、定义损失函数和优化函数
- 3、定义图片生成函数
- 三、模型训练
- 1、实例化模型
- 2、迭代训练模型
- 3、生成图片展示
- 总结
前言
讲述模型前,先讲述两个概念,统一下思路:
1 变分自动编码
变分自动编码(Variational Auto-Encoder,VAE) 是一种通过“概率压缩+结构约束”实现数据生成与特征提取的深度学习模型。简而言之,编码器不直接输出一个隐藏变量,而是输出一个均值和一个方差,这两个变量刻画了隐藏变量的高斯分布。这样,可以从这个分布中随机采样隐藏变量,再用解码器生成新图片。其结构图如下所示:
注:
VAE多用在数据生成、插值、半监督学习等场景中。VAE是自动编码器在生成能力上的概率化扩展,通过引入潜在空间分布约束和变分推断优化目标,克服了传统自动编码器无法生成新样本的缺陷。
2 证据下界
证据下界(Evidence Lower Bound, ELBO) 是VAE模型优化的核心,其本质是对数据真实分布的下界近似。ELBO可拆解为两部分,分别对应VAE的两个核心目标:一个是重构项,即解码器从隐藏变量 z 生成的数据 x 与原始输入 x 的匹配程度(本代码中为均方误差MES);另一个是KL散度项(正则化项),约束近似后验分布 q(z∣x) 与先验分布 p(z)(通常为标准正态分布)的相似性,其作用为防止过拟合,保证潜在空间连续性(使 z 的分布平滑,便于生成新样本)。
注:为何需要证据下界?
1)VAE的目标是建模输入数据分布 ,但由于数据复杂(如图像、文本),直接建模非常困难。因此引入隐藏变量 z,借助隐藏变量生成数据的方式得到数据分布 ,即隐变量建模。
2)但这个过程中积分难以直接计算,导致优化目标不可行。因此采用一个近似分布 (编码器)逼近真实后验分布 ,进而推导出一个可优化的下界(ELBO)(换句话说,就是用近似的代替真实的进行积分计算),即变分推断。
闲言少叙,直接展示逻辑,先上引用:
import os
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
import time
一、数据准备
首先准备MNIST数据集,并进行数据的预处理。关于数据集、数据加载器的介绍,可以查看第五回内容。
data_treating = tfs.Compose([tfs.ToTensor(),tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('./data', transform=data_treating)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
二、使用步骤
1、定义模型
模型由五部分组成,分别是初始化函数、编码函数、解码函数、参数重构函数、前向传播函数。各函数的连接对应上面所展示的VAE结构图。即,编码器输出均值和方差,再经过重构函数进行参数重构,最后通过解码器传出新的数据。在初始化函数中,c21用来获得编码器的均值参数,c22用获得编码器的对数方差参数。其代码如下:
class vae_net(nn.Module):def __init__(self):super(vae_net, self).__init__()self.c1 = nn.Linear(28 * 28, 400)self.c21 = nn.Linear(400, 20)self.c22 = nn.Linear(400, 20)self.c3 = nn.Linear(20, 400)self.c4 = nn.Linear(400, 28*28)def encode(self, x):code1 = F.relu(self.c1(x))return self.c21(code1), self.c22(code1)# 重参数化def reparameterization(self, b, a):value1 = torch.exp(a.mul(0.5))value2 = torch.FloatTensor(value1.size()).normal_()if torch.cuda.is_available():value2 = value2.cuda()return value2.mul(value1).add_(b)def decode(self, x):code2 = F.relu(self.c3(x))return F.tanh(self.c4(code2))def forward(self, x):b, a = self.encode(x)value = self.reparameterization(b, a)return self.decode(value), b, a
2、定义损失函数和优化函数
本次定义的损失函数比较特殊,不能和之前分享的内容一样,直接采用均方误差损失函数计算的损失值。这里,需要计算KL散度项,通过KL散度约束获得模型的损失函数。具体代码如下所示:
reconstruction_function = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)def loss_f(recons_x, x, b, a):MSE = reconstruction_function(recons_x, x)KLD_element = b.pow(2).add_(a.exp()).mul_(-1).add_(1).add_(a)KLD = torch.sum(KLD_element).mul_(-0.5)return MSE + KLD
3、定义图片生成函数
模型训练后,需要将图片进行保存。由于训练时修改了数据的格式,因此这里需要还原数据的图片格式,代码如下。
def change_image(x):x = 0.5 * (x + 1.)x = x.clamp(0, 1)x = x.view(x.shape[0], 1, 28, 28)return x
三、模型训练
1、实例化模型
这里实例化了一个变分自动编码的模型,同时为了加快模型训练,引入GPU进行训练(如何引入GPU进行训练,可以查看第六回)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device=", device)
net = vae_net().to(device)
2、迭代训练模型
进行100次迭代训练,并将生成的图片保存。
time1 = time.time()
for e in range(100):for image, _ in train_data:image = image.view(image.shape[0], -1)image = image.to(device)optimizer.zero_grad()recons_image, b, a = net(image)loss = loss_f(recons_image, image, b, a) / image.shape[0]loss.backward()optimizer.step()time_consume = time.time() - time1if (e + 1) % 10 == 0:print('epoch: {}, Loss: {:.4f},consume time:{:.3f}s'.format(e + 1, loss, time_consume))new_image = change_image(recons_image.cpu().data)if not os.path.exists('./vae_image'):os.mkdir('./vae_image')save_image(new_image, './vae_image/image_{}.png'.format(e + 1))
3、生成图片展示
训练10次的图片
训练100次的图片
相对来说,经过100次迭代训练,生成的数字轮廓还是比较清晰的。
总结
1)数据准备:准备MNIST集;
2)模型准备:定义变分自动编码模型、损失函数和优化器;
3)数据训练:实例化模型并训练,生成新的图片数据
相关文章:
Pytorch 第十四回:神经网络编码器——变分自动编解码器
Pytorch 第十四回:神经网络编码器——变分自动编解码器 本次开启深度学习第十四回,基于Pytorch的神经网络编码器。本回分享VAE变分自动编码器。在本回中,通过minist数据集来分享如何建立一个变分自动编码器。接下来给大家分享具体思路。 本次…...
hive排序函数
在 Hive 中,排序可以通过几种不同的方法来实现,通常依赖于 ORDER BY 或 SORT BY 等函数。这里简要介绍这几种排序方法: 1. ORDER BY ORDER BY 用于对结果集进行全局排序。它会将所有数据加载到一个节点进行排序,因此可能会导致性能问题,尤其是在数据量很大的时候。 语法…...
Android测试王炸:Appium + UI Automator2
Android平台主流开源框架简介 在Android平台上,有多个开源且好用的自动化测试框架。以下是几个被广泛使用和认可的框架: 1.1 Appium Appium是一个跨平台的移动测试工具,支持iOS和Android上的原生、混合及移动Web应用。 它使用了供应商提供的…...
用Python打造增强现实的魔法:实时对象叠加系统全解析
友友们好! 我是Echo_Wish,我的的新专栏《Python进阶》以及《Python!实战!》正式启动啦!这是专为那些渴望提升Python技能的朋友们量身打造的专栏,无论你是已经有一定基础的开发者,还是希望深入挖掘Python潜力的爱好者,这里都将是你不可错过的宝藏。 在这个专栏中,你将会…...
const let var 在react jsx中的使用方法 。
在 JavaScript 里,const 和 let 都是 ES6(ES2015)引入的用于声明变量的关键字,它们和之前的 var 关键字有所不同。下面为你详细介绍 const 和 let 的区别: 1. 块级作用域 const 和 let 都具备块级作用域,…...
C++隐式转换的机制、风险与消除方法
引言 C作为一门强类型语言,类型安全是其核心特性之一。 然而,隐式转换(Implicit Conversion)的存在既为开发者提供了便利,也可能成为程序中的“隐藏炸弹”。 一、隐式转换的定义与分类 1.1 什么是隐式转换…...
Python 为什么要保留显式的 self ?
当你在类中定义方法时,Python要求第一个参数必须表示当前对象实例。当你调用obj.method(),Python 本质上会将它转换为ClassName.method(obj)。 所以你需要通过self参数显式接收这个实例,才能访问该对象的属性和其他方法。如果不加self&#…...
Linux 性能调优之CPU认知
写在前面 博文内容为《性能之巅 系统、企业与云可观测性(第2版)》CPU 章节课后习题答案整理内容涉及: CPU 术语,指标认知CPU 性能问题分析解决CPU 资源负载特征分析应用程序用户态CPU用量分析理解不足小伙伴帮忙指正对每个人而言,真正的职责只有一个:找到自我。然后在心中…...
认识vue中的install和使用场景
写在前面 install 在实际开发中如果你只是一个简单的业务实现者,那么大部分时间你是用不到install的,因为你用到的基本上都是别人封装好的插件、组件、方法、指令等等,但是如果你需要给公司的架构做建设,install就是你避不开的一个…...
C++Cherno 学习笔记day17 [66]-[70] 类型双关、联合体、虚析构函数、类型转换、条件与操作断点
b站Cherno的课[66]-[70] 一、C的类型双关二、C的union(联合体、共用体)三、C的虚析构函数四、C的类型转换五、条件与操作断点——VisualStudio小技巧 一、C的类型双关 作用:在C中绕过类型系统 C是强类型语言 有一个类型系统,不…...
3.神经网络
神经网络 神经元与大脑 神经网络神经元的结构: 输入(Input):接收来自前一层神经元的信息。 权重(Weights):每个输入都有一个权重,表示其重要性。 加权和(Weighted Sum&a…...
CentOS 7安装Python3.12
文章目录 使用pyenv安装python3.12一、gitub下载pyenv二、升级GCC三.升级openssl这样python3.12.9就完成安装在CentOS上啦! 使用pyenv安装python3.12 一、gitub下载pyenv https://github.com/pyenv/pyenv 按照README,pyenv教程安装即可 二、升级GCC 安…...
微服务无感发布实践:基于Nacos的客户端缓存与故障转移机制
微服务无感发布实践:基于Nacos的客户端缓存与故障转移机制 背景与问题场景 在微服务架构中,服务的动态扩缩容、滚动升级是常态,而服务实例的上下线需通过注册中心(如Nacos)实现服务发现的实时同步。但在实际生产环境…...
5.2 自定义通知操作按钮(UNNotificationAction)
在本地推送通知中添加自定义操作按钮可以增强用户交互性,让用户无需打开应用就能执行一些快速操作。本节将详细介绍如何在SwiftUI应用中实现这一功能。 基本概念 UNNotificationAction 和 UNNotificationCategory 是UserNotifications框架中用于定义通知交互的核心…...
Python与链上数据分析:解锁区块链数据的潜力
Python与链上数据分析:解锁区块链数据的潜力 引言 区块链技术的兴起不仅改变了金融行业,也为数据分析领域带来了全新的机遇。链上数据(On-chain Data)是区块链网络中公开透明的交易记录和活动数据,它为我们提供了一个独特的视角,去观察用户行为、市场趋势以及网络健康状…...
数字化转型:未来已来,企业如何抢占先机?
近年来,“数字化转型”从一个技术热词逐渐演变为各行各业的“必选项”。无论是全球市场还是中国市场,数字化浪潮正以不可逆的姿态重塑商业生态。据IDC预测,到2028年,中国数字化转型市场规模将突破7300亿美元,全球投资规…...
Web3游戏全栈开发实战指南:智能合约与去中心化生态构建全解析
在GameFi市场规模突破千亿美元的当下,去中心化游戏系统开发正面临技术架构升级与生态融合的双重机遇。本文基于Solidity、Rust等多链智能合约开发经验,结合Truffle、Hardhat等主流框架,深度解析如何构建高性能、高收益的链游生态系统。 一、…...
Windows 图形显示驱动开发-WDDM 2.0功能_IoMmu 模型
概述 输入输出内存管理单元 (IOMMU) 是一个硬件组件,它将支持具有 DMA 功能的 I/O 总线连接到系统内存。 它将设备可见的虚拟地址映射到物理地址,使其在虚拟化中很有用。 在 WDDM 2.0 IoMmu 模型中,每个进程都有一个虚拟地址空间࿰…...
uniapp微信小程序基于wu-input二次封装TInput组件(支持点击下拉选择、支持整数、电话、小数、身份证、小数点位数控制功能)
一、 最终效果 二、实现了功能 1、支持输入正整数---设置specifyTypeinteger 2、支持输入数字(含小数点)---设置specifyTypedecimal,可设置decimalLimit来调整小数点位数 3、支持输入手机号--设置specifyTypephone 4、支持输入身份证号---设…...
Java 大厂面试题 -- JVM 深度剖析:解锁大厂 Offe 的核心密钥
最近佳作推荐: Java大厂面试高频考点|分布式系统JVM优化实战全解析(附真题)(New) Java大厂面试题 – JVM 优化进阶之路:从原理到实战的深度剖析(2)(New&#…...
小白入门JVM、字节码、类加载机制图解
前提知识~ JDK 基本介绍 JDK 的全称(Java Development Kit Java 开发工具包)JDK JRE java 的开发工具[java, javac,javadoc,javap 等]JDK 是提供给Java 开发人员使用的,其中包含了java 的开发工具,也包括了JRE。可开发、编译、调试…… JRE 基本介绍…...
新能源汽车动力性与经济性优化中的经典数学模型
一、动力性优化数学模型 动力性优化的核心目标是提升车辆的加速性能、最高车速及爬坡能力,主要数学模型包括: 1. 车辆纵向动力学模型 模型方程: 应用场景: 计算不同工况下的驱动力需求匹配电机扭矩与减速器速比案例ÿ…...
高级java每日一道面试题-2025年3月25日-微服务篇[Nacos篇]-Nacos中的命名空间(Namespace)有什么作用?
如果有遗漏,评论区告诉我进行补充 面试官: Nacos中的命名空间(Namespace)有什么作用? 我回答: 在Java高级面试中,关于Nacos中的命名空间(Namespace)的作用,是一个考察候选人对微服务架构和配…...
5.JVM-G1垃圾回收器
一、什么是G1 二、G1的三种垃圾回收方式 region默认2048 三、YGC的过程(Step1) 3.1相关代码 public class YGC1 {/*-Xmx128M -XX:UseG1GC -XX:PrintGCTimeStamps -XX:PrintGCDetails -XX:UnlockExperimentalVMOptions -XX:G1LogLevelfinest128m5% 60%6.4M 75M*/private stati…...
变量、数据、值类型引用类型的存储方式
代码写了也有2年了,对于这些基础的程序名词,说出口也是模棱两可,心里很不爽,很多基础还是模糊不清,清算一下...... Example值类型: int x 10; 变量:“x”是一个标识符,它对应着栈…...
分布式和微服务的区别
1. 定义 在讨论分布式系统和微服务的区别之前,我们先明确两者的定义: 分布式系统:是一组相互独立的计算机,通过网络协同工作,共同完成某个任务的系统。其核心在于资源的分布和任务的分解。 微服务架构:是…...
数组的常见算法一
注: 本文来自尚硅谷-宋红康仅用来学习备份 6.1 数值型数组特征值统计 这里的特征值涉及到:平均值、最大值、最小值、总和等 **举例1:**数组统计:求总和、均值 public class TestArrayElementSum {public static void main(String[] args)…...
Leedcode刷题 | Day27_贪心算法01
一、学习任务 455.分发饼干代码随想录376. 摆动序列53. 最大子序和 二、具体题目 1.455分发饼干455. 分发饼干 - 力扣(LeetCode) 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。 对…...
springboot集成kafka,后续需要通过flask封装restful接口
Spring Boot与Kafka的整合 在现代软件开发中,消息队列是实现服务解耦、异步消息处理、流量削峰等场景的重要组件。Apache Kafka是一个分布式流处理平台,它具有高吞吐量、可扩展性和容错性等特点。Spring Boot作为一个轻量级的、用于构建微服务的框架&am…...
MYSQL数据库语法补充2
一,数据库设计范式(原则) 数据库设计三大范式: 第一范式: 保证列的原子性(列不可再分) 反例:联系方式(手机,邮箱,qq) 正例: 手机号,qq,邮箱. 第二范式: 要有主键,其他列依赖于主键列,因为主键是唯一的,依赖了主键,这行数据就是唯一的. 第三范式: 多表关联时,在…...
Flask返回文件方法详解
在 Flask 中返回文件可以通过 send_file 或 send_from_directory 方法实现。以下是详细方法和示例: 1. 使用 send_file 返回文件 这是最直接的方法,适用于返回任意路径的文件。 from flask import Flask, send_fileapp = Flask(__name__)@app.route("/download")…...
随机数据下的最短路问题(Dijstra优先队列)
题目描述 给定 NN 个点和 MM 条单向道路,每条道路都连接着两个点,每个点都有自己编号,分别为 1∼N1∼N 。 问你从 SS 点出发,到达每个点的最短路径为多少。 输入描述 输入第一行包含三个正整数 N,M,SN,M,S。 第 22 到 M1M1 行…...
CPP杂项
注意:声明类,只是告知有这个类未完全定义,若使用类里具体属性,要么将访问块(函数之类)放到定义后,要么直接完全定义 C 友元函数/类的使用关键点(声明顺序为核心) 1. 友元…...
Idea将Java工程打包成war包并发布
1、问题概述? 项目开发之后,我们需要将Java工程打包成war后缀,并进行发布。之前在网上看到很多的文章,但是都不齐全,今天将提供一个完整的实现打包war工程,并发布的文章,希望对大家有所帮助,主要解决如下问题: 1、war工程需要满足的相关配置 2、如何解决项目中的JDK…...
多级缓存模型设计
为了有效避免缓存击穿、穿透和雪崩的问题。最基本的缓存设计就是从数据库中查询数据时,无论数据库中是否存在数据,都会将查询的结果缓存起来,并设置一定的有效期。后续请求访问缓存时,如果缓存中存在指定Key时,哪怕对应…...
SGLang实战:从KV缓存复用到底层优化,解锁大模型高效推理的全栈方案
在当今快速发展的人工智能领域,大型语言模型(LLM)的应用已从简单对话扩展到需要复杂逻辑控制、多轮交互和结构化输出的高级任务。面对这一趋势,如何高效地微调并部署这些大模型成为开发者面临的核心挑战。本文将深入探讨SGLang——这一专为大模型设计的高…...
【中大厂面试题】腾讯 后端 校招 最新面试题
操作系统 进程和线程的区别 本质区别:进程是操作系统资源分配的基本单位,而线程是任务调度和执行的基本单位 在开销方面:每个进程都有独立的代码和数据空间(程序上下文),程序之间的切换会有较大的开销&am…...
数据结构第六章(一) -图
数据结构第六章(一) 图一、图的基本概念1.图的分类、度、权等2.路径、回路、连通等3.连通分量、生成树等 二、几种特殊的图1.完全图2.稠密图、稀疏图3.树、有向树 三、常见考点总结 图 一、图的基本概念 感觉要想怎么定义图,图就是顶点和边组…...
【技术白皮书】外功心法 | 第三部分 | 数据结构与算法基础(常用的数据结构)
数据结构与算法基础 什么是算法?算法效率查询和排序为什么排序如此重要?思考问题如何确定复杂性数据结构连续或链接的数据结构链表的优点数组的优点数组集合Set 声明的一些方法有Multiset多元集合栈和队列何时使用栈和队列数据字典字典Hash的实现时间复杂度对时间的影响什么是…...
spark简介和安装
spark概念 Spark 是一种基于内存的快速、通用、可扩展的大数据分析计算引擎 spark核心模块 Spark Core Spark Core 中提供了 Spark 最基础与最核心的功能,Spark 其他的功能如:Spark SQL,Spark Streaming,GraphX, MLlib 都是在 …...
如何在 CentOS 7 系统上以容器方式部署 GitLab,使用 ZeroNews 通过互联网访问 GitLab 私有仓库,进行代码版本发布与更新
第 1 步: 部署 GitLab 容器 在开始部署 GitLab 容器之前,您需要创建本地目录来存储 GitLab 数据、配置和日志: #创建本地目录 mkdir -p /opt/docker/gitlab/data mkdir -p /opt/docker/gitlab/config mkdir -p /opt/docker/gitlab/log#gi…...
springboot Filter实现请求响应全链路拦截!完整日志监控方案
一、为什么你需要这个过滤器? 日志痛点: 🚨 请求参数散落在各处? 🚨 响应数据无法统一记录? 🚨 日志与业务代码严重耦合? 解决方案: 一个Filter同时拦截请…...
spring mvc 在拦截器、控制器和视图中获取和使用国际化区域信息的完整示例
在拦截器、控制器和视图中获取和使用国际化区域信息的完整示例 1. 核心组件代码示例 1.1 配置类(Spring Boot) import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.spring…...
docker部署jenkins并成功自动化部署微服务
一、环境版本清单: docker 26.1.4JDK 17.0.28Mysql 8.0.27Redis 6.0.5nacos 2.5.1maven 3.8.8jenkins 2.492.2 二、服务架构:有gateway,archives,system这三个服务 三、部署步骤 四、安装linux 五、在linux上安装redis&#…...
android 14.0 工厂模式 测试音频的一些问题(高通)
1之前用tinycap,现在得用agmcap 执行----agmcap /data/test.wav -D 100 -d 101 -i CODEC_DMA-LPAIF_RXTX-TX-3 -T 3 报错1 agmcap data/test.wav -D 100 -d 101 -i CODEC_DMA-LPAIF_RXTX-TX-3 -T 3 Failed to open xml file name /vendor/etc/backend_co…...
vue:前端预览 / chrome浏览器设置 / <iframe> 方法预览 doc、pdf / vue-pdf 预览pdf
一、本文目标 <iframe> 方法预览 pdf 、word vue-pdf 预览pdf 二、<iframe> 方法 2.1、iframe 方法预览需要 浏览器 设置为: chrome:设置-隐私设置和安全性-网站设置-更多内容设置-PDF文档 浏览器访问: chrome://settings/co…...
Axure疑难杂症:详解横向菜单左右拖动效果及阈值说明
亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:横向菜单左右拖动 主要内容:菜单设计、拖动效果、阈值设计 应用场景:APP横向菜单栏、台账菜单栏、功能选择栏等 案例展示: 案例视频: 横向菜单左右拖动效果 正文内容: 最近很多粉丝私信我,横向…...
多视图几何--立体校正--Bouguet方法
Bouguet算法的数学原理详解 Bouguet算法的核心目标是实现双目相机的极线校正,使左右图像的对应点位于同一水平线上,从而简化立体匹配。其数学原理围绕旋转矩阵分解、极线约束构造和重投影优化展开,以下是分步推导: 一、坐标系定义…...
debian12 mysql完全卸载
MySQL 的数据目录通常是 /var/lib/mysql,配置文件通常在 /etc/mysql 目录下。使用以下命令删除这些目录: sudo rm -rf /var/lib/mysql sudo rm -rf /etc/mysql清理残留文件 sudo find / -name "mysql*" -exec rm -rf {} \; 2>/dev/null验…...
随机动作指令活体检测技术的广泛应用,为人脸识别安全保驾护航
随着人脸识别技术在金融支付、门禁系统、手机解锁等领域的广泛应用,攻击手段也日益多样化,如照片、视频回放、3D面具等伪造方式对系统安全构成严重威胁。传统的人脸识别技术难以区分真实人脸与伪造攻击,正是在这样的背景下,随机动…...