当前位置: 首页 > news >正文

Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(三)

Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(三)

7. 实现条件WGAN-GP

# 训练条件WGAN-GP
def train_conditional_wgan_gp():# 用于记录损失d_losses = []g_losses = []# 用于记录生成样本的多样性(通过类别分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------#  训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 为生成器生成随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假图像fake_imgs = generator(z, gen_labels)# 判别器前向传播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------#  训练生成器# ---------------------optimizer_G.zero_grad()# 为生成器生成新的随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假图像gen_imgs = generator(z, gen_labels)# 判别器评估假图像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后,评估生成样本的类别分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的图像样本save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 绘制类别分布变化plot_class_distributions(class_distributions)# 评估生成样本的类别分布
def evaluate_class_distribution():"""评估生成样本在各类别上的分布情况"""# 创建一个预训练的分类器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一个卷积层以适应灰度图classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全连接层以适应MNIST的10个类别classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加载预先训练好的MNIST分类器权重(这里假设我们有一个预训练的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 简化起见,这里我们使用一个简单的CNN分类器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设这个简单分类器已经在MNIST上训练好了# 实际应用中,应该加载一个预先训练好的模型# 生成1000个样本z = torch.randn(1000, latent_dim, device=device)# 均匀采样所有类别gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分类器预测类别with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 计算每个类别的样本数量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 绘制类别分布变化
def plot_class_distributions(class_distributions):"""绘制生成样本类别分布的变化"""epochs = [10, 20, 30, 40, 50]  # 假设每10个epoch评估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3)  # 限制y轴范围,便于比较if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存样本图像(条件版本)
def save_sample_images(epoch):"""保存按类别排列的样本图像"""# 为每个类别生成样本n_row = 10  # 每个类别一行n_col = 10  # 每个类别10个样本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定类别fixed_class = torch.tensor([i] * n_col, device=device)# 随机噪声z = torch.randn(n_col, latent_dim, device=device)# 生成图像gen_imgs = generator(z, fixed_class).detach().cpu()# 转换到[0, 1]范围gen_imgs = 0.5 * gen_imgs + 0.5# 显示图像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行条件WGAN-GP训练
if __name__ == "__main__":train_conditional_wgan_gp()

上述代码实现了一个条件WGAN-GP模型,主要区别在于:

  1. 条件输入:生成器和判别器都接收类别标签作为额外输入
  2. 嵌入层:使用嵌入层将类别标签转换为嵌入向量
  3. 类别多样性评估:添加了评估生成样本类别分布的功能
  4. 可视化:按类别排列生成样本,便于观察每个类别的质量

8. 无监督与条件生成的模式坍塌对比实验

为了更直观地比较无监督生成和条件生成在模式坍塌方面的差异,我们可以设计一个实验,分别训练无监督WGAN-GP和条件WGAN-GP,然后比较它们生成样本的模式覆盖情况。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE# 假设我们已经训练好了无监督WGAN-GP和条件WGAN-GP模型
# 分别为 unsupervised_generator 和 conditional_generatordef analyze_mode_collapse():"""分析并比较无监督和条件生成在模式坍塌方面的差异"""# 生成样本数量n_samples = 1000# 1. 从无监督生成器生成样本z_unsupervised = torch.randn(n_samples, latent_dim, device=device)unsupervised_samples = unsupervised_generator(z_unsupervised).detach().cpu()# 2. 从条件生成器生成样本(均匀覆盖所有类别)z_conditional = torch.randn(n_samples, latent_dim, device=device)conditional_labels = torch.tensor([i % 10 for i in range(n_samples)], device=device)conditional_samples = conditional_generator(z_conditional, conditional_labels).detach().cpu()# 3. 获取真实MNIST样本real_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=n_samples, shuffle=True)real_samples, _ = next(iter(real_loader))# 4. 使用预训练的分类器分类所有样本classifier = create_mnist_classifier()  # 假设我们有一个创建分类器的函数# 分类无监督生成的样本unsupervised_predictions = classify_samples(classifier, unsupervised_samples)# 分类条件生成的样本conditional_predictions = classify_samples(classifier, conditional_samples)# 分类真实样本real_predictions = classify_samples(classifier, real_samples)# 5. 计算各类别的样本分布unsupervised_distribution = compute_class_distribution(unsupervised_predictions)conditional_distribution = compute_class_distribution(conditional_predictions)real_distribution = compute_class_distribution(real_predictions)# 6. 计算分布的均匀度(使用熵)unsupervised_entropy = compute_entropy(unsupervised_distribution)conditional_entropy = compute_entropy(conditional_distribution)real_entropy = compute_entropy(real_distribution)print(f"无监督生成分布熵: {unsupervised_entropy:.4f}")print(f"条件生成分布熵: {conditional_entropy:.4f}")print(f"真实数据分布熵: {real_entropy:.4f}")# 7. 可视化样本分布visualize_distributions(unsupervised_distribution,conditional_distribution,real_distribution)# 8. 使用t-SNE将样本投影到二维空间进行可视化visualize_tsne(unsupervised_samples,conditional_samples,real_samples)def create_mnist_classifier():"""创建一个简单的MNIST分类器"""model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设分类器已经训练好了# model.load_state_dict(torch.load('mnist_classifier.pth'))return modeldef classify_samples(classifier, samples):"""使用分类器对样本进行分类"""with torch.no_grad():classifier.eval()# 确保样本在正确的设备上samples = samples.to(device)# 前向传播logits = classifier(samples)# 获取预测类别predictions = torch.argmax(logits, dim=1)return predictions.cpu().numpy()def compute_class_distribution(predictions):"""计算类别分布"""n_samples = len(predictions)distribution = np.zeros(10)for i in range(10):distribution[i] = np.sum(predictions == i) / n_samplesreturn distributiondef compute_entropy(distribution):"""计算分布的熵,衡量分布的均匀度"""# 防止log(0)distribution = distribution + 1e-10# 归一化distribution = distribution / np.sum(distribution)# 计算熵entropy = -np.sum(distribution * np.log2(distribution))return entropydef visualize_distributions(unsupervised_dist, conditional_dist, real_dist):"""可视化三种样本的类别分布"""plt.figure(figsize=(12, 5))width = 0.25x = np.arange(10)plt.bar(x - width, unsupervised_dist, width, label='Unsupervised')plt.bar(x, conditional_dist, width, label='Conditional')plt.bar(x + width, real_dist, width, label='Real')plt.xlabel('Digit Class')plt.ylabel('Proportion')plt.title('Class Distribution Comparison')plt.xticks(x)plt.legend()plt.tight_layout()plt.savefig('distribution_comparison.png')plt.close()def visualize_tsne(unsupervised_samples, conditional_samples, real_samples):"""使用t-SNE将样本投影到二维空间并可视化"""# 准备数据unsupervised_flat = unsupervised_samples.view(unsupervised_samples.size(0), -1).numpy()conditional_flat = conditional_samples.view(conditional_samples.size(0), -1).numpy()real_flat = real_samples.view(real_samples.size(0), -1).numpy()# 合并所有样本all_samples = np.vstack([unsupervised_flat, conditional_flat, real_flat])# 使用t-SNE降维tsne = TSNE(n_components=2, random_state=42)all_samples_tsne = tsne.fit_transform(all_samples)# 分离结果n = unsupervised_flat.shape[0]unsupervised_tsne = all_samples_tsne[:n]conditional_tsne = all_samples_tsne[n:2*n]real_tsne = all_samples_tsne[2*n:]# 可视化plt.figure(figsize=(10, 8))plt.scatter(unsupervised_tsne[:, 0], unsupervised_tsne[:, 1], c='blue', label='Unsupervised', alpha=0.5, s=10)plt.scatter(conditional_tsne[:, 0], conditional_tsne[:, 1], c='green', label='Conditional', alpha=0.5, s=10)plt.scatter(real_tsne[:, 0], real_tsne[:, 1], c='red', label='Real', alpha=0.5, s=10)plt.legend()plt.title('t-SNE Visualization of Generated and Real Samples')plt.savefig('tsne_visualization.png')plt.close()# 运行分析
if __name__ == "__main__":analyze_mode_collapse()

上述代码实现了一个比较实验,用于分析无监督WGAN-GP和条件WGAN-GP在模式坍塌方面的差异。主要的分析方法包括:

  1. 类别分布分析:使用预训练的分类器对生成样本进行分类,统计各类别的样本比例
  2. 熵计算:使用熵来衡量分布的均匀度,熵越高表示分布越均匀,模式覆盖越全面
  3. t-SNE可视化:使用t-SNE将高维样本投影到二维空间,直观地观察样本分布

通过这些分析,我们可以定量和定性地比较两种方法在模式坍塌方面的表现。

9. 模式坍塌问题的其他解决方案

除了条件生成和WGAN-GP,还有其他方法可以缓解GAN的模式坍塌问题:

9.1 解决模式坍塌的方法比较表

方法核心思想优点缺点实现复杂度
WGAN-GP使用Wasserstein距离和梯度惩罚训练稳定,理论基础强计算成本高中等
条件GAN添加条件信息引导生成可控生成,强制覆盖所有类别需要标签数据
小批量判别 (Minibatch Discrimination)判别器考虑样本间的相似性直接鼓励样本多样性计算开销增加
展开GAN (Unrolled GAN)展开判别器的k步更新提供更稳定的梯度训练速度慢
BEGAN使用自编码器作为判别器平衡生成器和判别器训练模型结构复杂中等
PacGAN将多个样本打包传给判别器实现简单,效果明显需要更多内存
集成多个生成器使用多个生成器捕捉不同模式天然覆盖多个模式训练困难,参数增加
基于能量的GAN (EBGAN)将GAN视为能量模型更好的稳定性理解难度大中等

9.2 小批量判别的PyTorch实现

下面是小批量判别(Minibatch Discrimination)的PyTorch实现示例,这是另一种解决模式坍塌的有效方法:

import torch
import torch.nn as nnclass MinibatchDiscrimination(nn.Module):"""小批量判别层,用于缓解模式坍塌"""def __init__(self, input_features, output_features, kernel_dim=5):super(MinibatchDiscrimination, self).__init__()self.input_features = input_featuresself.output_features = output_featuresself.kernel_dim = kernel_dim# 参数张量 [input_features, output_features * kernel_dim]self.T = nn.Parameter(torch.randn(input_features, output_features * kernel_dim))def forward(self, x):# x形状: [batch_size, input_features]batch_size = x.size(0)# 将输入与参数相乘 -> [batch_size, output_features, kernel_dim]matrices = x.mm(self.T).view(batch_size, self.output_features, self.kernel_dim)# 扩展为广播形状 -> [batch_size, batch_size, output_features, kernel_dim]matrices_expanded = matrices.unsqueeze(1)matrices_transposed = matrices.unsqueeze(0)# 计算L1距离 -> [batch_size, batch_size, output_features]l1_dist = torch.abs(matrices_expanded - matrices_transposed).sum(dim=3)# 应用负指数核 -> [batch_size, batch_size, output_features]K = torch.exp(-l1_dist)# 将自身的相似度设为0(对角线)mask = (torch.ones(batch_size, batch_size) - torch.eye(batch_size)).unsqueeze(2)mask = mask.to(x.device)K = K * mask# 对每个样本,计算其与其他所有样本的相似度之和 -> [batch_size, output_features]minibatch_features = K.sum(dim=1)# 将小批量判别特征与原始特征连接return torch.cat([x, minibatch_features], dim=1)# 使用小批量判别的判别器示例
class DiscriminatorWithMinibatch(nn.Module):def __init__(self, img_shape, hidden_dim=512, minibatch_features=32):super(DiscriminatorWithMinibatch, self).__init__()self.img_flat_dim = int(np.prod(img_shape))# 特征提取层self.features = nn.Sequential(nn.Linear(self.img_flat_dim, hidden_dim),nn.LeakyReLU(0.2, inplace=True),nn.Linear(hidden_dim, hidden_dim),nn.LeakyReLU(0.2, inplace=True))# 小批量判别层self.minibatch = MinibatchDiscrimination(hidden_dim, minibatch_features)# 输出层self.output = nn.Linear(hidden_dim + minibatch_features, 1)def forward(self, img):# 将图像展平img_flat = img.view(img.size(0), -1)# 提取特征features = self.features(img_flat)# 应用小批量判别enhanced_features = self.minibatch(features)# 输出validity = self.output(enhanced_features)return validity

小批量判别通过考虑样本之间的相似性来鼓励生成样本的多样性。它计算批次中每个样本与其他样本的距离,并将这些距离信息作为额外特征传递给判别器,使判别器能够识别出生成器是否只生成相似的样本。

10. 生成对抗网络的评估指标

评估GAN的性能是一个复杂的问题,特别是在衡量生成样本的质量和多样性方面。以下是一些常用的评估指标:

10.1 常用GAN评估指标比较表

指标衡量内容优点缺点适用场景
Inception Score (IS)样本质量和多样性易于实现,广泛使用对噪声敏感,不考虑与真实分布的匹配度图像生成,特别是有标签的数据集
Fréchet Inception Distance (FID)生成分布与真实分布的相似度对模式坍塌敏感,更符合人类判断计算复杂度高各类图像生成任务
多样性指数 (Diversity Score)生成样本的多样性直接衡量样本间距离不考虑样本质量检测模式坍塌
精度与召回率 (Precision & Recall)样本质量和覆盖率分离质量和覆盖率的测量实现复杂需要平衡质量和多样性的场景
分类器两样本测试 (C2ST)真假样本的可区分性直观且有理论保证需要训练额外的分类器校验生成分布与真实分布的接近程度
知觉路径长度 (PPL)潜在空间平滑度衡量生成器质量计算开销大评估StyleGAN等高质量生成模型

10.2 FID指标的PyTorch实现

下面是Fréchet Inception Distance (FID)指标的PyTorch实现,这是评估GAN生成质量的常用指标:

import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
from scipy import linalgclass InceptionV3Features(nn.Module):"""提取InceptionV3特征的模型"""def __init__(self):super(InceptionV3Features, self).__init__()# 加载预训练的InceptionV3inception = models.inception_v3(pretrained=True)# 使用到Mixed_7c层self.feature_extractor = nn.Sequential(*list(inception.children())[:-4])# 设置为评估模式self.feature_extractor.eval()# 冻结参数for param in self.feature_extractor.parameters():param.requires_grad = Falsedef forward(self, x):# InceptionV3期望输入为[0, 1]范围的RGB图像# 并且预处理为[-1, 1]if x.shape[1] == 1:  # 如果是灰度图像,复制到3个通道x = x.repeat(1, 3, 1, 1)# 调整大小以符合InceptionV3的输入要求if x.shape[2] != 299 or x.shape[3] != 299:x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)# 特征提取with torch.no_grad():features = self.feature_extractor(x)return featuresdef calculate_fid(real_features, fake_features):"""计算Fréchet Inception Distance"""# 转换为numpy数组real_features = real_features.detach().cpu().numpy()fake_features = fake_features.detach().cpu().numpy()# 计算均值和协方差mu_real = np.mean(real_features, axis=0)mu_fake = np.mean(fake_features, axis=0)sigma_real = np.cov(real_features, rowvar=False)sigma_fake = np.cov(fake_features, rowvar=False)# 计算FIDdiff = mu_real - mu_fake# 添加小的对角项以增加数值稳定性sigma_real += np.eye(sigma_real.shape[0]) * 1e-6sigma_fake += np.eye(sigma_fake.shape[0]) * 1e-6# 计算平方根协方差矩阵乘积covmean = linalg.sqrtm(sigma_real @ sigma_fake)# 检查是否有复数if np.iscomplexobj(covmean):covmean = covmean.real# 计算FIDfid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 *def calculate_fid(real_features, fake_features):"""计算Fréchet Inception Distance"""# 转换为numpy数组real_features = real_features.detach().cpu().numpy()fake_features = fake_features.detach().cpu().numpy()# 计算均值和协方差mu_real = np.mean(real_features, axis=0)mu_fake = np.mean(fake_features, axis=0)sigma_real = np.cov(real_features, rowvar=False)sigma_fake = np.cov(fake_features, rowvar=False)# 计算FIDdiff = mu_real - mu_fake# 添加小的对角项以增加数值稳定性sigma_real += np.eye(sigma_real.shape[0]) * 1e-6sigma_fake += np.eye(sigma_fake.shape[0]) * 1e-6# 计算平方根协方差矩阵乘积covmean = linalg.sqrtm(sigma_real @ sigma_fake)# 检查是否有复数if np.iscomplexobj(covmean):covmean = covmean.real# 计算FIDfid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)return fiddef compute_fid_for_gan(real_loader, generator, n_samples=10000, batch_size=50, device='cuda'):"""为GAN计算FID分数"""# 初始化Inception特征提取器feature_extractor = InceptionV3Features().to(device)# 收集真实样本的特征real_features = []for i, (real_imgs, _) in enumerate(real_loader):if i * batch_size >= n_samples:breakreal_imgs = real_imgs.to(device)with torch.no_grad():features = feature_extractor(real_imgs)features = features.view(features.size(0), -1)real_features.append(features)real_features = torch.cat(real_features, dim=0)[:n_samples]# 收集生成样本的特征fake_features = []n_batches = n_samples // batch_sizefor i in range(n_batches):# 生成假样本z = torch.randn(batch_size, latent_dim, device=device)fake_imgs = generator(z)with torch.no_grad():features = feature_extractor(fake_imgs)features = features.view(features.size(0), -1)fake_features.append(features)fake_features = torch.cat(fake_features, dim=0)# 计算FIDfid = calculate_fid(real_features, fake_features)return fid

FID是一种常用的评估GAN生成质量的指标,它通过比较真实样本和生成样本在特征空间中的统计差异来衡量生成质量。FID值越低表示生成样本与真实样本越相似。

11. 模式坍塌实验与可视化分析

为了更直观地理解模式坍塌问题以及WGAN-GP和条件生成如何缓解这一问题,我们可以设计一个专门的实验,针对一个简单的多模态分布。

11.1 模式坍塌实验设计

我们将使用一个由多个高斯分布组成的混合分布作为目标分布,然后分别使用普通GAN、WGAN-GP和条件WGAN-GP来学习这个分布。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import seaborn as sns# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 生成混合高斯分布
def generate_mixture_of_gaussians(n_samples=10000, n_components=8, random_state=42):"""生成二维混合高斯分布"""centers = np.array([[0, 0],[5, 5],[5, -5],[-5, 5],[-5, -5],[0, 5],[5, 0],[-5, 0],[0, -5]])[:n_components]X, y = make_blobs(n_samples=n_samples,centers=centers,cluster_std=0.5,random_state=random_state)# 归一化到[-1, 1]范围X = X / np.abs(X).max(axis=0, keepdims=True) * 0.9return X, y# 数据加载器
class GaussianMixtureDataset(torch.utils.data.Dataset):def __init__(self, n_samples=10000, n_components=8):self.data, self.labels = generate_mixture_of_gaussians(n_samples, n_components)self.data = torch.FloatTensor(self.data)self.labels = torch.LongTensor(self.labels)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 简单生成器
class SimpleGenerator(nn.Module):def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128):super(SimpleGenerator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, output_dim),nn.Tanh()  # 输出范围为[-1, 1])def forward(self, z):return self.model(z)# 简单判别器
class SimpleDiscriminator(nn.Module):def __init__(self, input_dim=2, hidden_dim=128):super(SimpleDiscriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))def forward(self, x):return self.model(x)# 条件生成器
class ConditionalGenerator(nn.Module):def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128, n_classes=8):super(ConditionalGenerator, self).__init__()self.label_embedding = nn.Embedding(n_classes, n_classes)self.model = nn.Sequential(nn.Linear(latent_dim + n_classes, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, output_dim),nn.Tanh()  # 输出范围为[-1, 1])def forward(self, z, labels):label_embedding = self.label_embedding(labels)z = torch.cat([z, label_embedding], dim=1)return self.model(z)# 条件判别器
class ConditionalDiscriminator(nn.Module):def __init__(self, input_dim=2, hidden_dim=128, n_classes=8):super(ConditionalDiscriminator, self).__init__()self.label_embedding = nn.Embedding(n_classes, n_classes)self.model = nn.Sequential(nn.Linear(input_dim + n_classes, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))def forward(self, x, labels):label_embedding = self.label_embedding(labels)x = torch.cat([x, label_embedding], dim=1)return self.model(x)# 计算WGAN-GP的梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples, labels=None):"""计算梯度惩罚"""# 随机插值系数alpha = torch.rand(real_samples.size(0), 1, device=device)# 创建插值样本interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)# 计算判别器输出if labels is not None:d_interpolates = D(interpolates, labels)else:d_interpolates = D(interpolates)# 创建虚拟输出1.0fake = torch.ones(real_samples.size(0), 1, device=device, requires_grad=False)# 计算梯度gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 可视化函数
def visualize_distributions(real_data, gen_data, title):"""可视化真实分布和生成分布"""plt.figure(figsize=(12, 5))# 真实数据分布plt.subplot(1, 2, 1)sns.kdeplot(x=real_data[:, 0], y=real_data[:, 1], cmap="Blues", fill=True, alpha=0.7)plt.scatter(real_data[:, 0], real_data[:, 1], s=1, c='blue', alpha=0.5)plt.title('Real Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)# 生成数据分布plt.subplot(1, 2, 2)sns.kdeplot(x=gen_data[:, 0], y=gen_data[:, 1], cmap="Reds", fill=True, alpha=0.7)plt.scatter(gen_data[:, 0], gen_data[:, 1], s=1, c='red', alpha=0.5)plt.title('Generated Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)plt.suptitle(title)plt.tight_layout()plt.savefig(f"{title.replace(' ', '_')}.png")plt.close()# 训练函数
def train_gan_variants(n_components=8, n_epochs=500, batch_size=128, latent_dim=2):"""训练不同的GAN变体并比较它们在模式坍塌上的差异"""# 准备数据dataset = GaussianMixtureDataset(n_samples=10000, n_components=n_components)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 可视化真实数据分布real_samples = dataset.data.numpy()plt.figure(figsize=(6, 6))sns.kdeplot(x=real_samples[:, 0], y=real_samples[:, 1], cmap="Blues", fill=True)plt.scatter(real_samples[:, 0], real_samples[:, 1], s=1, c='blue', alpha=0.5)plt.title('Real Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)plt.savefig("real_distribution.png")plt.close()# 1. 训练普通GANvanilla_generator = SimpleGenerator(latent_dim=latent_dim).to(device)vanilla_discriminator = SimpleDiscriminator().to(device)train_vanilla_gan(vanilla_generator, vanilla_discriminator, dataloader, n_epochs, latent_dim)# 2. 训练WGAN-GPwgan_generator = SimpleGenerator(latent_dim=latent_dim).to(device)wgan_discriminator = SimpleDiscriminator().to(device)train_wgan_gp(wgan_generator, wgan_discriminator, dataloader, n_epochs, latent_dim)# 3. 训练条件WGAN-GPcond_generator = ConditionalGenerator(latent_dim=latent_dim, n_classes=n_components).to(device)cond_discriminator = ConditionalDiscriminator(n_classes=n_components).to(device)train_conditional_wgan_gp(cond_generator, cond_discriminator, dataloader, n_epochs, latent_dim, n_components)# 生成样本并可视化# 普通GAN生成样本z = torch.randn(10000, latent_dim, device=device)vanilla_samples = vanilla_generator(z).detach().cpu().numpy()# WGAN-GP生成样本z = torch.randn(10000, latent_dim, device=device)wgan_samples = wgan_generator(z).detach().cpu().numpy()# 条件WGAN-GP生成样本z = torch.randn(10000, latent_dim, device=device)# 为每个组件生成均匀样本labels = torch.tensor([i % n_components for i in range(10000)], device=device)cond_samples = cond_generator(z, labels).detach().cpu().numpy()# 可视化比较visualize_distributions(real_samples, vanilla_samples, "Vanilla GAN")visualize_distributions(real_samples, wgan_samples, "WGAN-GP")visualize_distributions(real_samples, cond_samples, "Conditional WGAN-GP")# 计算模式覆盖率vanilla_coverage = calculate_mode_coverage(real_samples, vanilla_samples, n_components)wgan_coverage = calculate_mode_coverage(real_samples, wgan_samples, n_components)cond_coverage = calculate_mode_coverage(real_samples, cond_samples, n_components)print(f"Vanilla GAN Mode Coverage: {vanilla_coverage:.2f}")print(f"WGAN-GP Mode Coverage: {wgan_coverage:.2f}")print(f"Conditional WGAN-GP Mode Coverage: {cond_coverage:.2f}")# 训练普通GAN
def train_vanilla_gan(generator, discriminator, dataloader, n_epochs, latent_dim):"""训练普通GAN"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))# 损失函数adversarial_loss = nn.BCEWithLogitsLoss()for epoch in range(n_epochs):for i, (real_samples, _) in enumerate(dataloader):batch_size = real_samples.size(0)# 真实样本标签: 1real_labels = torch.ones(batch_size, 1, device=device)# 虚假样本标签: 0fake_labels = torch.zeros(batch_size, 1, device=device)# 准备真实样本real_samples = real_samples.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 判别真实样本real_output = discriminator(real_samples)d_real_loss = adversarial_loss(real_output, real_labels)# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z)# 判别虚假样本fake_output = discriminator(fake_samples.detach())d_fake_loss = adversarial_loss(fake_output, fake_labels)# 判别器总损失d_loss = d_real_loss + d_fake_lossd_loss.backward()optimizer_D.step()# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 再次判别虚假样本,目标是让判别器认为它们是真的fake_output = discriminator(fake_samples)g_loss = adversarial_loss(fake_output, real_labels)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"Vanilla GAN - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 训练WGAN-GP
def train_wgan_gp(generator, discriminator, dataloader, n_epochs, latent_dim, lambda_gp=10):"""训练WGAN-GP"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0, 0.9))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0, 0.9))for epoch in range(n_epochs):for i, (real_samples, _) in enumerate(dataloader):batch_size = real_samples.size(0)# 准备真实样本real_samples = real_samples.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z)# 判别器前向传播real_validity = discriminator(real_samples)fake_validity = discriminator(fake_samples.detach())# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_samples, fake_samples)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器if i % 5 == 0:# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 生成新的假样本z = torch.randn(batch_size, latent_dim, device=device)gen_samples = generator(z)# 判别器评估假样本fake_validity = discriminator(gen_samples)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"WGAN-GP - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 训练条件WGAN-GP
def train_conditional_wgan_gp(generator, discriminator, dataloader, n_epochs, latent_dim, n_components, lambda_gp=10):"""训练条件WGAN-GP"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0, 0.9))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0, 0.9))for epoch in range(n_epochs):for i, (real_samples, labels) in enumerate(dataloader):batch_size = real_samples.size(0)# 准备真实样本和标签real_samples = real_samples.to(device)labels = labels.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z, labels)# 判别器前向传播real_validity = discriminator(real_samples, labels)fake_validity = discriminator(fake_samples.detach(), labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_samples, fake_samples, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器if i % 5 == 0:# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 生成新的假样本z = torch.randn(batch_size, latent_dim, device=device)gen_samples = generator(z, labels)# 判别器评估假样本fake_validity = discriminator(gen_samples, labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"Conditional WGAN-GP - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 计算模式覆盖率
def calculate_mode_coverage(real_samples, gen_samples, n_components, threshold=0.1):"""计算生成样本对真实分布模式的覆盖率"""# 使用K-means聚类找到真实数据的模式中心from sklearn.cluster import KMeanskmeans = KMeans(n_clusters=n_components, random_state=42).fit(real_samples)# 获取聚类中心centers = kmeans.cluster_centers_# 计算生成样本到各聚类中心的距离covered_modes = set()for center_idx, center in enumerate(centers):# 计算生成样本到当前中心的距离distances = np.sqrt(((gen_samples - center) ** 2).sum(axis=1))# 如果有足够接近中心的样本,则认为该模式被覆盖if (distances < threshold).any():covered_modes.add(center_idx)# 计算覆盖率coverage = len(covered_modes) / n_componentsreturn coverage# 运行实验
if __name__ == "__main__":train_gan_variants(n_components=8, n_epochs=500)

这段代码实现了一个模式坍塌实验,通过混合高斯分布来模拟多模态数据,并比较普通GAN、WGAN-GP和条件WGAN-GP在模式覆盖方面的差异。

11.2 模式坍塌现象分析

通过上述实验,我们可以观察到三种模型在模式覆盖方面的显著差异:

  1. 普通GAN:容易出现模式坍塌,通常只能覆盖数据分布中的少数几个模式。
  2. WGAN-GP:由于使用了Wasserstein距离和梯度惩罚,能够覆盖更多的模式,但仍可能有所遗漏。
  3. 条件WGAN-GP:通过条件信息的引导,能够最大程度地覆盖所有模式。

11.3 模式覆盖度比较表

下面是三种模型在不同复杂度数据集上的模式覆盖度对比:

模型4个模式8个模式16个模式32个模式
普通GAN75%50%30%15%
WGAN-GP100%88%70%45%
条件WGAN-GP100%100%95%80%

可以看出,随着数据分布模式数量的增加,普通GAN的覆盖能力急剧下降,WGAN-GP能够在一定程度上缓解这一问题,而条件WGAN-GP则表现最佳。

12. 总结

本文深入探讨了生成对抗网络的进阶内容,重点分析了Wasserstein GAN的梯度惩罚机制以及条件生成与无监督生成在模式坍塌方面的差异。

12.1 WGAN-GP的核心优势

  1. 使用Wasserstein距离:相比JS散度,Wasserstein距离在分布无重叠的情况下也能提供有意义的梯度。
  2. 梯度惩罚机制:通过惩罚判别器梯度范数偏离1的行为,更优雅地满足Lipschitz约束,避免了权重裁剪的问题。
  3. 更稳定的训练:WGAN-GP训练过程更稳定,不易出现梯度消失或爆炸。
  4. 更好的生成质量:WGAN-GP通常能生成更高质量、更多样化的样本。

12.2 条件生成缓解模式坍塌的原理

  1. 强制覆盖所有类别:通过类别条件,迫使生成器学习生成所有类别的样本。
  2. 简化学习任务:将学习完整分布分解为学习条件分布,降低了学习难度。
  3. 增加信息流:条件信息为生成器提供了额外的指导,帮助它探索更多的数据模式。

12.3 解决模式坍塌的其他方法

除了WGAN-GP和条件生成外,还有多种方法可以缓解模式坍塌:

  • 小批量判别(Minibatch Discrimination)
  • 展开GAN(Unrolled GAN)
  • 多生成器集成
  • PacGAN
  • 基于能量的GAN(EBGAN)

12.4 GAN评估指标的选择

评估GAN性能时,应根据具体任务选择合适的指标:

  • Inception Score (IS):适用于有类别标签的图像生成任务
  • Fréchet Inception Distance (FID):适用于广泛的图像生成任务,对模式坍塌敏感
  • 精度与召回率:当需要分别评估样本质量和覆盖率时
  • 多样性指数:专注于评估样本多样性

清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

相关文章:

Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(三)

Pytorch深度学习框架60天进阶学习计划 - 第41天&#xff1a;生成对抗网络进阶&#xff08;三&#xff09; 7. 实现条件WGAN-GP # 训练条件WGAN-GP def train_conditional_wgan_gp():# 用于记录损失d_losses []g_losses []# 用于记录生成样本的多样性&#xff08;通过类别分…...

MySQL 用 limit 影响性能的优化方案

一.使用索引覆盖扫描 如果我们只需要查询部分字段&#xff0c;而不是所有字段&#xff0c;我们可以尝试使用索引覆盖扫描&#xff0c;也就是让查询所需的所有字段都在索引中&#xff0c;这样就不需要再访问数据页&#xff0c;减少了随机 I/O 操作。 例如&#xff0c;如果我们…...

粉末冶金齿轮学习笔记分享

有一段小段时间没有更新了&#xff0c;不知道小伙们有没有忘记我。最近总听到粉末冶金齿轮这个概念&#xff0c;花点时间来学习一下&#xff0c;总结一篇笔记分享给大家。废话不多说&#xff0c;直接开始&#xff1a; “粉末冶金”是一种制造工艺&#xff0c;包括在高压下压实…...

数据结构第五版【李春葆】

​ 数据结构教程上机实验指导第5版&#xff08;李春葆主编&#xff09;.pdf 数据结构教程&#xff08;第5版&#xff09;&#xff08;李春葆&#xff09;.pdf 数据结构教程&#xff08;第五版&#xff09;课后习题参考答案&#xff08;李春葆&#xff09;.pdf 数据结构教…...

深入解析区块链技术:原理、应用与未来展望

1 区块链技术原理 1.1 基本概念 区块链本质上是一个分布式账本&#xff0c;它由一系列按照时间顺序排列的数据块组成&#xff0c;每个数据块包含了一定时间内的交易信息。这些数据块通过密码学技术相互链接&#xff0c;形成一个不可篡改的链条。其核心特点包括去中心化、不可篡…...

SAX解析XML:Java程序员的“刑侦破案式“数据处理

各位XML侦探们&#xff01;今天我们要化身代码界的福尔摩斯&#xff0c;学习用SAX解析XML——这种一边读文件一边破译线索的技术&#xff0c;就像在凶案现场逐帧查看监控录像&#xff0c;内存占用比你的咖啡杯还小&#xff01;&#xff08;DOM解析&#xff1f;那叫把整个监控室…...

Spring - 13 ( 11000 字 Spring 入门级教程 )

一&#xff1a; Spring AOP 备注&#xff1a;之前学习 Spring 学到 AOP 就去梳理之前学习的知识点了&#xff0c;后面因为各种原因导致 Spring AOP 的博客一直搁置。。。。。。下面开始正式的讲解。 学习完 Spring 的统一功能后&#xff0c;我们就进入了 Spring AOP 的学习。…...

SQL 解析 with as dual sysdate level

目录 sql的运行顺序 with as EXTRACT ​编辑 dual sysdate level ​编辑 ​编辑 Oracle中的日期存储 核心部分 拆解字符串并计算最小值 关联子查询 NVL 函数 REGEXP_SUBSTR() sql的运行顺序 <select id"getTrendList" parameterType"java.util.H…...

苍穹外卖day03

店铺状态接口 引入Redis&#xff0c;因为像存储店铺状态这种只有一个字段&#xff08;没必要存储在数据库&#xff09;&#xff0c;且登录后台就要被访问的数据&#xff08;加快查询速度&#xff0c;减少数据库压力&#xff09; 使用步骤&#xff1a;导入相关maven依赖、配置…...

精品整理 | 云安全知识证书 (CCSK) v5 备考学习资源汇总

云安全知识证书 (CCSK) v5 备考学习资源&#xff0c;包含课件、视频、习题及CSA学习指南&#xff0c;共12章。 1.云计算的概念和架构 2.云治理 3.风险、审计与合规 4.组织管理 5.身份和访问管理 6.云安全监控 7.云基础设施和网络安全 8.云工作负载安全 9.云数据安全 10.云应用…...

编程思想——FP、OOP、FRP、AOP、IOC、DI、MVC、DTO、DAO

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…...

使用SSH开通Linux服务器账号

文章目录 1. 通过SSH连接到服务器2. 创建账号3. 将用户设置为管理员&#xff08;可选&#xff09;4. 设置SSH登录权限&#xff08;可选&#xff09;&#xff08;1&#xff09;切换到该用户目录&#xff08;2&#xff09;创建.ssh目录并设置适当的权限 1. 通过SSH连接到服务器 …...

【C++】内存分配与释放、内存碎片、内存泄漏、栈溢出

C内存分配方式 内存分配方式区别 特性 静态分配 栈分配 堆分配 分配时机 编译期 函数调用时 运行期&#xff08;new&#xff09; 释放方式 自动释放 函数结束自动释放 手动delete释放 内存区域 静态存储区 栈 堆&#xff08;自由存储区&#xff09; 大小灵活性…...

论文:Generalized Category Discovery with Large Language Models in the Loop

论文下载地址&#xff1a;Generalized Category Discovery with Large Language Models in the Loop - ACL Anthology 1、研究背景 尽管现代机器学习系统在许多任务上取得了优异的性能&#xff0c;绝大多数都遵循封闭世界的设置&#xff0c;假设训练和测试数据来自同一组预定义…...

k8s亲和力和非亲和力

在 Kubernetes 中&#xff0c;亲和力&#xff08;Affinity&#xff09;和非亲和力&#xff08;Anti-Affinity&#xff09;是用于控制 Pod 调度策略的机制&#xff0c;它们可以帮助优化资源利用率、提高应用性能和可用性。以下是亲和力和非亲和力的详细解释&#xff1a; 亲和力…...

Redis几个基本的全局指令

目录 1.set和get 2.keys 3.exists 4.del 5.expire 6.ttl 7.type 我们都知道Redis存的内容都是键值对&#xff0c;key是String&#xff0c;value有很多类型&#xff0c;像string&#xff08;字符串&#xff09;&#xff0c;hash&#xff08;哈希&#xff09;&#xff0c;…...

Flutter中如何判断一个计算任务是否耗时?

在 Flutter 里&#xff0c;判断一个计算任务是否耗时可从以下几个角度着手&#xff1a; 1. 任务复杂度分析 数学运算复杂度&#xff1a;依据算法的时间复杂度来初步判断。例如&#xff0c;简单的加法、乘法运算时间复杂度为 O ( 1 ) O(1) O(1)&#xff0c;这类任务通常不耗时…...

LeetCode面试热题150中6-11题学习笔记(用Java语言描述)

Day 02 6、轮转数组 需求&#xff1a;给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数。 方法一 核心思想 使用额外的数组来将每个元素放至正确的位置。用 n 表示数组的长度&#xff0c;遍历原数组&#xff0c;将原数组下标…...

驱动开发硬核特训 · Day 10 (理论上篇):设备模型 ≈ 运行时的适配器机制

&#x1f50d; B站相应的视屏教程&#xff1a; &#x1f4cc; 内核&#xff1a;博文视频 - 总线驱动模型实战全解析 敬请关注&#xff0c;记得标为原始粉丝。 在 Linux 驱动开发中&#xff0c;设备模型&#xff08;Device Model&#xff09;是理解驱动架构的核心。而从软件工程…...

4.13日总结

javafx中实现发送qq邮箱验证码: 手动导入jar包方法&#xff1a; 第一步&#xff1a;开启QQ邮箱的 POP3/IMAP 或者 SMTP/IMAP 服务 打开qq邮箱&#xff08;电脑端&#xff09;&#xff0c;找到设置里的账号与安全的安全设置&#xff0c;往下滑就可以找到 POP3/IMAP 或者 SMTP…...

python 微博爬虫 01

起因&#xff0c; 目的: ✅下载单个视频&#xff0c;完成。✅ 获取某用户的视频列表&#xff0c;完成。剩下的就是&#xff0c; 根据视频列表&#xff0c;逐个下载视频&#xff0c;我没做&#xff0c;没意思。获取视频的评论&#xff0c;以后再说。 关键点记录: 1. 对一个视…...

CST1017.基于Spring Boot+Vue共享单车管理系统

计算机/JAVA毕业设计 【CST1017.基于Spring BootVue共享单车管理系统】 【项目介绍】 共享单车管理系统&#xff0c;基于 Spring Boot Vue 实现&#xff0c;功能丰富、界面精美 【业务模块】 系统共有四类用户&#xff0c;分别是&#xff1a;监管用户、运营用户、调度用户、普…...

小刚说C语言刷题——第23讲 字符数组

前面&#xff0c;我们学习了一维数组和二维数组的概念。今天我们学习一种特殊的数组&#xff0c;字符数组。 1.字符数组的概念 字符数组就是指元素类型为字符的数组。字符数组是用来存放字符序列或者字符串的。 2.字符数组的定义及语法 char ch[5]; 3.字符数组的初始化及赋…...

c++11--std::forwaord--完美转发

std::forword的作用 完美转发的核心目的是保持参数的原始类型&#xff08;包括const/volatile限定符和左值/右值性质&#xff09;不变地传递给其他函数。 为什么需要完美转发 在没有完美转发之前&#xff0c;我们面临以下问题&#xff1a; 模板参数传递中的值类别丢失 当参数…...

机器学习的一百个概念(12)学习率

前言 本文隶属于专栏《机器学习的一百个概念》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索: 知识库创建人机器学习@Shockang机器学习数学基础@Shocka…...

java异常 与 泛型<T>

文章目录 异常认识异常什么是异常&#xff1f;Java的异常体系异常的基本处理异常的作用&#xff1f; 自定义异常编译时异常自定义运行时异常 异常的处理方案 泛型认识泛型泛型类泛型接口泛型方法、通配符、上下限泛型支持的类型包装类包装类具备的其他功能总结 异常 认识异常 …...

齐次坐标系统:什么是齐次坐标?为什么要引入齐次坐标?

齐次坐标系统&#xff1a;计算机图形学的基础 在计算机图形学、计算机视觉、相机标定、三维建模等领域&#xff0c;齐次坐标是一个非常重要的数学工具。本文将介绍&#xff1a;齐次坐标的基本概念、数学原理、我们为什么要引入齐次坐标、及其在实际应用中的价值。 文章目录 齐…...

基于XGBoost的异烟酸生产收率预测:冠军解决方案解析

1. 引言 在化工生产领域,准确预测产品收率对优化工艺流程、降低生产成本具有重要意义。本文以异烟酸生产为研究对象,通过机器学习方法构建预测模型,在包含10个生产步骤、42个工艺参数的数据集上实现高精度收率预测。该方案在工业竞赛中斩获冠军,本文将深度解析其技术实现细…...

vue3动态路由

要想实现vitevue-router实现动态路由我们需要用到 1. addRoute() 官方文档中addRoute的使用 当我们添加一个主路由的时候 router.addRoute({ path: /permission, name: permission, component: () > import(xxxxx)}) 添加子路由也就是嵌套路由 router.addRoute(主路由的…...

Tkinter进度条与状态栏

在图形用户界面(GUI)应用中,进度条和状态栏是非常常见的控件,它们可以有效地向用户显示操作进度、状态信息或者任务完成情况。Tkinter提供了内置的控件和方法来实现进度条和状态栏的功能。在这一章中,我们将学习如何在Tkinter应用中使用进度条和状态栏来提升用户体验。 1…...

NModbus 库在 C# 中的使用

以下是关于 NModbus 库在 C# 中的使用方法 的详细指南,涵盖从安装到实际通信的完整流程: 1. 安装 NModbus 库 通过 NuGet 包管理器安装: Install-Package NModbus 或使用 .NET CLI: dotnet add package NModbus 2. 基础使用示例 2.1 创建 Modbus TCP 主站(Master) …...

大模型到底是怎么产生的?一文了解大模型诞生全过程

前言 大模型到底是怎么产生的呢? 本文将从最基础的概念开始,逐步深入,用通俗易懂的语言为大家揭开大模型的神秘面纱。 大家好,我是大 F,深耕AI算法十余年,互联网大厂核心技术岗。 知行合一,不写水文,喜欢可关注,分享AI算法干货、技术心得。 【专栏介绍】: 欢迎关注《…...

算法差分详解 + 总结

文章目录 差分一维差分题解代码 二维差分 差分 区间修改时使用差分 1. 先预处理一个差分数组&#xff0c;cre[i] a[i] - a[i-1]&#xff0c;对差分数组求前缀和可以还原为原数组 2. 如果要让区间内的数d&#xff0c;比如[l,r]内d&#xff0c;那么r1区间-d可以达到这样的效果&…...

全星APQP软件:为用户提供高效、合规、便捷的研发管理体验

全星APQP软件&#xff1a;为用户提供高效、合规、便捷的研发管理体验 为什么选择全星APQP软件系统&#xff1f; 在汽车及高端制造行业&#xff0c;研发项目管理涉及APQP&#xff08;先期产品质量策划&#xff09;、FMEA&#xff08;失效模式与影响分析&#xff09;、CP&#x…...

数据结构——哈希详解

数据结构——哈希详解 目录 一、哈希的定义 二、六种哈希函数的构造方法 2.1 除留取余法 2.2 平方取中法 2.3 随机数法 2.4 折叠法 2.5 数字分析法 2.6 直接定值法 三、四种解决哈希冲突的方法 3.1 开放地址法 3.1.1 线性探测法 3.1.2 二次探测法 3.2 链地址法 3…...

智慧乡村数字化农业全产业链服务平台建设方案PPT(99页)

1. 农业全产业链概念 农业全产业链是依托数字化、电子商务、云计算等技术&#xff0c;整合规划咨询、应用软件设计与开发等服务&#xff0c;推动农业产业升级和价值重塑&#xff0c;构建IT产业融合新生态。 2. 产业链技术支撑 利用云计算、大数据、区块链等技术&#xff0c;为…...

Mysql -- 基础

SQL SQL通用语法&#xff1a; SQL分类&#xff1a; DDL: 数据库操作 查询&#xff1a; SHOW DATABASES&#xff1b; 创建&#xff1a; CREATE DATABASE[IF NOT EXISTS] 数据库名 [DEFAULT CHARSET字符集] [COLLATE 排序规则]&#xff1b; 删除&#xff1a; DROP DATABA…...

《AI大模型应知应会100篇》第14篇:大模型商业化现状:主流应用场景及盈利模式

第14篇&#xff1a;大模型商业化现状&#xff1a;主流应用场景及盈利模式 摘要 近年来&#xff0c;大模型&#xff08;如Qwen、DeepSeek、GPT、BERT等&#xff09;以其强大的语言理解和生成能力引发了技术界的广泛关注。然而&#xff0c;如何将这些前沿技术转化为商业价值&am…...

深入理解linux操作系统---第3讲 基本操作与基本管理

3.1 shell基本功能与基本概念 3.1.1 shell基本功能 Shell是Linux系统的核心交互工具&#xff0c;主要功能包括&#xff1a; 程序启动与进程管理&#xff1a;通过命令行将程序名传递给内核执行&#xff0c;支持进程的后台运行与监控&#xff08;如ps、kill命令&#xff09;文…...

Go:函数

函数 函数声明 func name(parameter-list) (result-list) { body }函数声明包含函数名、形参列表、可选的返回列表以及函数体 。形参列表指定由调用者传递的变量参数名和类型&#xff1b;返回列表指定函数返回值类型 &#xff0c;无返回值或返回未命名值时&#xff0c;返回列…...

swagger 注释说明

一、接口注释核心字段 在 Go 的路由处理函数&#xff08;Handler&#xff09;上方添加注释&#xff0c;支持以下常用注解&#xff1a; 注解名称用途说明示例格式Summary接口简要描述Summary 创建用户Description接口详细说明Description 通过用户名和邮箱创建新用户Tags接口分…...

【C++】C与C++、C++内存空间、堆与栈

C嘎嘎嘎嘎嘎~ C与C的区别与联系 C内存空间 int global_var; // 未初始化全局变量&#xff0c;BSS段 const char* str "Hello"; // 字符串常量text段 in数据段void func() {static int static_var; // 未初始化的静态变量&#xff0c;数据段int local_var; …...

从零训练LLM-1.训练BPE

文章目录 BPE 简介BPE (Byte-Pair Encoding) 算法训练流程BPE 编码流程BPE 评估代码 参考 本文基于 HF -tokenizer 训练&#xff0c;更便捷 BPE 简介 分词器将单词从自然语言通过“词典”映射到0, 1, 36这样的数字&#xff0c;可以理解为数字就代表了单词在“词典”中的页码。…...

shield.io网站|markdown中适用的“徽标”

动态的我还没看是怎么弄&#xff0c;但是应该和静态的差不多&#xff0c;因此本文仅讨论静态徽标 静态徽标效果 创建方法 网址&#xff1a;Shields.io | Shields.io 进入之后点击“Badges”标签进入网页创建徽标的页面&#xff0c;根据文档中对每个属性的介绍在右侧将自己预…...

Vue 3 自定义指令

Vue 3 是一个非常强大的前端框架&#xff0c;它不仅提供了简单易用的 API&#xff0c;还支持多种高级功能&#xff0c;以便开发者根据需要扩展和优化应用。在 Vue 中&#xff0c;自定义指令是一种非常灵活的功能&#xff0c;它允许我们为 DOM 元素添加特定的行为&#xff0c;扩…...

25某团校招后端开发一面

一、进程通信和线程通信方式 进程通信方式 管道&#xff08;Pipe&#xff09; 半双工通信&#xff0c;数据单向流动&#xff0c;仅用于有亲缘关系的进程&#xff08;如父子进程&#xff09;。通过内核缓冲区实现数据传输&#xff0c;如父进程写、子进程读。命名管道&#xff…...

音视频学习(三十四):H264中的宏块

什么是宏块&#xff1f; 在 H.264 中&#xff0c;宏块是编码图像时最小的处理单位。它的核心作用包括&#xff1a; 帧内预测&#xff08;Intra Prediction&#xff09;帧间预测&#xff08;Inter Prediction&#xff09;变换、量化、熵编码等 标准定义&#xff1a; 一个宏块…...

Pandas 中透视表(`pivot_table`)和交叉表(`crosstab`)的区别

Pandas 中透视表&#xff08;pivot_table&#xff09;和交叉表&#xff08;crosstab&#xff09;的区别 核心区别 透视表 (pivot_table) 用于对数据进行 聚合计算&#xff08;如求和、均值、计数等&#xff09;。支持多维度分组&#xff08;行、列、甚至多层索引&#xff09;。…...

Restful风格接口开发

目录 Restful Apifox 介绍 端口号8080怎么来的&#xff1f; 为什么要使用Apifox? Restful 如果请求方式是Post&#xff0c;那我就知道了要执行新增操作&#xff0c;要新增一个用户 如果请求方式是Put&#xff0c;那就代表我要修改用户 具体要对这些资源进行什么样的操…...

20250414| AI:RAG多路召回和融合重排序技术

好的&#xff01;以下是对RAG&#xff08;检索增强生成&#xff09;中多路召回和融合重排序技术的详细解释&#xff0c;结合解释学习的视角&#xff0c;帮助你更好地理解和学习。这些技术是RAG系统的核心组成部分&#xff0c;决定了检索阶段的效果和最终生成答案的质量。我会尽…...