从开始实现扩散概率模型 PyTorch 实现
目录
一、说明
二、从头开始实施
三、线性噪声调度器
四、时间嵌入
五、下层DownBlock类块
六、中间midBlock类块
七、UpBlock上层类块
八、UNet 架构
九、训练
十、采样
十一、配置(Default.yaml)
十二、数据集 (MNIST)
keyword: Diffusion Probabilistic Models
一、说明
扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。
在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。
通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。
训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。
二、从头开始实施
对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。
为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。
作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。
为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。
三、线性噪声调度器
import torchclass LinearNoiseScheduler:def __init__(self, num_timesteps, beta_start, beta_end):self.num_timesteps = num_timestepsself.beta_start = beta_startself.beta_end = beta_endself.betas = torch.linspace(beta_start, beta_end, num_timesteps)self.alphas = 1. - self.betasself.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。
def add_noise(self, original, noise, t):original_shape = original.shapebatch_size = original_shape[0]sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)for _ in range(len(original_shape) - 1):sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)for _ in range(len(original_shape) - 1):sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)# Apply and Return Forward process equationreturn (sqrt_alpha_cum_prod.to(original.device) * original+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
该add_noise()
函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。
def sample_prev_timestep(self, xt, noise_pred, t):x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))x0 = torch.clamp(x0, -1., 1.)mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])if t == 0:return mean, x0else:variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])variance = variance * self.betas.to(xt.device)[t]sigma = variance ** 0.5z = torch.randn(xt.shape).to(xt.device)return mean + sigma * z, x0
调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。
对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。
这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。
作者图片
无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。
对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。
作者图片
时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。
作者图片
UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。
作者图片
大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。
在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。
为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。
作者图片
另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。
四、时间嵌入
import torch
import torch.nn as nndef get_time_embedding(time_steps, temb_dim):assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"# factor = 10000^(2i/d_model)factor = 10000 ** ((torch.arange(start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)))# pos / factor# timesteps B -> B, 1 -> B, temb_dimt_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factort_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)return t_emb
第一个函数为给定的时间步长get_time_embedding
生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。time_steps
:时间步长值的张量(形状:[B]
其中B
是批次大小)。每个值代表批次元素的一个离散时间步长。temb_dim
:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。
确保这temb_dim
是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。
五、下层DownBlock类块
class DownBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim,down_sample=True, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.down_sample = down_sampleself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,kernel_size=3, stride=1, padding=1),)for i in range(num_layers)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels,kernel_size=3, stride=1, padding=1),)for _ in range(num_layers)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers)])self.down_sample_conv = nn.Conv2d(out_channels, out_channels,4, 2, 1) if self.down_sample else nn.Identity()def forward(self, x, t_emb):out = xfor i in range(self.num_layers):# Resnet block of Unetresnet_input = outout = self.resnet_conv_first[i](out)out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]out = self.resnet_conv_second[i](out)out = out + self.residual_input_conv[i](resnet_input)# Attention block of Unetbatch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attnout = self.down_sample_conv(out)return out
DownBlock 类结合了ResNet 块、自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。
参数:
in_channels
:输入通道数。out_channels
:输出通道数。t_emb_dim
:时间嵌入的维度。down_sample
:布尔值,确定是否在块末尾应用下采样。num_heads
:多头注意力层中的注意力头的数量。num_layers
:此块中的 ResNet + 注意力层的数量。
ResNet块:
resnet_conv_first
:ResNet 块的第一个卷积层。t_emb_layers
:时间嵌入投影层。resnet_conv_second
:ResNet 块的第二个卷积层。residual_input_conv
:用于残差连接的 1x1 卷积。
自注意力模块:
attention_norms
:在注意力机制之前对规范化层进行分组。attentions
:多头注意力层。
下采样:
down_sample_conv
:应用卷积来减少空间维度(如果down_sample=True
)。
Forward Pass 方法定义了如何x
通过块处理输入张量:out
初始化为输入x
。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。
在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb
到线性层(投影到out_channels
),并将此投影时间嵌入添加到out
(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input
)添加到第二个卷积的输出。
在自注意力模块中,我们将空间维度扁平化为一个维度(h * w
)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn
使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接和下采样。
六、中间midBlock类块
class MidBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,padding=1),)for i in range(num_layers+1)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers + 1)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),)for _ in range(num_layers+1)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers+1)])def forward(self, x, t_emb):out = x# First resnet blockresnet_input = outout = self.resnet_conv_first[0](out)out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]out = self.resnet_conv_second[0](out)out = out + self.residual_input_conv[0](resnet_input)for i in range(self.num_layers):# Attention Blockbatch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attn# Resnet Blockresnet_input = outout = self.resnet_conv_first[i+1](out)out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]out = self.resnet_conv_second[i+1](out)out = out + self.residual_input_conv[i+1](resnet_input)return out
该类MidBlock
是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块和自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:
- 时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
- 层迭代:在注意力和ResNet 块之间交替,按
num_layers
这些组合的顺序处理输入。
七、UpBlock上层类块
class UpBlock(nn.Module):def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):super().__init__()self.num_layers = num_layersself.up_sample = up_sampleself.resnet_conv_first = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, in_channels if i == 0 else out_channels),nn.SiLU(),nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,padding=1),)for i in range(num_layers)])self.t_emb_layers = nn.ModuleList([nn.Sequential(nn.SiLU(),nn.Linear(t_emb_dim, out_channels))for _ in range(num_layers)])self.resnet_conv_second = nn.ModuleList([nn.Sequential(nn.GroupNorm(8, out_channels),nn.SiLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),)for _ in range(num_layers)])self.attention_norms = nn.ModuleList([nn.GroupNorm(8, out_channels)for _ in range(num_layers)])self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True)for _ in range(num_layers)])self.residual_input_conv = nn.ModuleList([nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)for i in range(num_layers)])self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,4, 2, 1) \if self.up_sample else nn.Identity()def forward(self, x, out_down, t_emb):x = self.up_sample_conv(x)x = torch.cat([x, out_down], dim=1)out = xfor i in range(self.num_layers):resnet_input = outout = self.resnet_conv_first[i](out)out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]out = self.resnet_conv_second[i](out)out = out + self.residual_input_conv[i](resnet_input)batch_size, channels, h, w = out.shapein_attn = out.reshape(batch_size, channels, h * w)in_attn = self.attention_norms[i](in_attn)in_attn = in_attn.transpose(1, 2)out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)out = out + out_attnreturn out
该类UpBlock
是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样、跳过连接、ResNet 块和自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。
- 上采样:通过转置卷积(
ConvTranspose2d
)实现,以增加特征图的空间分辨率。 - 跳过连接:允许解码器重用编码器的详细特征,帮助重建。
- ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
- 自我注意力:捕获远程空间依赖关系以保留全局上下文。
- 时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。
八、UNet 架构
class Unet(nn.Module):def __init__(self, model_config):super().__init__()im_channels = model_config['im_channels']self.down_channels = model_config['down_channels']self.mid_channels = model_config['mid_channels']self.t_emb_dim = model_config['time_emb_dim']self.down_sample = model_config['down_sample']self.num_down_layers = model_config['num_down_layers']self.num_mid_layers = model_config['num_mid_layers']self.num_up_layers = model_config['num_up_layers']assert self.mid_channels[0] == self.down_channels[-1]assert self.mid_channels[-1] == self.down_channels[-2]assert len(self.down_sample) == len(self.down_channels) - 1# Initial projection from sinusoidal time embeddingself.t_proj = nn.Sequential(nn.Linear(self.t_emb_dim, self.t_emb_dim),nn.SiLU(),nn.Linear(self.t_emb_dim, self.t_emb_dim))self.up_sample = list(reversed(self.down_sample))self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))self.downs = nn.ModuleList([])for i in range(len(self.down_channels)-1):self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,down_sample=self.down_sample[i], num_layers=self.num_down_layers))self.mids = nn.ModuleList([])for i in range(len(self.mid_channels)-1):self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,num_layers=self.num_mid_layers))self.ups = nn.ModuleList([])for i in reversed(range(len(self.down_channels)-1)):self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))self.norm_out = nn.GroupNorm(8, 16)self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)def forward(self, x, t):# Shapes assuming downblocks are [C1, C2, C3, C4]# Shapes assuming midblocks are [C4, C4, C3]# Shapes assuming downsamples are [True, True, False]# B x C x H x Wout = self.conv_in(x)# B x C1 x H x W# t_emb -> B x t_emb_dimt_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)t_emb = self.t_proj(t_emb)down_outs = []for idx, down in enumerate(self.downs):down_outs.append(out)out = down(out, t_emb)# down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]# out B x C4 x H/4 x W/4for mid in self.mids:out = mid(out, t_emb)# out B x C3 x H/4 x W/4for up in self.ups:down_out = down_outs.pop()out = up(out, down_out, t_emb)# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]out = self.norm_out(out)out = nn.SiLU()(out)out = self.conv_out(out)# out B x C x H x Wreturn out
该类是U-Net 架构Unet
的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样、中级处理和上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。
作者图片
- 时间嵌入:实现时间动态。
- 跳过连接:通过连接将细粒度的空间细节集成到解码器中。
- 灵活的架构:允许通过
model_config
不同的深度、分辨率和功能丰富度进行定制。 - 规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
- 输出一致性:确保输出图像保留原始的空间尺寸和通道数。
九、训练
import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseSchedulerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def train(args):with open(args.config_path, 'r') as file:try:config = yaml.safe_load(file)except yaml.YAMLError as exc:print(exc)print(config)diffusion_config = config['diffusion_params']dataset_config = config['dataset_params']model_config = config['model_params']train_config = config['train_params']# Create the noise schedulerscheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],beta_start=diffusion_config['beta_start'],beta_end=diffusion_config['beta_end'])# Create the datasetmnist = MnistDataset('train', im_path=dataset_config['im_path'])mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)# Instantiate the modelmodel = Unet(model_config).to(device)model.train()# Create output directoriesif not os.path.exists(train_config['task_name']):os.mkdir(train_config['task_name'])# Load checkpoint if foundif os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):print('Loading checkpoint as found one')model.load_state_dict(torch.load(os.path.join(train_config['task_name'],train_config['ckpt_name']), map_location=device))# Specify training parametersnum_epochs = train_config['num_epochs']optimizer = Adam(model.parameters(), lr=train_config['lr'])criterion = torch.nn.MSELoss()# Run trainingfor epoch_idx in range(num_epochs):losses = []for im in tqdm(mnist_loader):optimizer.zero_grad()im = im.float().to(device)# Sample random noisenoise = torch.randn_like(im).to(device)# Sample timestept = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)# Add noise to images according to timestepnoisy_im = scheduler.add_noise(im, noise, t)noise_pred = model(noisy_im, t)loss = criterion(noise_pred, noise)losses.append(loss.item())loss.backward()optimizer.step()print('Finished epoch:{} | Loss : {:.4f}'.format(epoch_idx + 1,np.mean(losses),))torch.save(model.state_dict(), os.path.join(train_config['task_name'],train_config['ckpt_name']))print('Done Training ...')if __name__ == '__main__':parser = argparse.ArgumentParser(description='Arguments for ddpm training')parser.add_argument('--config', dest='config_path',default='config/default.yaml', type=str)args = parser.parse_args()train(args)
加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。
设置组件:
- 初始化噪声调度器,用于在不同的时间步添加噪声。
- 创建一个MNIST 数据集加载器。
- 实例化U-Net模型。
检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。
训练循环:每个时期:
- 遍历数据集,根据采样的时间步长向图像添加噪声。
- 使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
- 使用反向传播更新模型参数并保存模型检查点。
优化:使用 Adam 优化器和 MSE 损失函数来训练模型。
完成:打印 epoch 损失并在每个 epoch 结束时保存模型。
十、采样
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseSchedulerdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def sample(model, scheduler, train_config, model_config, diffusion_config):xt = torch.randn((train_config['num_samples'],model_config['im_channels'],model_config['im_size'],model_config['im_size'])).to(device)for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):# Get prediction of noisenoise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))# Use scheduler to get x0 and xt-1xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))# Save x0ims = torch.clamp(xt, -1., 1.).detach().cpu()ims = (ims + 1) / 2grid = make_grid(ims, nrow=train_config['num_grid_rows'])img = torchvision.transforms.ToPILImage()(grid)if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):os.mkdir(os.path.join(train_config['task_name'], 'samples'))img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))img.close()def infer(args):# Read the config file #with open(args.config_path, 'r') as file:try:config = yaml.safe_load(file)except yaml.YAMLError as exc:print(exc)print(config)diffusion_config = config['diffusion_params']model_config = config['model_params']train_config = config['train_params']# Load model with checkpointmodel = Unet(model_config).to(device)model.load_state_dict(torch.load(os.path.join(train_config['task_name'],train_config['ckpt_name']), map_location=device))model.eval()# Create the noise schedulerscheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],beta_start=diffusion_config['beta_start'],beta_end=diffusion_config['beta_end'])with torch.no_grad():sample(model, scheduler, train_config, model_config, diffusion_config)if __name__ == '__main__':parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')parser.add_argument('--config', dest='config_path',default='config/default.yaml', type=str)args = parser.parse_args()infer(args)
加载配置:从 YAML 文件读取模型、扩散和训练参数。
模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。
采样过程:
- 从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
- 在每个时间步:
- 使用模型预测噪音。
- 使用调度程序计算去噪图像(
x0
)并更新当前噪声图像(xt
)。 - 将中间去噪图像作为 PNG 文件保存在输出目录中。
推理:执行采样过程并保存结果而不改变模型。
十一、配置(Default.yaml)
dataset_params:im_path: 'data/train/images'diffusion_params:num_timesteps : 1000beta_start : 0.0001beta_end : 0.02model_params:im_channels : 1im_size : 28down_channels : [32, 64, 128, 256]mid_channels : [256, 256, 128]down_sample : [True, True, False]time_emb_dim : 128num_down_layers : 2num_mid_layers : 2num_up_layers : 2num_heads : 4train_params:task_name: 'default'batch_size: 64num_epochs: 40num_samples : 100num_grid_rows : 10lr: 0.0001ckpt_name: 'ddpm_ckpt.pth'
该配置文件提供了扩散模型的训练和推理的设置。
数据集参数im_path
:指定训练图像的路径( )。
扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_start
和beta_end
)。
模型参数:
- 定义模型架构,包括:
- 输入图像通道(
im_channels
)和大小(im_size
)。 - 下采样、中间处理和上采样的通道数。
- 每一级是否发生下采样(
down_sample
)。 - 各种块的嵌入尺寸和层数。
训练参数:
- 指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
- 包括采样设置,例如用于可视化的样本数量和网格行数。
十二、数据集 (MNIST)
import glob
import osimport torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Datasetclass MnistDataset(Dataset):self.split = splitself.im_ext = im_extself.images, self.labels = self.load_images(im_path)def load_images(self, im_path):assert os.path.exists(im_path), "images path {} does not exist".format(im_path)ims = []labels = []for d_name in tqdm(os.listdir(im_path)):for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):ims.append(fname)labels.append(int(d_name))print('Found {} images for split {}'.format(len(ims), self.split))return ims, labelsdef __len__(self):return len(self.images)def __getitem__(self, index):im = Image.open(self.images[index])im_tensor = torchvision.transforms.ToTensor()(im)# Convert input to -1 to 1 range.im_tensor = (2 * im_tensor) - 1return im_tensor
初始化:采用分割名称、图像文件扩展名(im_ext
)和图像路径(im_path
)。调用load_images
以加载图像路径及其相应的标签。
图像加载:load_images
遍历 处的目录结构im_path
,假设子目录已标记(例如,数字类别的0
、1
、...)。收集图像文件路径并根据文件夹名称分配标签。
数据集长度:__len__
返回图像的总数。
数据检索:__getitem__
通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。
相关文章:
从开始实现扩散概率模型 PyTorch 实现
目录 一、说明 二、从头开始实施 三、线性噪声调度器 四、时间嵌入 五、下层DownBlock类块 六、中间midBlock类块 七、UpBlock上层类块 八、UNet 架构 九、训练 十、采样 十一、配置(Default.yaml) 十二、数据集 (MNIST) keyword: Diffusion…...
LabVIEW智能焊接系统
焊接作为制造业中的核心工艺,直接影响到产品的性能与可靠性。传统的焊接过程通常依赖操作工的经验控制参数,导致质量波动较大,效率低下且容易产生人为误差。随着工业自动化和智能制造的不断发展,传统焊接方法的局限性愈加明显。本…...
如何快速排查 Wi-Fi 的 TPUT 问题?
1. 如何排查 Wi-Fi TPUT 问题 掌握每个 Wi-Fi 协议下的 Wi-Fi TPUT 的计算方法 一文让你轻松理解WLAN物理层速率计算方式_wifi速率计算公式-CSDN博客配查 CPU 的资源占用率:interrupt、CPU loading Linux/Android 系统使用 mpstat 工具 具体工具的使用方法ÿ…...
C语言单链表、双链表专题及应用
1.链表的概念及结构 概念:链表是一种物理存储结构上非连续,非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的 链表的结构跟火车车厢相似,淡季时车次的车厢会相应减少,旺季时车次的车厢会额外增…...
C++4--类
目录 1.类的引入 2.类的定义 3.类的访问限定符及封装 3.1访问的限定符 3.2封装 4.类的作用域 5.类的实体化 1.类的引入 C语言结构体中只能定义变量,在C中,结构体内不仅可以定义变量,也可以定义函数。比如:之间在数据结构中&…...
紫光展锐5G融云方案,开启云终端新时代
近年来,云终端凭借便捷、高效、高性价比的优势正逐步在各行各业渗透。研究机构IDC的数据显示,2024上半年,中国云终端市场总体出货量达到166.3万台,同比增长22.4%,销售额29亿元人民币,同比增长24.9%…...
雪泥鸿爪和屈指可数
paw这个单词,表示“爪或手”,是一个和hoof相对的单词: hoof n.(马等动物的)蹄paw n.爪子;(动物的)爪;(人的)手 v.挠,抓;动手动脚 所以,当你理解了 paw 和 hoof 是相对的概念时&…...
C++并发与多线程(高级函数async)
async 在 C 中,async 关键字用于实现异步编程,它允许你定义异步操作,这些操作可以在后台执行,而不会阻塞当前线程。这是 C11 引入的特性,与 std::async 函数和 std::future 类一起使用。与thread函数模板的区别在于as…...
LeetCode 力扣 热题 100道(二十)三数之和(C++)
给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意:答案中不可以包含重复的三元组。 如下代码…...
类和对象(4)
大家好,今天来给大家介绍一下this引用,在学习类和对象的时候大家一定有一点疑惑吧,类为什么能知道我们传入的是哪个对象,又是怎么实例化我们的成员的,那么我们便来了解一下。 四.this引用 4.1为什么要有this引用 在…...
php基础:正则表达式
1.正则表达式 正则表达式是用于描述字符排列和匹配模式的一种语法规则。它主要用于字符串的模式分割、匹配、查找及替换操作。到目前为止,我们前面所用过的精确(文本)匹配也是一种正则表达式。 在PHP中,正则表达式一般是由正规字…...
Vue3动态表单实现
实现方法:通过<component />标签动实现动态表单渲染 component标签: 在vue中 component 标签用于动态组件标签的渲染。它允许在同一个挂载点上条件渲染不同的组件,通过is属性可以渲染指定的属性 在上面的例子中,通过调用…...
【网络取证篇】取证实战之PHP服务器镜像网站重构及绕密分析
【网络取证篇】取证实战之PHP服务器镜像网站重构及绕密分析 在裸聊敲诈、虚假理财诈骗案件类型中,犯罪分子为了能实现更低成本、更快部署应用的目的,其服务器架构多为常见的初始化网站架构,也称为站库同体服务器!也就是说网站应用…...
高数 | 用简单的话讲考研数学知识点(第一集:充分和必要)
目录 一、前言 二、充分和必要 三、基础符号 四、符号拓展 五、符号进阶 六、符号进阶拓展 七、本集总结 一、前言 up最近想去上学,就想考个研究生读一读,那就要复习高数,光复习挺没意思的,所以就想着边复习边写文章吧&…...
前端学习-操作元素内容(二十二)
目录 前言 目标 对象.innerText 属性 对象.innerHTML属性 案例 年会抽奖 需求 方法一 方法二 总结 前言 曾经沧海难为水,除却巫山不是云。 目标 能够修改元素的文本更换内容 DOM对象都是根据标签生成的,所以操作标签,本质上就是操作DOM对象,…...
PostgreSql-学习06-libpq之同步命令处理
目录 一、环境 二、介绍 三、函数 1、PQsetdbLogin (1)作用 (2)声明 (3)参数介绍 (4)检测成功与否 2、PQfinish (1)作用 (2࿰…...
Python `str.strip()` 的高级用法详解
Python str.strip 的高级用法详解 1. str.strip() 的基本用法2. str.strip() 的高级用法2.1 移除指定字符2.2 移除多个指定字符2.3 移除换行符和制表符2.4 结合正则表达式的高级处理 3. lstrip() 和 rstrip() 的用法3.1 lstrip():移除左端字符3.2 rstrip()ÿ…...
Vue 3 中的 `update:modelValue` 事件详解
在 Vue 3 中,update:modelValue 事件通常与 v-model 指令一起使用,以实现自定义组件的双向数据绑定。以下是对该事件的详细分析: 事件定义 首先,我们需要在组件中定义 update:modelValue 事件。可以使用 defineEmits 函…...
AI 助力医学伦理知情同意书的完善:守护受试者权益
在医学研究中,知情同意书是保障受试者权益的核心文件,其质量直接关系到研究的伦理合规性。一份完善的知情同意书应清晰、准确且全面地向受试者传达研究的关键信息,确保他们在充分理解的基础上自愿做出参与决策。然而,在实际撰写过…...
【信息系统项目管理师-论文真题】2017上半年论文详解(包括解题思路和写作要点)
更多内容请见: 备考信息系统项目管理师-专栏介绍和目录 文章目录 试题一:论信息系统项目的范围管理解题思路写作要点试题二:论项目采购管理解题思路写作要点试题一:论信息系统项目的范围管理 实施项目范围管理的目的是包括确保项目做且制作所需的全部工作,以顺利完成项目…...
rpc设计的再次思考20251215(以xdb为核心构建游戏框架)
1.服务提供者注册的方式 // 表明这是一个服务提供者,ServerType 和 ServerId从application.properties中读取 // 而且只有当当前服务是Game时,才生效。 或者 条件注解??? RpcProvider(typeServerType.Game) public class GameProvider{MsgReceiver…...
mysql 查看并设置 innodb_flush_log_at_trx_commit 参数
mysql 查看并设置 innodb_flush_log_at_trx_commit 参数 innodb_flush_log_at_trx_commit 是 MySQL 中的一个系统变量,用于控制 InnoDB 存储引擎的日志刷新行为。该变量有三个可选的值: 0:每隔一秒钟,日志缓冲被刷新到日志文件&a…...
spring使用rabbitmq当rabbitmq集群节点挂掉 spring rabbitmq怎么保证高可用,rabbitmq网络怎么重新连接
##spring rabbitmq代码示例 Controller代码 import com.alibaba.fastjson.JSONObject; import com.newland.mi.config.RabbitDMMQConfig; import org.springframework.amqp.core.Message; import org.springframework.amqp.core.MessageProperties; import org.springframewo…...
Java BigDecimal
1. BigDecimal 用于解决浮点型运算时,出现结果失真的问题。 2. BigDecimal创建的构造器、常用方法 构造器说明public BigDecimal(double val)---不推荐将double 类型转为BigDecimalpublic BigDecimal(String val)---推荐将String 类型转为BigDecimal 方法说明pub…...
RFMiD:多疾病检测的视网膜图像分析挑战|文献速递-生成式模型与transformer在医学影像中的应用
Title 题目 RFMiD: Retinal Image Analysis for multi-Disease Detection challenge RFMiD:多疾病检测的视网膜图像分析挑战 01 文献速递介绍 眼部疾病的普遍性与上升趋势 根据世界卫生组织 (WHO) 2019 年《全球视觉报告》,目前全球约有 22 亿人存…...
布隆过滤器
这篇博客我们来说一下布隆过滤器 之前我们在讲redis缓存穿透的时候说可以使用布隆过滤器来解决这个问题 那么我们先来简单复习一下什么时缓存穿透 (一)复习缓存穿透 我们都知道redis可以作为mysql的缓存帮忙抵挡大部分的请求,但是当redis中…...
构建一个rust生产应用读书笔记四(实战6)
本节我们开始使用tracing来记录日志,实际上在生产环境中,更推荐使用tracing作为日志记录的首先,它提供了更丰富的上下文信息和结构化日志记录功能。tracing 不仅可以记录日志信息,还可以跟踪函数调用、异步任务等,适用…...
如何使用git新建本地仓库并关联远程仓库的步骤(详细易懂)
一、新建本地仓库并关联远程仓库的步骤 新建本地仓库 打开终端(在 Windows 上是命令提示符或 PowerShell,在 Linux 和Mac上是终端应用),进入你想要创建仓库的目录。例如,如果你想在桌面上创建一个名为 “my - project”…...
5.最长回文字串
给你一个字符串 s,找到 s 中最长的 回文 子串 。 示例 1: 输入:s "babad" 输出:"bab" 解释:"aba" 同样是符合题意的答案。示例 2: 输入:s "cbbd"…...
数据仓库工具箱—读书笔记02(Kimball维度建模技术概述02、事实表技术基础)
Kimball维度建模技术概述 记录一下读《数据仓库工具箱》时的思考,摘录一些书中关于维度建模比较重要的思想与大家分享🤣🤣🤣 第二章前言部分作者提到:技术的介绍应该通过涵盖各种行业的熟悉的用例展开(赞同…...
【C++】13___STL
一、基本概念 STL(Standard Template Library,标准模板库)STL从广义上分为:容器(container)、算法(algorithm)、迭代器(iterator)容器和算法之间通过迭代器进行无缝连接STL几乎所有的代码都采用了类模板或者函数模板 二、STL六大组件 分别是:容器、算法…...
在 Ubuntu 中启用 root 用户的远程登录权限
1. 概述:为什么需要启用 root 用户远程登录? 在 Ubuntu 中,出于安全原因,默认情况下 root 用户被禁止远程登录。然而,在某些情况下(如需要进行高权限操作的远程管理任务),启用 root…...
android 混淆
前沿 很久没用过混淆功能了,因为之前的包都使用第三方加固了,而且项目开发好几年了,突然要混淆也很麻烦。换了家公司后,感觉还是得混淆代码才行,不然直接暴露源码也太不行了。 启动混淆功能 isMinifyEnabled true …...
6、AI测试辅助-测试报告编写(生成Bug分析柱状图)
AI测试辅助-测试报告编写(生成Bug分析柱状图) 一、测试报告1. 创建测试报告2. 报告补充优化2.1 Bug图表分析 3. 风险评估 总结 一、测试报告 测试报告内容应该包含: 1、测试结论 2、测试执行情况 3、测试bug结果分析 4、风险评估 5、改进措施…...
让人工智能帮我写一个矩阵按键扫描程序
1.前言 嘉立创做了一块编程小车的蓝牙按键遥控器,按键是4*4矩阵的,通过蓝牙发送按键编码值给蓝牙小车(外围设备)。 原理图如下: 板子回来后,因为懒得写按键矩阵扫描程序,想想还是交给人工智能…...
基于MindSpore NLP的PEFT微调
创建notebook 登录控制台 创建notebook 如果出现提示按如下操作 回到列表页面创建notebook参数如下: 配置mindnlp环境 打开GitHub - mindspore-lab/mindnlp: Easy-to-use and high-performance NLP and LLM framework based on MindSpore, compatible with model…...
2024年12月CCF-GESP编程能力等级认证C++编程八级真题解析
本文收录于专栏《C++等级认证CCF-GESP真题解析》,专栏总目录:点这里。订阅后可阅读专栏内所有文章。 一、单选题(每题 2 分,共 30 分) 第 1 题 小杨家响应国家“以旧换新”政策,将自家的汽油车置换为新能源汽车,正在准备自编车牌。自编车牌包括5位数字或英文字母,要求…...
基于微信小程序的小区疫情防控ssm+论文源码调试讲解
第2章 程序开发技术 2.1 Mysql数据库 为了更容易理解Mysql数据库,接下来就对其具备的主要特征进行描述。 (1)首选Mysql数据库也是为了节省开发资金,因为网络上对Mysql的源码都已进行了公开展示,开发者根据程序开发需…...
moment()获取时间
moment 是一个 JavaScript 日期处理类库。 使用: //安装 moment npm install moment -- save引用 //在main.js中全局引入 import moment from "moment"设定moment区域为中国 //import 方式 import moment/locale/zh-cn moment.locale(zh-cn); 挂载全…...
CAD学习 day3
细节问题 快捷键X 分解单独进行操作如果需要制定字体样式选择 gdcbig.shx快捷键AA 算面积 平面布置图 客户沟通 - 会面笔记 - 客户需求(几个人居住、生活方式、功能需求(电竞房、家政柜)、书房、佛龛、儿童房、风格方向)根据客户需求 - 平面方案布置 (建议做三个以上方案) -…...
windows免登录linux
windows 生成秘钥文件 ssh-keygen -t rsa 将公钥传送到服务器 scp C:\Users\xx/.ssh/id_rsa.pub xxxx:/home/ruoyi/id_rsa.pub linux 使用ssh-copy-id -i ~/.ssh/id_rsa.pub userhost 如果禁用root登录,先开启 vim /etc/ssh/sshd_config PermitRootLogin yes …...
边缘计算的方式
做边缘计算这个行业要想赚得到收益,那一定要找到适合自己参与的一种方式。目前参与边缘计算的话,它主要有两个渠道。 第一个就是用盒子来跑,这个盒子的话包括光猫、路由器、摄像头等等,盒子是一条网线带动一个盒子,它…...
Android GO 版本锁屏声音无效问题
问题描述 Android go版本 在设置中打开锁屏音开关,息屏灭屏还是无声音 排查 vendor\mediatek\proprietary\packages\apps\SystemUI\src\com\android\systemui\keyguard\KeyguardViewMediator.java private void setupLocked() {...String soundPath Settings.G…...
Android之RecyclerView显示数据列表和网格
一、RecyclerView的优势 RecyclerView 的最大优势在于,它对大型列表来说非常高效: 默认情况下,RecyclerView 仅会处理或绘制当前显示在屏幕上的项。例如,如果您的列表包含一千个元素,但只有 10 个元素可见࿰…...
汽车发动机电控系统-【传感器】篇
燃油:喷油控制(不多不少) 进气 主传感器MAP:进气压力传感器(微型车)、空气流量传感器MAF 辅助传感器:节气门传感器、水温传感器(提供暖机工况)、进气温度传感器 反馈…...
牛客周赛 Round 72 题解
本次牛客最后一个线段树之前我也没碰到过,等后续复习到线段树再把那个题当例题发出来 小红的01串(一) 思路:正常模拟,从前往后遍历一遍去统计即可 #include<bits/stdc.h> using namespace std; #define int lo…...
Python AI后台服务器
把数据训练放在后台,首先碰到的一个问题是如何高效地从数据库把数据请求下来。 分别试了几个库 modin 号称和pandas能够无缝衔接,试了下,确实pd.read_sql蛮快的,但是下来后数据格式就变了,不太好进行后续处理了conne…...
音视频入门基础:MPEG2-TS专题(19)——FFmpeg源码中,解析TS流中的PES流的实现
一、引言 FFmpeg源码在解析完PMT表后,会得到该节目包含的视频和音频信息,从而找到音视频流。TS流的音视频流包含在PES流中。FFmpeg源码通过调用函数指针tss->u.pes_filter.pes_cb指向的回调函数解析PES流的PES packet: /* handle one TS…...
Qt Q_ENUM enum 转 QString 枚举字符串互转; C++模板应用
Part1: Summary 项目中我们常用到命名,使用 enum 转成 string ,方便简洁;Qt给我们提供了一个很方便的功能 Q_ENUM,可以实现枚举字符串互转; Q_ENUM宏将枚举注册到元对象系统中; QMetaEnum::fromType获取枚…...
Mac配置 Node镜像源的时候报错解决办法
在Mac电脑中配置国内镜像源的时候报错,提示权限问题,无法写入配置文件。本文提供解决方法,青测有效。 一、原因分析 遇到的错误是由于 .npm 目录下的文件被 root 用户所拥有,导致当前用户无法写入相关配置文件。 二、解决办法 在终端输入以下命令,输入管理员密码即可。 su…...