Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(二)
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(二)
7. 实现条件WGAN-GP
# 训练条件WGAN-GP
def train_conditional_wgan_gp():# 用于记录损失d_losses = []g_losses = []# 用于记录生成样本的多样性(通过类别分布)class_distributions = []for epoch in range(n_epochs):for i, (real_imgs, labels) in enumerate(dataloader):real_imgs = real_imgs.to(device)labels = labels.to(device)batch_size = real_imgs.shape[0]# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 生成随机噪声z = torch.randn(batch_size, latent_dim, device=device)# 为生成器生成随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批假图像fake_imgs = generator(z, gen_labels)# 判别器前向传播real_validity = discriminator(real_imgs, labels)fake_validity = discriminator(fake_imgs.detach(), gen_labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器n_critic = 5if i % n_critic == 0:# ---------------------# 训练生成器# ---------------------optimizer_G.zero_grad()# 为生成器生成新的随机标签gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)# 生成一批新的假图像gen_imgs = generator(z, gen_labels)# 判别器评估假图像fake_validity = discriminator(gen_imgs, gen_labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if i % 50 == 0:print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 每个epoch结束后,评估生成样本的类别分布if (epoch + 1) % 10 == 0:class_dist = evaluate_class_distribution()class_distributions.append(class_dist)# 保存生成的图像样本save_sample_images(epoch)# 绘制损失曲线plt.figure(figsize=(10, 5))plt.plot(d_losses, label='Discriminator Loss')plt.plot(g_losses, label='Generator Loss')plt.xlabel('Iterations (x50)')plt.ylabel('Loss')plt.legend()plt.savefig('cond_wgan_gp_loss.png')plt.close()# 绘制类别分布变化plot_class_distributions(class_distributions)# 评估生成样本的类别分布
def evaluate_class_distribution():"""评估生成样本在各类别上的分布情况"""# 创建一个预训练的分类器classifier = torchvision.models.resnet18(pretrained=True)# 修改第一个卷积层以适应灰度图classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 修改最后的全连接层以适应MNIST的10个类别classifier.fc = nn.Linear(classifier.fc.in_features, 10)# 加载预先训练好的MNIST分类器权重(这里假设我们有一个预训练的模型)# classifier.load_state_dict(torch.load('mnist_classifier.pth'))# 简化起见,这里我们使用一个简单的CNN分类器classifier = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设这个简单分类器已经在MNIST上训练好了# 实际应用中,应该加载一个预先训练好的模型# 生成1000个样本z = torch.randn(1000, latent_dim, device=device)# 均匀采样所有类别gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)gen_imgs = generator(z, gen_labels)# 使用分类器预测类别with torch.no_grad():classifier.eval()preds = torch.softmax(classifier(gen_imgs), dim=1)pred_labels = torch.argmax(preds, dim=1)# 计算每个类别的样本数量class_counts = torch.zeros(10)for i in range(10):class_counts[i] = (pred_labels == i).sum().item() / 1000return class_counts.numpy()# 绘制类别分布变化
def plot_class_distributions(class_distributions):"""绘制生成样本类别分布的变化"""epochs = [10, 20, 30, 40, 50] # 假设每10个epoch评估一次plt.figure(figsize=(12, 8))for i, dist in enumerate(class_distributions):plt.subplot(len(class_distributions), 1, i+1)plt.bar(np.arange(10), dist)plt.ylabel(f'Epoch {epochs[i]}')plt.ylim(0, 0.3) # 限制y轴范围,便于比较if i == len(class_distributions) - 1:plt.xlabel('Digit Class')plt.tight_layout()plt.savefig('class_distribution.png')plt.close()# 保存样本图像(条件版本)
def save_sample_images(epoch):"""保存按类别排列的样本图像"""# 为每个类别生成样本n_row = 10 # 每个类别一行n_col = 10 # 每个类别10个样本fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))for i in range(n_row):# 固定类别fixed_class = torch.tensor([i] * n_col, device=device)# 随机噪声z = torch.randn(n_col, latent_dim, device=device)# 生成图像gen_imgs = generator(z, fixed_class).detach().cpu()# 转换到[0, 1]范围gen_imgs = 0.5 * gen_imgs + 0.5# 显示图像for j in range(n_col):axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')axs[i, j].axis('off')plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')plt.close()# 运行条件WGAN-GP训练
if __name__ == "__main__":train_conditional_wgan_gp()
上述代码实现了一个条件WGAN-GP模型,主要区别在于:
- 条件输入:生成器和判别器都接收类别标签作为额外输入
- 嵌入层:使用嵌入层将类别标签转换为嵌入向量
- 类别多样性评估:添加了评估生成样本类别分布的功能
- 可视化:按类别排列生成样本,便于观察每个类别的质量
8. 无监督与条件生成的模式坍塌对比实验
为了更直观地比较无监督生成和条件生成在模式坍塌方面的差异,我们可以设计一个实验,分别训练无监督WGAN-GP和条件WGAN-GP,然后比较它们生成样本的模式覆盖情况。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE# 假设我们已经训练好了无监督WGAN-GP和条件WGAN-GP模型
# 分别为 unsupervised_generator 和 conditional_generatordef analyze_mode_collapse():"""分析并比较无监督和条件生成在模式坍塌方面的差异"""# 生成样本数量n_samples = 1000# 1. 从无监督生成器生成样本z_unsupervised = torch.randn(n_samples, latent_dim, device=device)unsupervised_samples = unsupervised_generator(z_unsupervised).detach().cpu()# 2. 从条件生成器生成样本(均匀覆盖所有类别)z_conditional = torch.randn(n_samples, latent_dim, device=device)conditional_labels = torch.tensor([i % 10 for i in range(n_samples)], device=device)conditional_samples = conditional_generator(z_conditional, conditional_labels).detach().cpu()# 3. 获取真实MNIST样本real_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=n_samples, shuffle=True)real_samples, _ = next(iter(real_loader))# 4. 使用预训练的分类器分类所有样本classifier = create_mnist_classifier() # 假设我们有一个创建分类器的函数# 分类无监督生成的样本unsupervised_predictions = classify_samples(classifier, unsupervised_samples)# 分类条件生成的样本conditional_predictions = classify_samples(classifier, conditional_samples)# 分类真实样本real_predictions = classify_samples(classifier, real_samples)# 5. 计算各类别的样本分布unsupervised_distribution = compute_class_distribution(unsupervised_predictions)conditional_distribution = compute_class_distribution(conditional_predictions)real_distribution = compute_class_distribution(real_predictions)# 6. 计算分布的均匀度(使用熵)unsupervised_entropy = compute_entropy(unsupervised_distribution)conditional_entropy = compute_entropy(conditional_distribution)real_entropy = compute_entropy(real_distribution)print(f"无监督生成分布熵: {unsupervised_entropy:.4f}")print(f"条件生成分布熵: {conditional_entropy:.4f}")print(f"真实数据分布熵: {real_entropy:.4f}")# 7. 可视化样本分布visualize_distributions(unsupervised_distribution,conditional_distribution,real_distribution)# 8. 使用t-SNE将样本投影到二维空间进行可视化visualize_tsne(unsupervised_samples,conditional_samples,real_samples)def create_mnist_classifier():"""创建一个简单的MNIST分类器"""model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 7 * 7, 128),nn.ReLU(),nn.Linear(128, 10)).to(device)# 这里假设分类器已经训练好了# model.load_state_dict(torch.load('mnist_classifier.pth'))return modeldef classify_samples(classifier, samples):"""使用分类器对样本进行分类"""with torch.no_grad():classifier.eval()# 确保样本在正确的设备上samples = samples.to(device)# 前向传播logits = classifier(samples)# 获取预测类别predictions = torch.argmax(logits, dim=1)return predictions.cpu().numpy()def compute_class_distribution(predictions):"""计算类别分布"""n_samples = len(predictions)distribution = np.zeros(10)for i in range(10):distribution[i] = np.sum(predictions == i) / n_samplesreturn distributiondef compute_entropy(distribution):"""计算分布的熵,衡量分布的均匀度"""# 防止log(0)distribution = distribution + 1e-10# 归一化distribution = distribution / np.sum(distribution)# 计算熵entropy = -np.sum(distribution * np.log2(distribution))return entropydef visualize_distributions(unsupervised_dist, conditional_dist, real_dist):"""可视化三种样本的类别分布"""plt.figure(figsize=(12, 5))width = 0.25x = np.arange(10)plt.bar(x - width, unsupervised_dist, width, label='Unsupervised')plt.bar(x, conditional_dist, width, label='Conditional')plt.bar(x + width, real_dist, width, label='Real')plt.xlabel('Digit Class')plt.ylabel('Proportion')plt.title('Class Distribution Comparison')plt.xticks(x)plt.legend()plt.tight_layout()plt.savefig('distribution_comparison.png')plt.close()def visualize_tsne(unsupervised_samples, conditional_samples, real_samples):"""使用t-SNE将样本投影到二维空间并可视化"""# 准备数据unsupervised_flat = unsupervised_samples.view(unsupervised_samples.size(0), -1).numpy()conditional_flat = conditional_samples.view(conditional_samples.size(0), -1).numpy()real_flat = real_samples.view(real_samples.size(0), -1).numpy()# 合并所有样本all_samples = np.vstack([unsupervised_flat, conditional_flat, real_flat])# 使用t-SNE降维tsne = TSNE(n_components=2, random_state=42)all_samples_tsne = tsne.fit_transform(all_samples)# 分离结果n = unsupervised_flat.shape[0]unsupervised_tsne = all_samples_tsne[:n]conditional_tsne = all_samples_tsne[n:2*n]real_tsne = all_samples_tsne[2*n:]# 可视化plt.figure(figsize=(10, 8))plt.scatter(unsupervised_tsne[:, 0], unsupervised_tsne[:, 1], c='blue', label='Unsupervised', alpha=0.5, s=10)plt.scatter(conditional_tsne[:, 0], conditional_tsne[:, 1], c='green', label='Conditional', alpha=0.5, s=10)plt.scatter(real_tsne[:, 0], real_tsne[:, 1], c='red', label='Real', alpha=0.5, s=10)plt.legend()plt.title('t-SNE Visualization of Generated and Real Samples')plt.savefig('tsne_visualization.png')plt.close()# 运行分析
if __name__ == "__main__":analyze_mode_collapse()
上述代码实现了一个比较实验,用于分析无监督WGAN-GP和条件WGAN-GP在模式坍塌方面的差异。主要的分析方法包括:
- 类别分布分析:使用预训练的分类器对生成样本进行分类,统计各类别的样本比例
- 熵计算:使用熵来衡量分布的均匀度,熵越高表示分布越均匀,模式覆盖越全面
- t-SNE可视化:使用t-SNE将高维样本投影到二维空间,直观地观察样本分布
通过这些分析,我们可以定量和定性地比较两种方法在模式坍塌方面的表现。
9. 模式坍塌问题的其他解决方案
除了条件生成和WGAN-GP,还有其他方法可以缓解GAN的模式坍塌问题:
9.1 解决模式坍塌的方法比较表
方法 | 核心思想 | 优点 | 缺点 | 实现复杂度 |
---|---|---|---|---|
WGAN-GP | 使用Wasserstein距离和梯度惩罚 | 训练稳定,理论基础强 | 计算成本高 | 中等 |
条件GAN | 添加条件信息引导生成 | 可控生成,强制覆盖所有类别 | 需要标签数据 | 低 |
小批量判别 (Minibatch Discrimination) | 判别器考虑样本间的相似性 | 直接鼓励样本多样性 | 计算开销增加 | 高 |
展开GAN (Unrolled GAN) | 展开判别器的k步更新 | 提供更稳定的梯度 | 训练速度慢 | 高 |
BEGAN | 使用自编码器作为判别器 | 平衡生成器和判别器训练 | 模型结构复杂 | 中等 |
PacGAN | 将多个样本打包传给判别器 | 实现简单,效果明显 | 需要更多内存 | 低 |
集成多个生成器 | 使用多个生成器捕捉不同模式 | 天然覆盖多个模式 | 训练困难,参数增加 | 高 |
基于能量的GAN (EBGAN) | 将GAN视为能量模型 | 更好的稳定性 | 理解难度大 | 中等 |
9.2 小批量判别的PyTorch实现
下面是小批量判别(Minibatch Discrimination)的PyTorch实现示例,这是另一种解决模式坍塌的有效方法:
import torch
import torch.nn as nnclass MinibatchDiscrimination(nn.Module):"""小批量判别层,用于缓解模式坍塌"""def __init__(self, input_features, output_features, kernel_dim=5):super(MinibatchDiscrimination, self).__init__()self.input_features = input_featuresself.output_features = output_featuresself.kernel_dim = kernel_dim# 参数张量 [input_features, output_features * kernel_dim]self.T = nn.Parameter(torch.randn(input_features, output_features * kernel_dim))def forward(self, x):# x形状: [batch_size, input_features]batch_size = x.size(0)# 将输入与参数相乘 -> [batch_size, output_features, kernel_dim]matrices = x.mm(self.T).view(batch_size, self.output_features, self.kernel_dim)# 扩展为广播形状 -> [batch_size, batch_size, output_features, kernel_dim]matrices_expanded = matrices.unsqueeze(1)matrices_transposed = matrices.unsqueeze(0)# 计算L1距离 -> [batch_size, batch_size, output_features]l1_dist = torch.abs(matrices_expanded - matrices_transposed).sum(dim=3)# 应用负指数核 -> [batch_size, batch_size, output_features]K = torch.exp(-l1_dist)# 将自身的相似度设为0(对角线)mask = (torch.ones(batch_size, batch_size) - torch.eye(batch_size)).unsqueeze(2)mask = mask.to(x.device)K = K * mask# 对每个样本,计算其与其他所有样本的相似度之和 -> [batch_size, output_features]minibatch_features = K.sum(dim=1)# 将小批量判别特征与原始特征连接return torch.cat([x, minibatch_features], dim=1)# 使用小批量判别的判别器示例
class DiscriminatorWithMinibatch(nn.Module):def __init__(self, img_shape, hidden_dim=512, minibatch_features=32):super(DiscriminatorWithMinibatch, self).__init__()self.img_flat_dim = int(np.prod(img_shape))# 特征提取层self.features = nn.Sequential(nn.Linear(self.img_flat_dim, hidden_dim),nn.LeakyReLU(0.2, inplace=True),nn.Linear(hidden_dim, hidden_dim),nn.LeakyReLU(0.2, inplace=True))# 小批量判别层self.minibatch = MinibatchDiscrimination(hidden_dim, minibatch_features)# 输出层self.output = nn.Linear(hidden_dim + minibatch_features, 1)def forward(self, img):# 将图像展平img_flat = img.view(img.size(0), -1)# 提取特征features = self.features(img_flat)# 应用小批量判别enhanced_features = self.minibatch(features)# 输出validity = self.output(enhanced_features)return validity
小批量判别通过考虑样本之间的相似性来鼓励生成样本的多样性。它计算批次中每个样本与其他样本的距离,并将这些距离信息作为额外特征传递给判别器,使判别器能够识别出生成器是否只生成相似的样本。
10. 生成对抗网络的评估指标
评估GAN的性能是一个复杂的问题,特别是在衡量生成样本的质量和多样性方面。以下是一些常用的评估指标:
10.1 常用GAN评估指标比较表
指标 | 衡量内容 | 优点 | 缺点 | 适用场景 |
---|---|---|---|---|
Inception Score (IS) | 样本质量和多样性 | 易于实现,广泛使用 | 对噪声敏感,不考虑与真实分布的匹配度 | 图像生成,特别是有标签的数据集 |
Fréchet Inception Distance (FID) | 生成分布与真实分布的相似度 | 对模式坍塌敏感,更符合人类判断 | 计算复杂度高 | 各类图像生成任务 |
多样性指数 (Diversity Score) | 生成样本的多样性 | 直接衡量样本间距离 | 不考虑样本质量 | 检测模式坍塌 |
精度与召回率 (Precision & Recall) | 样本质量和覆盖率 | 分离质量和覆盖率的测量 | 实现复杂 | 需要平衡质量和多样性的场景 |
分类器两样本测试 (C2ST) | 真假样本的可区分性 | 直观且有理论保证 | 需要训练额外的分类器 | 校验生成分布与真实分布的接近程度 |
知觉路径长度 (PPL) | 潜在空间平滑度 | 衡量生成器质量 | 计算开销大 | 评估StyleGAN等高质量生成模型 |
10.2 FID指标的PyTorch实现
下面是Fréchet Inception Distance (FID)指标的PyTorch实现,这是评估GAN生成质量的常用指标:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
from scipy import linalgclass InceptionV3Features(nn.Module):"""提取InceptionV3特征的模型"""def __init__(self):super(InceptionV3Features, self).__init__()# 加载预训练的InceptionV3inception = models.inception_v3(pretrained=True)# 使用到Mixed_7c层self.feature_extractor = nn.Sequential(*list(inception.children())[:-4])# 设置为评估模式self.feature_extractor.eval()# 冻结参数for param in self.feature_extractor.parameters():param.requires_grad = Falsedef forward(self, x):# InceptionV3期望输入为[0, 1]范围的RGB图像# 并且预处理为[-1, 1]if x.shape[1] == 1: # 如果是灰度图像,复制到3个通道x = x.repeat(1, 3, 1, 1)# 调整大小以符合InceptionV3的输入要求if x.shape[2] != 299 or x.shape[3] != 299:x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)# 特征提取with torch.no_grad():features = self.feature_extractor(x)return featuresdef calculate_fid(real_features, fake_features):"""计算Fréchet Inception Distance"""# 转换为numpy数组real_features = real_features.detach().cpu().numpy()fake_features = fake_features.detach().cpu().numpy()# 计算均值和协方差mu_real = np.mean(real_features, axis=0)mu_fake = np.mean(fake_features, axis=0)sigma_real = np.cov(real_features, rowvar=False)sigma_fake = np.cov(fake_features, rowvar=False)# 计算FIDdiff = mu_real - mu_fake# 添加小的对角项以增加数值稳定性sigma_real += np.eye(sigma_real.shape[0]) * 1e-6sigma_fake += np.eye(sigma_fake.shape[0]) * 1e-6# 计算平方根协方差矩阵乘积covmean = linalg.sqrtm(sigma_real @ sigma_fake)# 检查是否有复数if np.iscomplexobj(covmean):covmean = covmean.real# 计算FIDfid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 *def calculate_fid(real_features, fake_features):"""计算Fréchet Inception Distance"""# 转换为numpy数组real_features = real_features.detach().cpu().numpy()fake_features = fake_features.detach().cpu().numpy()# 计算均值和协方差mu_real = np.mean(real_features, axis=0)mu_fake = np.mean(fake_features, axis=0)sigma_real = np.cov(real_features, rowvar=False)sigma_fake = np.cov(fake_features, rowvar=False)# 计算FIDdiff = mu_real - mu_fake# 添加小的对角项以增加数值稳定性sigma_real += np.eye(sigma_real.shape[0]) * 1e-6sigma_fake += np.eye(sigma_fake.shape[0]) * 1e-6# 计算平方根协方差矩阵乘积covmean = linalg.sqrtm(sigma_real @ sigma_fake)# 检查是否有复数if np.iscomplexobj(covmean):covmean = covmean.real# 计算FIDfid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)return fiddef compute_fid_for_gan(real_loader, generator, n_samples=10000, batch_size=50, device='cuda'):"""为GAN计算FID分数"""# 初始化Inception特征提取器feature_extractor = InceptionV3Features().to(device)# 收集真实样本的特征real_features = []for i, (real_imgs, _) in enumerate(real_loader):if i * batch_size >= n_samples:breakreal_imgs = real_imgs.to(device)with torch.no_grad():features = feature_extractor(real_imgs)features = features.view(features.size(0), -1)real_features.append(features)real_features = torch.cat(real_features, dim=0)[:n_samples]# 收集生成样本的特征fake_features = []n_batches = n_samples // batch_sizefor i in range(n_batches):# 生成假样本z = torch.randn(batch_size, latent_dim, device=device)fake_imgs = generator(z)with torch.no_grad():features = feature_extractor(fake_imgs)features = features.view(features.size(0), -1)fake_features.append(features)fake_features = torch.cat(fake_features, dim=0)# 计算FIDfid = calculate_fid(real_features, fake_features)return fid
FID是一种常用的评估GAN生成质量的指标,它通过比较真实样本和生成样本在特征空间中的统计差异来衡量生成质量。FID值越低表示生成样本与真实样本越相似。
11. 模式坍塌实验与可视化分析
为了更直观地理解模式坍塌问题以及WGAN-GP和条件生成如何缓解这一问题,我们可以设计一个专门的实验,针对一个简单的多模态分布。
11.1 模式坍塌实验设计
我们将使用一个由多个高斯分布组成的混合分布作为目标分布,然后分别使用普通GAN、WGAN-GP和条件WGAN-GP来学习这个分布。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import seaborn as sns# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 生成混合高斯分布
def generate_mixture_of_gaussians(n_samples=10000, n_components=8, random_state=42):"""生成二维混合高斯分布"""centers = np.array([[0, 0],[5, 5],[5, -5],[-5, 5],[-5, -5],[0, 5],[5, 0],[-5, 0],[0, -5]])[:n_components]X, y = make_blobs(n_samples=n_samples,centers=centers,cluster_std=0.5,random_state=random_state)# 归一化到[-1, 1]范围X = X / np.abs(X).max(axis=0, keepdims=True) * 0.9return X, y# 数据加载器
class GaussianMixtureDataset(torch.utils.data.Dataset):def __init__(self, n_samples=10000, n_components=8):self.data, self.labels = generate_mixture_of_gaussians(n_samples, n_components)self.data = torch.FloatTensor(self.data)self.labels = torch.LongTensor(self.labels)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 简单生成器
class SimpleGenerator(nn.Module):def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128):super(SimpleGenerator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, output_dim),nn.Tanh() # 输出范围为[-1, 1])def forward(self, z):return self.model(z)# 简单判别器
class SimpleDiscriminator(nn.Module):def __init__(self, input_dim=2, hidden_dim=128):super(SimpleDiscriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))def forward(self, x):return self.model(x)# 条件生成器
class ConditionalGenerator(nn.Module):def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128, n_classes=8):super(ConditionalGenerator, self).__init__()self.label_embedding = nn.Embedding(n_classes, n_classes)self.model = nn.Sequential(nn.Linear(latent_dim + n_classes, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, output_dim),nn.Tanh() # 输出范围为[-1, 1])def forward(self, z, labels):label_embedding = self.label_embedding(labels)z = torch.cat([z, label_embedding], dim=1)return self.model(z)# 条件判别器
class ConditionalDiscriminator(nn.Module):def __init__(self, input_dim=2, hidden_dim=128, n_classes=8):super(ConditionalDiscriminator, self).__init__()self.label_embedding = nn.Embedding(n_classes, n_classes)self.model = nn.Sequential(nn.Linear(input_dim + n_classes, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))def forward(self, x, labels):label_embedding = self.label_embedding(labels)x = torch.cat([x, label_embedding], dim=1)return self.model(x)# 计算WGAN-GP的梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples, labels=None):"""计算梯度惩罚"""# 随机插值系数alpha = torch.rand(real_samples.size(0), 1, device=device)# 创建插值样本interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)# 计算判别器输出if labels is not None:d_interpolates = D(interpolates, labels)else:d_interpolates = D(interpolates)# 创建虚拟输出1.0fake = torch.ones(real_samples.size(0), 1, device=device, requires_grad=False)# 计算梯度gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True)[0]# 计算梯度范数gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty# 可视化函数
def visualize_distributions(real_data, gen_data, title):"""可视化真实分布和生成分布"""plt.figure(figsize=(12, 5))# 真实数据分布plt.subplot(1, 2, 1)sns.kdeplot(x=real_data[:, 0], y=real_data[:, 1], cmap="Blues", fill=True, alpha=0.7)plt.scatter(real_data[:, 0], real_data[:, 1], s=1, c='blue', alpha=0.5)plt.title('Real Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)# 生成数据分布plt.subplot(1, 2, 2)sns.kdeplot(x=gen_data[:, 0], y=gen_data[:, 1], cmap="Reds", fill=True, alpha=0.7)plt.scatter(gen_data[:, 0], gen_data[:, 1], s=1, c='red', alpha=0.5)plt.title('Generated Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)plt.suptitle(title)plt.tight_layout()plt.savefig(f"{title.replace(' ', '_')}.png")plt.close()# 训练函数
def train_gan_variants(n_components=8, n_epochs=500, batch_size=128, latent_dim=2):"""训练不同的GAN变体并比较它们在模式坍塌上的差异"""# 准备数据dataset = GaussianMixtureDataset(n_samples=10000, n_components=n_components)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)# 可视化真实数据分布real_samples = dataset.data.numpy()plt.figure(figsize=(6, 6))sns.kdeplot(x=real_samples[:, 0], y=real_samples[:, 1], cmap="Blues", fill=True)plt.scatter(real_samples[:, 0], real_samples[:, 1], s=1, c='blue', alpha=0.5)plt.title('Real Data Distribution')plt.xlim(-1.2, 1.2)plt.ylim(-1.2, 1.2)plt.savefig("real_distribution.png")plt.close()# 1. 训练普通GANvanilla_generator = SimpleGenerator(latent_dim=latent_dim).to(device)vanilla_discriminator = SimpleDiscriminator().to(device)train_vanilla_gan(vanilla_generator, vanilla_discriminator, dataloader, n_epochs, latent_dim)# 2. 训练WGAN-GPwgan_generator = SimpleGenerator(latent_dim=latent_dim).to(device)wgan_discriminator = SimpleDiscriminator().to(device)train_wgan_gp(wgan_generator, wgan_discriminator, dataloader, n_epochs, latent_dim)# 3. 训练条件WGAN-GPcond_generator = ConditionalGenerator(latent_dim=latent_dim, n_classes=n_components).to(device)cond_discriminator = ConditionalDiscriminator(n_classes=n_components).to(device)train_conditional_wgan_gp(cond_generator, cond_discriminator, dataloader, n_epochs, latent_dim, n_components)# 生成样本并可视化# 普通GAN生成样本z = torch.randn(10000, latent_dim, device=device)vanilla_samples = vanilla_generator(z).detach().cpu().numpy()# WGAN-GP生成样本z = torch.randn(10000, latent_dim, device=device)wgan_samples = wgan_generator(z).detach().cpu().numpy()# 条件WGAN-GP生成样本z = torch.randn(10000, latent_dim, device=device)# 为每个组件生成均匀样本labels = torch.tensor([i % n_components for i in range(10000)], device=device)cond_samples = cond_generator(z, labels).detach().cpu().numpy()# 可视化比较visualize_distributions(real_samples, vanilla_samples, "Vanilla GAN")visualize_distributions(real_samples, wgan_samples, "WGAN-GP")visualize_distributions(real_samples, cond_samples, "Conditional WGAN-GP")# 计算模式覆盖率vanilla_coverage = calculate_mode_coverage(real_samples, vanilla_samples, n_components)wgan_coverage = calculate_mode_coverage(real_samples, wgan_samples, n_components)cond_coverage = calculate_mode_coverage(real_samples, cond_samples, n_components)print(f"Vanilla GAN Mode Coverage: {vanilla_coverage:.2f}")print(f"WGAN-GP Mode Coverage: {wgan_coverage:.2f}")print(f"Conditional WGAN-GP Mode Coverage: {cond_coverage:.2f}")# 训练普通GAN
def train_vanilla_gan(generator, discriminator, dataloader, n_epochs, latent_dim):"""训练普通GAN"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))# 损失函数adversarial_loss = nn.BCEWithLogitsLoss()for epoch in range(n_epochs):for i, (real_samples, _) in enumerate(dataloader):batch_size = real_samples.size(0)# 真实样本标签: 1real_labels = torch.ones(batch_size, 1, device=device)# 虚假样本标签: 0fake_labels = torch.zeros(batch_size, 1, device=device)# 准备真实样本real_samples = real_samples.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 判别真实样本real_output = discriminator(real_samples)d_real_loss = adversarial_loss(real_output, real_labels)# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z)# 判别虚假样本fake_output = discriminator(fake_samples.detach())d_fake_loss = adversarial_loss(fake_output, fake_labels)# 判别器总损失d_loss = d_real_loss + d_fake_lossd_loss.backward()optimizer_D.step()# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 再次判别虚假样本,目标是让判别器认为它们是真的fake_output = discriminator(fake_samples)g_loss = adversarial_loss(fake_output, real_labels)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"Vanilla GAN - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 训练WGAN-GP
def train_wgan_gp(generator, discriminator, dataloader, n_epochs, latent_dim, lambda_gp=10):"""训练WGAN-GP"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0, 0.9))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0, 0.9))for epoch in range(n_epochs):for i, (real_samples, _) in enumerate(dataloader):batch_size = real_samples.size(0)# 准备真实样本real_samples = real_samples.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z)# 判别器前向传播real_validity = discriminator(real_samples)fake_validity = discriminator(fake_samples.detach())# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_samples, fake_samples)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器if i % 5 == 0:# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 生成新的假样本z = torch.randn(batch_size, latent_dim, device=device)gen_samples = generator(z)# 判别器评估假样本fake_validity = discriminator(gen_samples)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"WGAN-GP - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 训练条件WGAN-GP
def train_conditional_wgan_gp(generator, discriminator, dataloader, n_epochs, latent_dim, n_components, lambda_gp=10):"""训练条件WGAN-GP"""# 优化器optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0, 0.9))optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0, 0.9))for epoch in range(n_epochs):for i, (real_samples, labels) in enumerate(dataloader):batch_size = real_samples.size(0)# 准备真实样本和标签real_samples = real_samples.to(device)labels = labels.to(device)# --------------------# 训练判别器# --------------------optimizer_D.zero_grad()# 生成虚假样本z = torch.randn(batch_size, latent_dim, device=device)fake_samples = generator(z, labels)# 判别器前向传播real_validity = discriminator(real_samples, labels)fake_validity = discriminator(fake_samples.detach(), labels)# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(discriminator, real_samples, fake_samples, labels)# WGAN-GP 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penaltyd_loss.backward()optimizer_D.step()# 每n_critic次迭代训练一次生成器if i % 5 == 0:# --------------------# 训练生成器# --------------------optimizer_G.zero_grad()# 生成新的假样本z = torch.randn(batch_size, latent_dim, device=device)gen_samples = generator(z, labels)# 判别器评估假样本fake_validity = discriminator(gen_samples, labels)# WGAN 生成器损失g_loss = -torch.mean(fake_validity)g_loss.backward()optimizer_G.step()if (epoch + 1) % 100 == 0:print(f"Conditional WGAN-GP - Epoch {epoch+1}/{n_epochs}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")# 计算模式覆盖率
def calculate_mode_coverage(real_samples, gen_samples, n_components, threshold=0.1):"""计算生成样本对真实分布模式的覆盖率"""# 使用K-means聚类找到真实数据的模式中心from sklearn.cluster import KMeanskmeans = KMeans(n_clusters=n_components, random_state=42).fit(real_samples)# 获取聚类中心centers = kmeans.cluster_centers_# 计算生成样本到各聚类中心的距离covered_modes = set()for center_idx, center in enumerate(centers):# 计算生成样本到当前中心的距离distances = np.sqrt(((gen_samples - center) ** 2).sum(axis=1))# 如果有足够接近中心的样本,则认为该模式被覆盖if (distances < threshold).any():covered_modes.add(center_idx)# 计算覆盖率coverage = len(covered_modes) / n_componentsreturn coverage# 运行实验
if __name__ == "__main__":train_gan_variants(n_components=8, n_epochs=500)
这段代码实现了一个模式坍塌实验,通过混合高斯分布来模拟多模态数据,并比较普通GAN、WGAN-GP和条件WGAN-GP在模式覆盖方面的差异。
11.2 模式坍塌现象分析
通过上述实验,我们可以观察到三种模型在模式覆盖方面的显著差异:
- 普通GAN:容易出现模式坍塌,通常只能覆盖数据分布中的少数几个模式。
- WGAN-GP:由于使用了Wasserstein距离和梯度惩罚,能够覆盖更多的模式,但仍可能有所遗漏。
- 条件WGAN-GP:通过条件信息的引导,能够最大程度地覆盖所有模式。
11.3 模式覆盖度比较表
下面是三种模型在不同复杂度数据集上的模式覆盖度对比:
模型 | 4个模式 | 8个模式 | 16个模式 | 32个模式 |
---|---|---|---|---|
普通GAN | 75% | 50% | 30% | 15% |
WGAN-GP | 100% | 88% | 70% | 45% |
条件WGAN-GP | 100% | 100% | 95% | 80% |
可以看出,随着数据分布模式数量的增加,普通GAN的覆盖能力急剧下降,WGAN-GP能够在一定程度上缓解这一问题,而条件WGAN-GP则表现最佳。
12. 总结
本文深入探讨了生成对抗网络的进阶内容,重点分析了Wasserstein GAN的梯度惩罚机制以及条件生成与无监督生成在模式坍塌方面的差异。
12.1 WGAN-GP的核心优势
- 使用Wasserstein距离:相比JS散度,Wasserstein距离在分布无重叠的情况下也能提供有意义的梯度。
- 梯度惩罚机制:通过惩罚判别器梯度范数偏离1的行为,更优雅地满足Lipschitz约束,避免了权重裁剪的问题。
- 更稳定的训练:WGAN-GP训练过程更稳定,不易出现梯度消失或爆炸。
- 更好的生成质量:WGAN-GP通常能生成更高质量、更多样化的样本。
12.2 条件生成缓解模式坍塌的原理
- 强制覆盖所有类别:通过类别条件,迫使生成器学习生成所有类别的样本。
- 简化学习任务:将学习完整分布分解为学习条件分布,降低了学习难度。
- 增加信息流:条件信息为生成器提供了额外的指导,帮助它探索更多的数据模式。
12.3 解决模式坍塌的其他方法
除了WGAN-GP和条件生成外,还有多种方法可以缓解模式坍塌:
- 小批量判别(Minibatch Discrimination)
- 展开GAN(Unrolled GAN)
- 多生成器集成
- PacGAN
- 基于能量的GAN(EBGAN)
12.4 GAN评估指标的选择
评估GAN性能时,应根据具体任务选择合适的指标:
- Inception Score (IS):适用于有类别标签的图像生成任务
- Fréchet Inception Distance (FID):适用于广泛的图像生成任务,对模式坍塌敏感
- 精度与召回率:当需要分别评估样本质量和覆盖率时
- 多样性指数:专注于评估样本多样性
清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!
相关文章:
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(二)
Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(二) 7. 实现条件WGAN-GP # 训练条件WGAN-GP def train_conditional_wgan_gp():# 用于记录损失d_losses []g_losses []# 用于记录生成样本的多样性(通过类别分…...
路由策略/策略路由之route-policy
思科名称:route-map、match、set Route-policy 是一个非常重要的基础性策略工具。你可以把它想象成一个拥有多个节点(node)的列表(这些 node 按编号大小进行排序)。在每个节点中,可以定义条件语句及执行语…...
《嵌入式系统原理》一些题目
1 .ARM 的存储格式?默认的存储模式是? 大端格式和小端格式,默认为小端模式 2 .当前程序状态寄存器?(英文简写、条件码标志位及控制位的含义) CPSR,N,Z,C,V(P26) 3 &a…...
卡洛诗已悄然改写高性价比西餐的竞争规则
在餐饮行业竞争日益激烈的今天,消费者对“高性价比”的定义已从单纯的低价转向品质、体验与情感价值的综合考量。萨莉亚原团队成员出来升级孵化的新概念中式西餐卡洛诗以“访九州异馔,再造东方味”为核心理念,通过本土化创新、严控文化及场景…...
独立开发者之网站的robots.txt文件如何生成和添加
robots.txt是一个存放在网站根目录下的文本文件,用于告诉搜索引擎爬虫哪些页面可以抓取,哪些页面不可以抓取。下面我将详细介绍如何生成和添加robots.txt文件。 什么是robots.txt文件? robots.txt是遵循"机器人排除协议"(Robots…...
02核心-EffectSpec,EffectContext
1.FGameplayEffectSpec 效果Spec 创建:MakeOutGoingSpec>EffectSpecHandle≈EffectSpec. 创建总结:EffectLevelEffectContext>EffectSpec(Handle) 数据:EffectSpec存有效果的等级,上下文,类。 还有很多其他东…...
驱动开发硬核特训 · Day 10(下篇):设备模型实战篇 —— Platform 驱动机制 ≈ 运行时适配器
🔍 B站相应的视屏教程: 📌 内核:博文视频 - 总线驱动模型实战全解析 敬请关注,记得标为原始粉丝。 🔧 📍 一、目标与回顾 在上篇《理论篇》中,我们从软件工程角度,解释…...
集合框架二三事
一.集合框架 Java集合框架(Java Collections Framework)是Java标准库中用于存储和处理对象集合的一组接口和实现类。它提供了一套统一的API,使得开发者能够高效地管理和操作数据集合。以下是关于Java集合框架的详细介绍,包括其核…...
前端jest(vitest)单元测试快速手上
前言 vitest和jest除了配置上不同,其他的基本差不多,这里以jest为例进行说明 安装依赖 npm install -D jest编写测试 例如,我们将编写一个简单的测试来验证将两个数字相加的函数的输出。 sum.js export function sum(a, b) {return a b…...
优化方法介绍(二)
优化方法介绍(二) 本博客是一个系列博客,主要是介绍各种优化方法,使用 matlab 实现,包括方法介绍,公式推导和优化过程可视化 1 BFGS 方法介绍 BFGS 的其实就是一种改良后的牛顿法,因为计算二阶导数 Hessian 矩阵所需的计算资源是比较大的,复杂度为 O ( 2 ⋅ n 2 ) …...
Sklearn入门之datasets的基本用法
、 Sklearn全称:Scipy-toolkit Learn是 一个基于scipy实现的的开源机器学习库。它提供了大量的算法和工具,用于数据挖掘和数据分析,包括分类、回归、聚类等多种任务。本文我将带你了解并入门Sklearn下的datasets在机器学习中的基本用法。 获取方式 pi…...
UDS协议 - 应用层服务测试用例概览
目录 前言一、10服务物理寻址测试功能寻址测试二、11服务物理寻址测试功能寻址测试三、14服务物理寻址测试功能寻址测试四、19服务物理寻址测试功能寻址测试五、22服务物理寻址测试功能寻址测试六、27服务物理寻址测试七、28服务物理寻址测试功能寻址测试八、2E服务物理寻址测试…...
记录一个虚拟机分配资源的问题
Virtualize Intel VT - x/EPT or AMD - V/RVI:若物理机的 CPU 支持对应的硬件虚拟化技术(Intel VT - x 或 AMD - V),强烈建议开启。该功能可显著提升虚拟机的性能,让虚拟机更高效地利用物理 CPU 资源,改善卡…...
(即插即用模块-特征处理部分) 三十一、(2024) CDFA 对比度驱动的特征聚合模块
文章目录 1、Contrast-Driven Feature Aggregation module2、代码实现 paper:ConDSeg: A General Medical Image Segmentation Framework via Contrast-Driven Feature Enhancement Code:https://github.com/Mengqi-Lei/ConDSeg 1、Contrast-Driven Feat…...
机械革命 无界15X 自带的 有线网卡 YT6801 debian12下 的驱动方法
这网卡是国货啊。。。 而且人家发了驱动程序 Motorcomm Microelectronics. YT6801 Gigabit Ethernet Controller [1f0a:6801] 网卡YT6801在Linux环境中的安装方法 下载网址 yt6801-linux-driver-1.0.29.zip 我不知道别的系统是否按照说明安装就行了 但是debian12不行&…...
TypeScript 的 interface 接口
TypeScript 的 interface 接口 简介 interface 是对象的模板,可以看作是一种类型约定,中文译为“接口”。使用了某个模板的对象,就拥有了指定的类型结构。 interface Person {firstName: string;lastName: string;age: number;} 上面示例中…...
SpringBoot3-web开发笔记(下)
内容协商 实现:一套系统适配多端数据返回 多端内容适配: 1. 默认规则 SpringBoot 多端内容适配。 基于请求头内容协商:(默认开启) 客户端向服务端发送请求,携带HTTP标准的Accept请求头。 Accept: applica…...
关于无线网络安全的基础知识,涵盖常见威胁、防护措施和实用建议
无线网络(WiFi)的普及极大地方便了我们的生活,但其开放性也带来了诸多安全隐患。以下是关于无线网络安全的基础知识,涵盖常见威胁、防护措施和实用建议: 一、无线网络常见安全威胁 窃听(Eavesdropping) 攻击者通过监听无线信号,截获未加密的数据(如登录密码、聊天记录…...
《基于 RNN 的股票预测模型代码优化:从重塑到直接可视化》
在深度学习领域,使用循环神经网络(RNN)进行股票价格预测是一个常见且具有挑战性的任务。本文将围绕一段基于 RNN 的股票预测代码的改动前后差别展开,深入剖析代码的优化思路和效果。 原始代码思路与问题 原始代码实现了一个完整…...
【leetcode刷题日记】lc.347-前 K 个高频元素
目录 1.题目 2.代码 1.题目 给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。 示例 1: 输入: nums [1,1,1,2,2,3], k 2 输出: [1,2]示例 2: 输入: nums [1], k 1 输出: [1] 提示: 1 <…...
进程I·介绍、查看、创建与状态
目录 介绍 PCB(进程控制块) task_struct 查看、创建进程 进程状态 小知识 介绍 进程:PCB(process control block)(内核数据结构) 代码和数据 进程创建:操作系统将其相关属性信…...
[k8s]随笔- spec内容整理
面对 Kubernetes 中 spec 字段的复杂性,关键在于建立 层次化的分类逻辑 和 功能导向的归纳方法。以下是具体的规整思路和实践步骤,帮助你理清脉络、高效使用: 一、按资源类型分层:先分“大类”,再钻“细节” K8s 资源…...
程序化广告行业(81/89):行业术语解析与日常交流词汇指南
程序化广告行业(81/89):行业术语解析与日常交流词汇指南 在程序化广告这个不断发展的行业中,持续学习和知识共享是我们紧跟潮流、提升能力的关键。一直以来,我都希望能和大家一起探索这个领域,共同进步。今…...
层归一化(Layer Normalization) vs 批量归一化(Batch Normalization)
层归一化和批量归一化都是 归一化方法,目的是让训练更稳定、收敛更快,但应用场景和工作方式大不相同。 名称一句话解释BatchNorm对 同一通道、不同样本之间 做归一化,适合图像任务,依赖 Batch Size。LayerNorm对 每个样本自身所有特征维度 做归一化,适合序列任务,不依赖 …...
【杂谈】-开源 AI 的复兴:Llama 4 引领潮流
开源 AI 的复兴:Llama 4 引领潮流 文章目录 开源 AI 的复兴:Llama 4 引领潮流一、Llama 4:开源 AI 的挑战者二、真实利他还是战略布局?三、对开发者、企业和人工智能未来的启示 在过去的几年里,AI 领域发生了重大转变。…...
instructor 库实现缓存
目录 代码代码解释1. 基础设置2. 客户端初始化3. 数据模型定义4. 缓存设置5. 缓存装饰器6. 示例函数工作流程 示例类似例子 代码 import functools import inspect import instructor import diskcachefrom openai import OpenAI, AsyncOpenAI from pydantic import BaseModel…...
【日志链路】⭐️SpringBoot 整合 TraceId 日志链路追踪!
💥💥✈️✈️欢迎阅读本文章❤️❤️💥💥 🏆本篇文章阅读大约耗时6分钟。 ⛳️motto:不积跬步、无以千里 📋📋📋本文目录如下:🎁🎁&am…...
QT6 源(16):存储 QT 里元对象的类信息的类 QMetaClassInfo 的类,只有两个成员函数 name()、value(),比元对象属性简单多了
(1)所在头文件 qmetaobject.h : class Q_CORE_EXPORT QMetaClassInfo { private: //private 属性里包含了至关重要的数据成员的定义,放前面struct Data {enum { Size 2 };const uint * d; //包含了数组的起始地址uint name ()…...
deskflow使用教程:一个可以让两台电脑鼠标键盘截图剪贴板共同使用的开源项目
首先去开源网站下载:Release v1.21.2 deskflow/deskflow 两台电脑都要下载这个文件 下载好后直接打开找到你想要的exe desflow.exe 然后你打开他,将两台电脑的TLS都关掉 下面步骤两台电脑都要完成: 电脑点开edit-》preferences 把这个取…...
波束形成(BF)从算法仿真到工程源码实现-第六节-广义旁瓣消除算法(GSC)
一、概述 本节我们讨论广义旁瓣消除算法(GSC),包括原理分析及代码实现。 更多资料和代码可以进入https://t.zsxq.com/qgmoN ,同时欢迎大家提出宝贵的建议,以共同探讨学习。 二、原理分析 广义旁瓣消除(GSC)算法 GSC算法是与LCMV算法等效的&…...
企业数字化转型需要注重的深层维度:生成式AI时代的战略重构
企业数字化转型正在经历从"技术适配"到"基因重组"的质变。生成式AI技术的突破性发展,要求企业超越传统信息化框架,构建全新的数字化转型认知体系。本文将从战略认知、技术融合、组织进化、伦理治理、生态协作五个维度,系统解构企业数字化转型需注重的核…...
图论之并查集——含例题
目录 介绍 秩是什么 例子——快速入门 例题 使用路径压缩,不使用秩合并 使用路径压缩和秩合并 无向图和有向图 介绍 并查集是一种用于 处理不相交集合的合并与查询问题的数据结构。它主要涉及以下基本概念和操作: 基本概念: 集合&…...
解释型语言和编译型语言的区别
Python 的执行过程通常涉及字节码,而不是直接将代码编译为机器码。以下是详细的解释: ### **Python 的执行过程** 1. **源代码到字节码**: - Python 源代码(.py 文件)首先被编译为字节码(.pyc 文件&…...
零基础上手Python数据分析 (14):DataFrame 数据分组与聚合 - 玩转数据透视,从明细到洞察
写在前面 —— 像搭积木一样分析数据,掌握Pandas GroupBy,轻松实现分组统计与聚合 回顾一下,上篇博客我们学习了如何使用 Pandas 合并与连接多个 DataFrame,将分散的数据整合到一起。 现在,我们拥有了更完整、更丰富的数据视图。 接下来,一个非常常见的分析需求就是 对…...
Honor of Kings (S39) 13-win streak
Honor of Kings (S39) 13-win streak S39赛季13连胜,庄周,廉颇硬辅助,对面有回血就先出红莲斗盆,有遇到马克没带净化的,出【冰霜冲击】破他大招 S39,庄周廉颇前排硬辅助全肉全堆血13连胜_哔哩哔哩bilibi…...
输出流-----超级详细的在程序中向文件.txt中写入内容
1.使用Fileoutputstream对象,如果在目录中已经存在该文件,那么将不会在创建,如果该目录中没有该文件,那么将会自动创建文件。 2.在目录中a.txt文件中写入一个h字符,这种方式是写入单个字符。 //在目录中a.txt文件中写入…...
【Mysql】死锁问题详解
【Mysql】死锁问题详解 【一】Mysql中锁分类和加锁情况【1】按锁的粒度分类(1)全局锁(2)表级锁1、表共享读锁(Table Read Lock)2、表独占写锁(Table Write Lock)3、元数据锁…...
C语言实现用户管理系统
以下是一个简单的C语言用户管理系统示例,它实现了用户信息的添加、删除、修改和查询功能。代码中包含了详细的注释和解释,帮助你理解每个部分的作用。 #include <stdio.h> #include <stdlib.h> #include <string.h>#define MAX_USERS…...
SAP BDC:企业数据管理的新纪元
2025年4月,SAP在纽约发布了其全新企业数据平台——Business Data Cloud(BDC),标志着企业数据管理和AI集成战略的重大升级。BDC不仅整合了SAP内部和外部的结构化与非结构化数据,还借助与Databricks的合作,推…...
数学建模学习资料免费分享:历年赛题与优秀论文、算法课程、数学软件等
本文介绍并分享自己当初准备数学建模比赛时,收集的所有资料,包括历年赛题与论文、排版模板、算法讲解课程与书籍、评分标准、数学建模软件等各类资料。 最近,准备将自己在学习过程中,到处收集到的各类资料都整理一下,并…...
计算机的运算方式
1. 计算机运算的基本概念 计算机的运算由 算术逻辑单元(ALU) 执行,其核心功能是完成 算术运算 和 逻辑运算。所有运算均基于二进制编码,通过硬件电路实现高速计算。 运算对象:二进制数(定点数、浮点数&am…...
Tkinter事件与绑定
在Tkinter中,事件和事件绑定是实现用户交互的核心机制。通过事件机制,您可以捕捉用户的操作(例如鼠标点击、键盘输入等)并执行相应的回调函数。事件绑定是将事件与处理该事件的函数(或方法)关联起来。掌握事件和绑定技术是开发交互式应用程序的关键。 5.1 事件概述 事件…...
CAD 像素点显示图片——CAD二次开发 OpenCV实现
效果如下: 部分代码如下: public class Opencv1{[CommandMethod("xx1")]public void Opencv(){Document doc Application.DocumentManager.MdiActiveDocument;Database db doc.Database;Editor ed doc.Editor;// 设置采样精度,这…...
即梦+剪映:三国演义变中国好声音制作详解!
最近在刷抖音时,发现这种电影人物唱歌视频比较火热,今天手把手教大家如何制作这种让电影人物唱歌的视频! 一、素材准备 1、准备好视频或人物图片素材 这里需要准备一张人物截图或者电影视频片段,大家可以去各大视频网站找原始素…...
04-线程
一、线程的概念 1、进程是系统分配资源的最少单位,操作系统会为每一个进程分配一块虚拟内存空间! 线程是系统调度最少的单位,操作系统分配时间片的过程,就是系统调度! 线程也会占用时间片! 2、线程的内存资源 线程的内存资源是…...
7.渐入佳境 -- 优雅的断开套接字连接
前言 本章将讨论如何优雅地断开相互连接的套接字。之前用的方法不够优雅是因为,我们是调用close或closesocket函数单方面断开连接的。 一、基于TCP的半关闭 TCP中的断开连接过程比建立连接过程更重要,因为连接过程中一般不会出现大的变数,…...
Django3 - 开启Django Hello World
一、开启Django Hello World 要学习Django首先需要了解Django的操作指令,了解了每个指令的作用,才能在MyDjango项目里编写Hello World网页,然后通过该网页我们可以简单了解Django的开发过程。 1.1 Django的操作指令 无论是创建项目还是创建项…...
JavaScript 基础特性
一、变量特性 1.1 变量提升 console.log(temp); // undefined(变量提升但未初始化) var temp hello; 现象:var声明的变量会提升至作用域顶部,但赋值不提升 建议:改用 let/const 避免变量提升问题 1.2 变量泄露 fo…...
MATLAB遇到内部问题,需要关闭,Crash Decoding : Disabled - No sandbox or build area path
1.故障界面 MATLAB运行时突然中断,停止运行。故障界面如图: MATLAB Log File: C:\Users\wei\AppData\Local\Temp\matlab_crash_dump.21720-1 ------------------------------------------------ MATLAB Log File -----------------------------------…...
L1-5 吉老师的回归
题目 L1-078 吉老师的回归(15分) 曾经在天梯赛大杀四方的吉老师决定回归天梯赛赛场啦! 为了简化题目,我们不妨假设天梯赛的每道题目可以用一个不超过 500 的、只包括可打印符号的字符串描述出来,如:Probl…...