从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
本文详细介绍了基于扩散模型构建的文本到视频生成系统,展示了在MSRV-TT和Shutterstock视频标注数据集上训练的模型输出结果。以下是模型在不同提示词下的生成示例。
首先展示一些模型生成效果展示
提示词:“A person holding a camera”(训练10K步)
拿相机的人物场景
提示词:“Spaceship crossing the bridge”(训练10K步)
飞船穿过桥梁场景
提示词:“News Reporter speaking”(训练10K步)
新闻记者讲话场景
在Moving Mnist数据集上训练的模型(训练5K步)
使用训练模型生成的合成Moving Mnist数据
扩散模型理论基础
扩散模型是当前文本到视频生成领域的主流架构,被广泛应用于OpenAI Sora、Stable Video Diffusion等系统中。本文将从基本原理出发,实现一个基于扩散原理的架构,构建能够根据文本提示生成视频或GIF的模型系统。
扩散架构核心组件
3D U-Net结构:专为视频处理设计,能够有效处理随时间变化的帧序列。此U-Net架构融合了多重注意力机制:
- 时间注意力:负责捕捉帧间的时序关联
- 空间注意力:处理每一帧内的区域关系
- 这些注意力层与特殊功能模块协同工作,从视频数据中提取关键特征
扩散过程原理:模型的工作机制可以简述为以下步骤:
- 向训练视频添加噪声直至其变为纯随机分布
- 模型学习逆向过程,即从噪声中恢复原始信号
- 生成阶段,从随机噪声出发,模型逐步去除噪声
- 文本提示通过BERT转换为嵌入向量,引导UNet的去噪方向
- 通过反复迭代,最终生成与文本语义匹配的视频内容
为了更直观地理解系统架构,下面是一个简化的模型框架图:
架构工作流程详解
- 输入视频处理:系统以原始视频或随机噪声作为起点
- UNet3D编码阶段:视频通过UNet3D编码器处理,该过程逐步降低空间维度并提取关键特征
- UNet3D瓶颈处理:在最小空间维度处理提取的特征,融合全局信息
- UNet3D解码阶段:处理后的特征被送入解码器,逐步恢复空间维度,重建视频结构
- 文本条件引导:输入的文本提示转换为语义嵌入向量,在UNet3D的各个层级提供引导信号
- 扩散过程实现:- 训练阶段:向视频添加噪声,模型学习预测并去除噪声- 生成阶段:从纯噪声开始,模型利用UNet3D逐步去除噪声,生成目标视频
- 输出视频生成:最终输出基于输入条件(噪声和文本提示)生成的视频序列
项目结构
text2video-from-scratch/
├── configs/
│ └── default.yaml # 训练参数和超参数的配置文件
├── src/
│ ├── architecture/
│ │ ├── attention.py # 包含用于注意力机制的Attention和EinopsToAndFrom类
│ │ ├── blocks.py # 包含Block、ResnetBlock和SpatialLinearAttention类(UNet的构建块)
│ │ ├── common.py # 包含架构中使用的常见层和实用工具
│ │ ├── unet.py # 包含主要的Unet3D模型定义
│ │ └── relative_position_bias.py # 包含用于位置编码的RelativePositionBias类
│ ├── data/
│ │ ├── dataset.py # 定义用于加载和预处理视频数据的Dataset类
│ │ └── utils.py # 处理视频和图像数据的实用函数
│ ├── diffusion/
│ │ └── gaussian_diffusion.py # 包含实现扩散过程的GaussianDiffusion类
│ ├── text/
│ │ └── text_handler.py # 使用预训练BERT模型处理文本输入的函数(标记化,嵌入)
│ ├── trainer/
│ │ └── trainer.py # 包含Trainer类,处理训练循环、优化、EMA、保存和采样
│ └── utils/
│ └── helper_functions.py # 通用辅助函数(exists, noop, is_odd, default, cycle等)
├── train.py # 主训练脚本:加载配置,创建模型,扩散,训练器,并开始训练 ├── generate.py # 主生成脚本:加载配置,创建模型,扩散,训练器,并开始生成
核心文件说明:
train.py
是训练模型的主脚本,负责加载配置、初始化模型组件并启动训练流程generate.py
用于利用训练好的模型根据文本提示生成视频内容src
目录包含模型架构、数据处理、扩散过程实现、文本处理和训练工具的核心代码
成功实现本项目需要对面向对象编程(OOP)和神经网络(NN)有深入理解。熟悉PyTorch框架将有助于理解和修改代码实现。源代码地址在文章最后
环境配置
首先需要克隆项目仓库并安装依赖:
git clone https://github.com/FareedKhan-dev/text2video-from-scratch cd text2video-from-scratchpip install -r requirements.txt
库导入
本项目使用的主要库及其功能:
# 系统操作相关
import os # 文件系统操作
import yaml # 配置文件解析
from pathlib import Path # 跨平台路径处理
import subprocess # 执行系统命令
import zipfile # ZIP文件处理# 数据处理相关
import pandas as pd # 结构化数据处理
from tqdm import tqdm # 进度显示# 图像、视频处理
from PIL import Image # 图像处理
from moviepy.editor import VideoFileClip # 视频编辑
from datasets import load_dataset # 数据集加载# PyTorch核心组件
import torch # 深度学习框架
from torch import nn, einsum # 神经网络构建与张量运算
from torch.nn import functional as F # 函数式API
from torch.utils import data # 数据加载工具# 张量操作工具
from einops import rearrange # 张量重排
from einops_exts import rearrange_many, check_shape # 扩展张量操作
from rotary_embedding_torch import RotaryEmbedding # 旋转位置编码# 自然语言处理
from transformers import BertModel, BertTokenizer # BERT模型与分词器# 其他工具
import copy # 对象复制
from torch.optim import Adam # 优化器
from torch.cuda.amp import autocast, GradScaler # 混合精度训练
import math # 数学函数import colorsys # 颜色空间转换
训练数据准备
为了构建高质量的文本到视频生成模型,我们需要多样化的带标注视频数据集。MSR-VTT(Microsoft Research Video to Text)是理想选择,它包含来自20个类别的10,000个视频剪辑,每个剪辑都有对应的英文描述标注。以下是数据获取与处理的实现方法:
# 使用Kaggle API下载数据集
def download_kaggle_dataset(dataset_name: str, download_dir: str) -> None: # 确保目录存在 Path(download_dir).mkdir(parents=True, exist_ok=True) # 调用Kaggle命令行工具下载数据集command = f"kaggle datasets download {dataset_name} -p {download_dir}" subprocess.run(command, shell=True, check=True) # 解压下载的数据文件
def unzip_file(zip_path: str, extract_dir: str) -> None: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) # 可视化数据集中的随机视频样本
def visualize_random_videos(videos_dir: str, num_videos: int = 8) -> None: # 获取所有MP4文件video_files = [f for f in os.listdir(videos_dir) if f.endswith('.mp4')] # 随机抽样指定数量的视频random_videos = random.sample(video_files, num_videos) # 创建显示网格fig, axes = plt.subplots(2, 4, figsize=(12, 6)) axes = axes.ravel() # 处理并显示每个视频的首帧for i, video_file in enumerate(random_videos): video_path = os.path.join(videos_dir, video_file) # 加载视频并提取前2秒帧clip = VideoFileClip(video_path).subclip(0, 2) # 获取首帧并显示frame = clip.get_frame(0) axes[i].imshow(frame) axes[i].axis('off') axes[i].set_title(f"Video {i+1}") plt.tight_layout() plt.show()
实际执行数据下载与可视化:
# 数据集下载与解压
kaggle_dataset_name = 'vishnutheepb/msrvtt'
download_dir = './msrvtt_data'
unzip_dir = './msrvtt_data/msrvtt' download_kaggle_dataset(kaggle_dataset_name, download_dir)
zip_file_path = os.path.join(download_dir, 'msrvtt.zip')
unzip_file(zip_file_path, unzip_dir) # 可视化8个随机视频样本
videos_dir = os.path.join(unzip_dir, 'TrainValVideo') visualize_random_videos(videos_dir)
MSRVTT数据集样本展示
为便于训练,我们需要将视频转换为更高效的格式。我们选择将MP4文件转换为GIF格式并创建对应的文本描述文件:
# 创建训练数据集,转换视频格式并准备标注
def create_training_data(videos_dir: str, output_dir: str, size=(64, 64), duration=2) -> None: Path(output_dir).mkdir(parents=True, exist_ok=True) video_files = [f for f in os.listdir(videos_dir) if f.endswith('.mp4')] for video_file in video_files: video_path = os.path.join(videos_dir, video_file) base_name = os.path.splitext(video_file)[0] gif_path = os.path.join(output_dir, f"{base_name}.gif") txt_path = os.path.join(output_dir, f"{base_name}.txt") # 视频转GIF处理clip = VideoFileClip(video_path).subclip(0, duration) clip = clip.resize(size) clip.write_gif(gif_path, program='ffmpeg') # 创建文本描述文件with open(txt_path, "w") as txt_file: txt_file.write(f"{base_name}") print(f"已处理: {video_file} -> {base_name}.gif and {base_name}.txt") # 执行数据转换
videos_dir = "./msrvtt_data/msrvtt/TrainValVideo"
output_dir = "./training_data" create_training_data(videos_dir, output_dir)
经过处理,我们的训练数据将采用以下结构:
training_data/ |── video1.gif |── video1.txt |── video2.gif |── video2.txt ...
此格式便于模型训练过程中高效地加载和处理数据。
辅助函数定义
为了提高代码的可维护性和模块化程度,我们定义了一系列辅助函数,用于处理常见任务如参数检查、数据迭代和掩码生成等:
# 检查变量是否存在(非None)
def exists(x: Union[None, object]) -> bool: return x is not None # 空操作函数,接受任意参数但不执行任何操作
def noop(*args, **kwargs) -> None: pass # 检查整数是否为奇数
def is_odd(n: int) -> bool: return (n % 2) == 1 # 返回值或默认值
def default(val: Union[None, object], d: Union[object, Callable[[], object]]) -> object: if exists(val): return val return d() if callable(d) else d # 数据加载器循环迭代器
def cycle(dl: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: while True: for data in dl: yield data # 将总数划分为指定大小的组
def num_to_groups(num: int, divisor: int) -> List[int]: groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr # 生成概率掩码
def prob_mask_like(shape: Tuple[int, ...], prob: float, device: torch.device) -> torch.Tensor: if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob # 检查列表或元组是否只包含字符串
def is_list_str(x: Union[List[object], Tuple[object, ...]]) -> bool: if not isinstance(x, (list, tuple)): return False return all([type(el) == str for el in x])
注意力机制实现
注意力机制是扩散视频生成模型的核心组件,使模型能够有选择地关注输入的重要部分,无论是空间区域、时间帧还是文本信息。以下是注意力模块的实现:
class EinopsToAndFrom(nn.Module): def __init__(self, from_einops: str, to_einops: str, fn: Callable[[torch.Tensor], torch.Tensor]) -> None: super().__init__() self.from_einops = from_einops # 输入形状描述self.to_einops = to_einops # 目标形状描述self.fn = fn # 要应用的变换函数def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: shape = x.shape # 创建重建原始形状所需的参数字典reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))# 重排张量到目标形状x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') # 应用变换函数x = self.fn(x, **kwargs)# 恢复原始形状x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim: int, heads: int = 4, dim_head: int = 32, rotary_emb: Optional[nn.Module] = None ) -> None: super().__init__() self.scale = dim_head ** -0.5 # 注意力缩放因子self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb # 可选的旋转位置编码self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) # QKV投影self.to_out = nn.Linear(hidden_dim, dim, bias=False) # 输出投影def forward( self, x: torch.Tensor, pos_bias: Optional[torch.Tensor] = None, focus_present_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) # 分离QKV向量# 如果focus_present_mask全部激活,直接返回值向量if exists(focus_present_mask) and focus_present_mask.all(): values = qkv[-1] return self.to_out(values) # 重排QKV用于多头处理q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) q = q * self.scale # 缩放查询向量# 应用旋转位置编码(如提供)if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # 计算注意力分数sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # 应用位置偏置(如提供)if pos_bias is not None: sim = sim + pos_bias # 处理焦点掩码if focus_present_mask is not None and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # 掩码应用# 数值稳定性优化sim = sim - sim.amax(dim=-1, keepdim=True).detach() # 计算注意力权重attn = sim.softmax(dim=-1) # 加权汇总值向量out = einsum('... h i j, ... h j d -> ... h i d', attn, v) # 重排输出out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) # 应用输出投影
EinopsToAndFrom 类提供了张量形状转换的灵活机制,使模型能够在不同表示之间高效切换。而 Attention 类实现了多头注意力机制的核心算法,支持位置编码、掩码操作和数值稳定性优化,这些都是高质量视频生成的关键要素。
视频生成模型的基础构建模块
在深度学习视频生成架构中,基础构建模块是整个网络功能的关键组件。我们实现的U-Net架构采用了层次化结构,由多个专用模块组成,每个模块都具有特定的数据转换功能。
class Block(nn.Module): def __init__(self, dim: int, dim_out: int) -> None: super().__init__() # 初始化具有内核大小(1, 3, 3)和填充(0, 1, 1)的3D卷积层self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) # 初始化输出维度的层归一化self.norm = nn.LayerNorm(dim_out) # 初始化SiLU激活函数(Sigmoid线性单元)self.act = nn.SiLU() def forward(self, x: torch.Tensor, scale_shift: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor: # 对输入张量应用3D卷积x = self.proj(x) # 对张量应用层归一化x = self.norm(x) # 如果提供了缩放和偏移值,则应用它们if exists(scale_shift): scale, shift = scale_shift # 对张量应用缩放和偏移x = x * (scale + 1) + shift # 应用SiLU激活函数return self.act(x) class ResnetBlock(nn.Module): def __init__(self, dim: int, dim_out: int, *, time_emb_dim: Optional[int] = None) -> None: super().__init__() # 如果指定了time_emb_dim,则创建一个MLP来生成缩放和偏移值self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None # 初始化定义的Block类的两个顺序块self.block1 = Block(dim, dim_out) self.block2 = Block(dim_out, dim_out) # 如果输入和输出维度不同,则应用1x1卷积用于残差连接self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x: torch.Tensor, time_emb: Optional[torch.Tensor] = None) -> torch.Tensor: scale_shift = None # 如果定义了MLP,则处理时间嵌入以生成缩放和偏移因子if exists(self.mlp): assert exists(time_emb), 'time_emb must be passed in when time_emb_dim is defined' time_emb = self.mlp(time_emb) # 通过MLP传递time_embtime_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # 重塑以启用广播scale_shift = time_emb.chunk(2, dim=1) # 将时间嵌入分为缩放和偏移# 应用第一个块,带有可选的缩放/偏移h = self.block1(x, scale_shift=scale_shift) # 应用第二个块,不带缩放/偏移h = self.block2(h) # 返回带有残差连接的结果return h + self.res_conv(x) class SpatialLinearAttention(nn.Module): def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: super().__init__() # 基于头维度的注意力分数的缩放因子self.scale = dim_head ** -0.5 self.heads = heads # 注意力头的数量hidden_dim = dim_head * heads # 多头注意力的总维度# 1x1卷积,用于生成查询、键和值张量self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # 注意力计算后用于输出投影的1x1卷积self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: # 获取输入张量的形状b, c, f, h, w = x.shape # 重新排列输入张量以在注意力机制中处理x = rearrange(x, 'b c f h w -> (b f) c h w') # 应用1x1卷积计算查询、键和值qkv = self.to_qkv(x).chunk(3, dim=1) # 重新排列Q、K、V用于多头注意力q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) # 对查询应用softmax(跨空间位置)q = q.softmax(dim=-2) # 对键应用softmax(跨特征)k = k.softmax(dim=-1) # 缩放查询q = q * self.scale # 基于键和值计算上下文(加权和)context = torch.einsum('b h d n, b h e n -> b h d e', k, v) # 通过将查询应用于上下文来计算注意力输出out = torch.einsum('b h d e, b h d n -> b h e n', context, q) # 将输出重新排列回原始空间格式out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) # 应用输出卷积,投影回输入维度out = self.to_out(out) # 将输出重新排列回原始批量大小和帧数return rearrange(out, '(b f) c h w -> b c f h w', b=b)
视频生成架构的基础构建模块包括三个关键组件:
标准卷积块 (Block):作为神经网络的基础单元,实现了输入特征的基本转换处理。该模块采用3D卷积层、层归一化和SiLU激活函数的组合,支持通过scale-shift操作进行条件处理,使其能够根据不同的条件输入(如时间嵌入)动态调整特征响应。
残差块 (ResnetBlock):通过引入跳跃连接增强了网络训练稳定性,有效缓解了深层网络中的梯度消失问题。此模块包含两个串联的标准卷积块,并支持时间条件嵌入,通过MLP网络将时间信息转换为特征调制信号,从而使模型能够根据时间步信息调整特征表示。
空间线性注意力模块 (SpatialLinearAttention):在视频帧的空间维度上实现高效注意力计算,使模型能够捕获图像内的长距离依赖关系。该模块采用了多头注意力设计,分别对查询和键应用softmax操作,实现了对输入特征的自适应加权处理。
这些基础模块具有不同但相互补充的功能:Block模块处理局部特征转换,ResnetBlock增强训练稳定性并支持时间条件,而SpatialLinearAttention则捕获远距离空间依赖关系。通过组合这些模块,模型能够有效处理视频数据的时空复杂性。
视频生成模型的通用组件
视频生成架构需要多种辅助组件来增强模型功能和训练稳定性。这些组件包括参数平滑机制、归一化层和位置编码等,为模型提供了必要的数据处理能力。
class EMA: def __init__(self, beta: float) -> None: super().__init__() # 存储用于更新移动平均的衰减因子(beta)self.beta = beta def update_model_average(self, ma_model: nn.Module, current_model: nn.Module) -> None: # 使用当前模型的参数更新移动平均模型for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): # 从移动平均模型获取旧权重,从当前模型获取新权重old_weight, up_weight = ma_params.data, current_params.data # 更新移动平均权重ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old: Optional[torch.Tensor], new: torch.Tensor) -> torch.Tensor: # 如果不存在旧值,则返回新值if old is None: return new # 基于beta和新值更新移动平均return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn: nn.Module) -> None: super().__init__() # 存储要在残差块中使用的函数self.fn = fn def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # 应用函数并将输入张量添加到其中,形成残差连接return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim: int) -> None: super().__init__() # 存储位置嵌入的维度self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: # 获取输入张量的设备device = x.device # 正弦和余弦嵌入的一半维度half_dim = self.dim // 2 # 嵌入范围的缩放因子emb_scale = math.log(10000) / (half_dim - 1) # 通过计算缩放arange值的指数创建正弦嵌入emb = torch.exp(torch.arange(half_dim, device=device) * -emb_scale) # 根据输入x应用位置编码(正弦)emb = x[:, None] * emb[None, :] # 连接嵌入的正弦和余弦变换emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim: int) -> nn.ConvTranspose3d: # 返回用于上采样的3D转置卷积层return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) def Downsample(dim: int) -> nn.Conv3d: # 返回用于下采样的3D卷积层return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) class LayerNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5) -> None: super().__init__() # 存储归一化中的数值稳定性的epsilonself.eps = eps # 创建可学习的缩放参数(gamma)self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: # 计算输入张量在通道维度上的方差和均值(dim=1)var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) # 通过减去均值并除以方差进行归一化# 用gamma进行可学习的缩放return (x - mean) / (var + self.eps).sqrt() * self.gamma class RMSNorm(nn.Module): def __init__(self, dim: int) -> None: super().__init__() # 基于输入维度(dim)计算缩放因子self.scale = dim ** 0.5 # 创建可学习的缩放参数(gamma)self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: # 在维度1(通道)上归一化输入张量并应用缩放return F.normalize(x, dim=1) * self.scale * self.gamma class PreNorm(nn.Module): def __init__(self, dim: int, fn: nn.Module) -> None: super().__init__() # 存储归一化后要使用的函数self.fn = fn # 用指定维度初始化层归一化self.norm = LayerNorm(dim) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: # 对输入张量应用归一化x = self.norm(x) # 将归一化的张量传递给函数(例如,注意力或MLP)return self.fn(x, **kwargs)
视频生成架构的通用组件提供了多种功能增强机制:
指数移动平均 (EMA):实现了模型参数的平滑更新策略,有效降低训练过程中的参数波动。通过对历史权重和当前权重进行加权平均,提高了模型的泛化能力和推理稳定性。这在视频生成任务中尤为重要,因为它帮助模型产生更连贯的时序输出。
残差连接包装器 (Residual):为任意函数模块添加跳跃连接,促进梯度在深层网络中的高效传播。这种设计不仅简化了残差路径的实现,还提高了整个架构的训练效率。
正弦位置嵌入 (SinusoidalPosEmb):为一维输入(如时间步)生成周期性位置编码,使模型能够区分不同的时间位置。该组件采用正弦和余弦函数的组合,创建了一种具有良好插值特性的嵌入表示。
上采样和下采样函数:分别通过3D转置卷积和标准3D卷积实现特征图的分辨率变换,维持通道数不变的同时改变空间维度。这些操作在U-Net结构的编码器和解码器部分扮演关键角色。
自定义归一化层:实现了多种特化的归一化策略,包括:
- LayerNorm:实现通道维度上的特征归一化,带有可学习的缩放参数
- RMSNorm:提供基于均方根的归一化,计算复杂度更低
- PreNorm:在应用任意函数前执行归一化预处理,稳定深层网络中的信号传播
这些通用组件协同工作,确保了视频生成模型的训练稳定性和生成质量,为处理高维时空数据提供了必要的计算基础。
相对位置编码机制
相对位置编码在处理序列数据时至关重要,它使模型能够理解和利用输入元素之间的相对位置关系,而不依赖于绝对位置信息。这在视频生成等需要捕捉时空依赖关系的任务中尤为关键。
class RelativePositionBias(nn.Module): def __init__( self, heads: int = 8, # 注意力头的数量num_buckets: int = 32, # 相对位置编码的桶数量max_distance: int = 128 # 要考虑的最大相对距离) -> None: super().__init__() # 存储相对位置偏置的桶数和最大距离self.num_buckets = num_buckets self.max_distance = max_distance # 为相对注意力偏置初始化嵌入层self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket( relative_position: torch.Tensor, num_buckets: int = 32, max_distance: int = 128 ) -> torch.Tensor: # 初始化结果变量(从零开始)ret = 0 # 取相对位置的负数(处理两个方向)n = -relative_position # 将桶数量减半num_buckets //= 2 # 如果位置为负,将其分配到桶的后半部分ret += (n < 0).long() * num_buckets # 获取相对位置的绝对值n = torch.abs(n) # 桶的一半将对应确切的距离max_exact = num_buckets // 2 # 小距离的标志is_small = n < max_exact # 对于较大的距离,使用对数尺度计算桶值val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() # 确保大距离的值不超过最大桶索引val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) # 根据距离是小还是大来更新结果ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n: int, device: torch.device) -> torch.Tensor: # 创建查询位置(q_pos)张量,范围从0到n-1q_pos = torch.arange(n, dtype=torch.long, device=device) # 创建键位置(k_pos)张量,范围从0到n-1k_pos = torch.arange(n, dtype=torch.long, device=device) # 计算每个键相对于每个查询的相对位置(形状:n x n)rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') # 为每对查询和键位置计算相对位置桶rp_bucket = self._relative_position_bucket( rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance ) # 从嵌入层获取相应的相对位置偏置values = self.relative_attention_bias(rp_bucket) # 重新排列值以匹配预期的输出形状(h, i, j)return rearrange(values, 'i j h -> h i j')
RelativePositionBias实现了一种高效的相对位置编码机制,通过桶化处理将相对位置映射到可学习的嵌入空间,具有以下特点:
自适应位置表示:与固定的位置嵌入不同,相对位置偏置学习不同位置元素之间的关系,而非绝对位置,这使模型在泛化到不同长度序列时更加灵活。
非线性桶化机制:通过
_relative_position_bucket
方法实现了一种智能的相对位置量化策略,对近距离位置使用线性分桶(提供精确表示),对远距离位置使用对数分桶(提供高效表示)。这种设计兼顾了计算效率和表示能力。
方向感知:通过区分正负相对距离,模型能够感知序列元素的相对方向,对于捕捉视频中的时序关系至关重要。
多头兼容设计:通过为每个注意力头生成独立的位置偏置,允许不同注意力头专注于不同类型的位置模式和依赖关系。
可扩展性:通过
num_buckets
和
max_distance
参数,该机制可以灵活调整以适应不同长度和复杂度的序列数据。
此位置编码组件通过为自注意力层提供精确的位置信息,使视频生成模型能够更好地理解时间和空间结构,在保持计算效率的同时提高生成质量。
视频数据处理工具
# 将通道数映射到相应图像模式的字典
CHANNELS_TO_MODE = { 1: 'L', # 1通道对应灰度模式('L'模式) 3: 'RGB', # 3通道对应RGB颜色模式 4: 'RGBA' # 4通道对应RGBA颜色模式(带透明度)
} # 生成器函数,用于从多帧图像(如GIF)中提取所有图像
def seek_all_images(img: Image.Image, channels: int = 3): # 确保指定的通道数有效 assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' # 获取通道数对应的模式 mode = CHANNELS_TO_MODE[channels] i = 0 while True: try: # 寻找图像中的第i帧 img.seek(i) # 转换图像帧为所需模式并返回 yield img.convert(mode) except EOFError: # 帧结束(EOF),跳出循环 break i += 1 # 将视频张量转换为GIF并保存到指定路径的函数
def video_tensor_to_gif( tensor: torch.Tensor, path: str, duration: int = 120, loop: int = 0, optimize: bool = True
): # 将视频张量中的每一帧转换为PIL图像 images = map(T.ToPILImage(), tensor.unbind(dim=1)) # 解包第一张图像和其余图像 first_img, *rest_imgs = images # 保存GIF并指定参数 first_img.save( path, save_all=True, # 将所有帧保存为GIF的一部分 append_images=rest_imgs, # 将其他帧附加到GIF中 duration=duration, # 设置每帧的持续时间(毫秒) loop=loop, # 设置GIF的循环次数(0表示无限循环) optimize=optimize # 启用GIF文件优化 ) # 返回图像列表作为结果 return images # 将GIF转换为张量(帧序列)的函数
def gif_to_tensor( path: str, channels: int = 3, transform: T.Compose = T.ToTensor()
) -> torch.Tensor: # 从给定路径打开GIF图像 img = Image.open(path) # 将GIF中的所有帧转换为张量,应用转换 tensors = tuple(map(transform, seek_all_images(img, channels=channels))) # 沿帧维度将张量堆叠成单个张量 return torch.stack(tensors, dim=1) # 恒等函数:返回不变的输入张量
def identity(t, *args, **kwargs): return t # 将图像张量归一化到[-1, 1]范围的函数
def normalize_img(t: torch.Tensor) -> torch.Tensor: # 通过将张量值从[0, 1]缩放到[-1, 1]来归一化 return t * 2 - 1 # 将图像张量反归一化回[0, 1]范围的函数
def unnormalize_img(t: torch.Tensor) -> torch.Tensor: # 通过将张量值从[-1, 1]缩放到[0, 1]来反归一化 return (t + 1) * 0.5 # 确保张量具有指定帧数的函数
def cast_num_frames(t: torch.Tensor, *, frames: int) -> torch.Tensor: # 获取张量中当前的帧数 f = t.shape[1] if f == frames: # 如果帧数已经符合要求,则返回不变的张量 return t if f > frames: # 如果帧数超过需要,则截取张量到所需帧数 return t[:, :frames] # 如果帧数不足,则用零填充张量(无新帧) return torch.nn.functional.padding(t, (0, 0, 0, 0, 0, frames - f))
上述工具函数集实现了对视频和帧序列数据的多种处理能力。
seek_all_images
函数提供了从多帧图像(如GIF)中逐帧提取单帧的机制,而
video_tensor_to_gif
允许将模型生成的视频张量转换为标准GIF格式并保存至存储设备。与此相对应,
gif_to_tensor
则实现了从GIF文件到PyTorch张量的转换,便于后续的神经网络处理。
此外,辅助函数如
normalize_img
和
unnormalize_img
提供了图像值域在标准化区间([-1,1])与图像标准区间([0,1])之间的双向转换,这是深度学习模型处理图像数据的常见需求。
cast_num_frames
函数则灵活处理视频帧数调整,通过截取或填充操作确保视频序列符合模型的输入要求。
数据集转换实现
为了支持扩散模型的训练过程,我们实现了专用的数据集类,用于高效处理视频数据与对应的文本描述。该类确保数据以PyTorch可用的批次形式加载,同时实现了必要的预处理转换。
# 用于处理GIF或视频文件的自定义数据集类
class Dataset(data.Dataset): # 使用所需参数初始化数据集 def __init__( self, folder: str, # 存储数据集的文件夹路径 image_size: int, # 每个图像调整到的大小 channels: int = 3, # 颜色通道数(默认为3,RGB) num_frames: int = 16, # 每个视频提取的帧数(默认为16) horizontal_flip: bool = False, # 是否应用水平翻转增强 force_num_frames: bool = True, # 是否强制视频张量具有确切的`num_frames`帧数 exts: List[str] = ['gif'] # 要查找的文件扩展名列表(默认为['gif']) ) -> None: # 调用父构造函数(来自PyTorch的Dataset) super().__init__() # 初始化数据集属性 self.folder = folder self.image_size = image_size self.channels = channels # 获取文件夹(及子文件夹)中与给定扩展名匹配的所有文件路径 self.paths = [ p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') ] # 定义必要时用于转换帧数的函数 # 如果`force_num_frames`为True,我们应用`cast_num_frames`函数,否则使用恒等函数 self.cast_num_frames_fn = partial(cast_num_frames, frames=num_frames) if force_num_frames else identity # 定义要应用于每个图像的转换(调整大小、随机翻转、裁剪和转换为张量) self.transform = T.Compose([ T.Resize(image_size), # 将图像调整到目标大小 T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), # 如果指定则应用随机水平翻转 T.CenterCrop(image_size), # 中心裁剪图像到目标大小 T.ToTensor() # 将图像转换为PyTorch张量 ]) # 返回数据集中样本的总数 def __len__(self) -> int: return len(self.paths) # 通过索引获取特定样本(图像及其对应文本,如果有) def __getitem__(self, index: int) -> Tuple[torch.Tensor, Optional[str]]: # 获取给定索引的样本文件路径 path = self.paths[index] # 使用`gif_to_tensor`函数将GIF(或视频)转换为张量 # 应用先前定义的转换 tensor = gif_to_tensor(path, self.channels, transform=self.transform) # 将张量转换为具有正确帧数(如果需要) tensor = self.cast_num_frames_fn(tensor) # 检查此图像是否有对应的文本文件(相同名称,.txt扩展名) text_path = path.with_suffix(".txt") if text_path.exists(): # 如果文本文件存在,读取其内容 with open(text_path, 'r') as f: text = f.read() # 返回张量和文件中的文本 return tensor, text else: # 如果没有文本文件,返回张量,文本为`None` return tensor, None
Dataset
类提供了完整的视频数据加载解决方案,通过初始化时接收的参数灵活配置数据处理行为。它在实例化过程中扫描指定文件夹中符合条件的视频文件,并建立转换管道用于图像处理。该类支持数据集大小查询和基于索引的数据获取,核心功能包括:
- 通过
__init__
方法配置数据源位置、图像尺寸、通道数和预处理参数 - 利用
__len__
方法返回数据集的样本总数 - 通过
__getitem__
方法检索指定索引的视频数据及其对应文本描述
此数据集类支持数据增强(如水平翻转),帧数调整,并处理视频与文本的自动关联匹配,为模型训练提供高质量的数据流。
高斯扩散过程
高斯扩散过程是本系统生成视频的核心机制,负责实现视频的逐步去噪生成。该过程模拟了从随机噪声逐渐恢复有意义视频内容的过程,通过深度神经网络指导每一步去噪。
高斯扩散原理示意图
# 根据时间步从张量中提取值的辅助函数
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: b, *_ = t.shape # 获取批量大小 out = a.gather(-1, t) # 根据时间步提取值 return out.reshape(b, *((1,) * (len(x_shape) - 1))) # 重塑以匹配输入形状 # 为beta创建余弦调度的函数
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) # 创建时间网格 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 # 余弦函数 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] # 归一化 betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) # 计算beta值 return torch.clip(betas, 0, 0.9999) # 确保beta值保持在范围内 # 高斯扩散模型的主类
class GaussianDiffusion(nn.Module): def __init__(self, denoise_fn: nn.Module, *, image_size: int, num_frames: int, timesteps: int = 1000): super().__init__() self.denoise_fn = denoise_fn self.image_size = image_size self.num_frames = num_frames betas = cosine_beta_schedule(timesteps) # 获取beta调度 # 初始化模型计算的各种张量 alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) # 注册缓冲区(不通过梯度下降更新的张量) register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # 更多初始化,用于各种系数(用于计算后验和前向过程) register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) # 计算q分布的均值、方差和对数方差的函数 def q_mean_variance(self, x_start: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1. - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance # 从噪声数据预测图像起点的函数 def predict_start_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) # 计算后验分布的函数 def q_posterior(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped # 使用模型预测进行去噪的函数 def p_mean_variance(self, x: torch.Tensor, t: torch.Tensor, clip_denoised: bool) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x)) if clip_denoised: # 如果需要,对去噪后的图像进行裁剪 x_recon = x_recon.clamp(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance # 单个去噪步骤的函数 @torch.inference_mode() def p_sample(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=True) noise = torch.randn_like(x) # 添加随机噪声 return model_mean + noise * (0.5 * model_log_variance).exp() # 返回去噪后的图像 # 生成样本的函数(整个循环) @torch.inference_mode() def p_sample_loop(self, shape: torch.Size) -> torch.Tensor: img = torch.randn(shape, device=self.device) # 从随机噪声开始 for t in reversed(range(self.num_timesteps)): # 迭代去噪 img = self.p_sample(img, t) return (img + 1) * 0.5 # 返回适当范围内的最终图像 # 生成一批样本的函数 @torch.inference_mode() def sample(self, batch_size: int = 16) -> torch.Tensor: return self.p_sample_loop((batch_size, self.channels, self.num_frames, self.image_size, self.image_size)) # 计算噪声图像和去噪图像之间的损失(如L1损失)的函数 def p_losses(self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None) -> torch.Tensor: x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 向图像添加噪声 x_recon = self.denoise_fn(x_noisy, t) # 使用模型去噪 # 计算噪声输出和去噪输出之间的损失 return F.l1_loss(noise, x_recon) # 模型的前向传递 def forward(self, x: torch.Tensor) -> torch.Tensor: t = torch.randint(0, self.num_timesteps, (x.shape[0],)) # 随机时间步 return self.p_losses(x, t) # 计算损失
GaussianDiffusion
类实现了扩散过程的核心机制,它基于指定的去噪网络构建完整的视频生成流程。该类在初始化时计算扩散过程所需的各种参数,包括基于余弦调度的噪声水平控制。主要组件包括:
- 前向扩散过程:通过
q_sample
方法向原始视频添加噪声 - 反向去噪过程:通过
p_sample
和p_sample_loop
方法从纯噪声逐步恢复视频内容 - 损失计算:
p_losses
方法衡量预测噪声与实际噪声之间的差异 - 条件生成:
sample
方法支持基于文本条件的视频生成
扩散过程的精妙之处在于通过数学建模将视频生成分解为多步有监督去噪任务,每一步都基于神经网络对噪声分布的准确预测,最终实现从纯噪声到高质量视频的演化。
文本处理模块
为了实现文本条件引导的视频生成,我们需要将文本转换为神经网络可处理的向量表示。下面实现了基于BERT的文本编码系统:
# 检查值是否存在(不为None)的函数
def exists(val: Optional[Union[torch.Tensor, any]]) -> bool: return val is not None # 将模型和标记器变量初始化为None
MODEL: Optional[BertModel] = None
TOKENIZER: Optional[BertTokenizer] = None
BERT_MODEL_DIM: int = 768 # BERT模型输出的维度大小 # 获取BERT模型标记器的函数
def get_tokenizer() -> BertTokenizer: global TOKENIZER if not exists(TOKENIZER): # 如果标记器尚未加载 TOKENIZER = BertTokenizer.from_pretrained('bert-base-cased') # 加载标记器 return TOKENIZER # 获取BERT模型的函数
def get_bert() -> BertModel: global MODEL if not exists(MODEL): # 如果模型尚未加载 MODEL = BertModel.from_pretrained('bert-base-cased') # 加载BERT模型 if torch.cuda.is_available(): # 如果GPU可用 MODEL = MODEL.cuda() # 将模型移至GPU return MODEL # 对输入文本(单个字符串或字符串列表)进行分词的函数
def tokenize(texts: Union[str, List[str], Tuple[str]]) -> torch.Tensor: if not isinstance(texts, (list, tuple)): # 如果输入是单个字符串,将其转换为列表 texts = [texts] tokenizer = get_tokenizer() # 获取标记器 encoding = tokenizer.batch_encode_plus( texts, # 输入文本 add_special_tokens=True, # 为BERT添加特殊标记 padding=True, # 将序列填充到相同长度 return_tensors='pt' # 作为PyTorch张量返回 ) return encoding.input_ids # 返回标记ID(数值表示) # 从标记ID获取BERT嵌入(特征)的函数
@torch.no_grad() # 推理时不需要跟踪梯度
def bert_embed( token_ids: torch.Tensor, return_cls_repr: bool = False, # 是否仅返回[CLS]标记表示 eps: float = 1e-8, # 小值,防止除零 pad_id: int = 0 # 填充标记ID(BERT通常为0)
) -> torch.Tensor: model = get_bert() # 获取BERT模型 mask = token_ids != pad_id # 为填充标记创建掩码(忽略它们) if torch.cuda.is_available(): # 如果GPU可用,将张量移至GPU token_ids = token_ids.cuda() mask = mask.cuda() # 运行BERT模型并获取输出(所有层的隐藏状态) outputs = model( input_ids=token_ids, attention_mask=mask, # 只关注非填充标记 output_hidden_states=True # 获取所有层的隐藏状态 ) hidden_state = outputs.hidden_states[-1] # 获取最后的隐藏状态(最后一层) if return_cls_repr: # 如果我们需要[CLS]标记表示,则返回它 return hidden_state[:, 0] # 如果没有掩码,返回所有隐藏状态的平均值 if not exists(mask): return hidden_state.mean(dim=1) # 如果有掩码,计算忽略填充标记的平均值 mask = mask[:, 1:] # 移除第一个标记的填充 mask = rearrange(mask, 'b n -> b n 1') # 重新排列以便广播 numer = (hidden_state[:, 1:] * mask).sum(dim=1) # 对掩码标记求和 denom = mask.sum(dim=1) # 计算非填充标记的数量 masked_mean = numer / (denom + eps) # 计算掩码平均值(避免除零) return masked_mean # 返回最终嵌入(平均值或[CLS]表示)
文本处理模块提供了从自然语言文本到语义向量的转换能力,是文本条件视频生成的关键组件。该模块采用预训练的BERT模型将文本转换为上下文感知的嵌入向量,实现了以下核心功能:
- 惰性加载机制:通过全局变量和初始化检查,确保BERT模型和分词器只加载一次
- 文本分词:将原始文本转换为BERT可处理的标记ID序列
- 向量嵌入提取:从BERT模型获取文本的语义表示,支持多种嵌入方式- 句子级表示:使用[CLS]标记作为整个文本的表示- 平均池化表示:对所有非填充标记的隐藏状态进行平均
这些文本嵌入随后被注入到扩散模型中,引导视频生成过程朝向与文本语义一致的方向发展,实现"文本到视频"的精确转换。
训练器设计
训练器模块封装了模型训练的完整流程,包括数据加载、优化、模型保存等关键环节。下面是训练器的核心实现:
class Trainer: def __init__(self, diffusion_model: nn.Module, folder: str, *, ema_decay: float = 0.995, train_batch_size: int = 32, train_lr: float = 1e-4, train_num_steps: int = 100000, gradient_accumulate_every: int = 2, amp: bool = False, save_model_every: int = 1000, results_folder: str = './results'): # 初始化训练器、数据集、优化器和其他配置 self.model = diffusion_model # 扩散模型 self.ema = EMA(ema_decay) # 用于平均权重的EMA模型 self.ema_model = copy.deepcopy(self.model) # EMA的副本 self.batch_size = train_batch_size # 批量大小 self.train_num_steps = train_num_steps # 总训练步数 self.ds = Dataset(folder, image_size=diffusion_model.image_size) # 视频数据集 self.dl = cycle(torch.utils.data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True)) # 数据加载器 self.opt = Adam(diffusion_model.parameters(), lr=train_lr) # 优化器 self.step = 0 # 步数计数器 self.amp = amp # 混合精度标志 self.scaler = GradScaler(enabled=amp) # 用于混合精度的缩放器 self.results_folder = Path(results_folder) # 保存结果的文件夹 self.results_folder.mkdir(exist_ok=True, parents=True) # 如果不存在则创建结果文件夹 def reset_parameters(self): # 重置EMA模型以匹配模型的参数 self.ema_model.load_state_dict(self.model.state_dict()) def step_ema(self): # 如果训练步数超过阈值,则更新EMA模型 if self.step >= 2000: # 2000步后开始更新EMA self.ema.update_model_average(self.ema_model, self.model) def save(self, milestone: int): # 在里程碑保存模型、EMA模型和优化器状态 torch.save({'step': self.step, 'model': self.model.state_dict(), 'ema': self.ema_model.state_dict(), 'scaler': self.scaler.state_dict()}, self.results_folder / f'model-{milestone}.pt') def load(self, milestone: int): # 从检查点加载模型 data = torch.load(self.results_folder / f'model-{milestone}.pt') self.step = data['step'] self.model.load_state_dict(data['model']) self.ema_model.load_state_dict(data['ema']) self.scaler.load_state_dict(data['scaler']) def train(self, log_fn: Callable[[dict], None] = noop): # 训练循环 while self.step < self.train_num_steps: for _ in range(self.gradient_accumulate_every): # 在多个步骤上累积梯度 data = next(self.dl) # 加载数据 video_data, text_data = data[0].cuda(), data[1] if len(data) == 2 else None # 将数据移至GPU with autocast(enabled=self.amp): # 混合精度 loss = self.model(video_data, cond=text_data) # 前向传递 self.scaler.scale(loss / self.gradient_accumulate_every).backward() # 反向传播损失 print(f'{self.step}: {loss.item()}') # 打印损失 if self.step % 10 == 0: # 每10步更新EMA self.step_ema() # 如有必要,使用梯度裁剪的优化器步骤 self.scaler.unscale_(self.opt) nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scaler.step(self.opt) self.scaler.update() self.opt.zero_grad() # 每指定步数保存模型 if self.step % self.save_model_every == 0: self.save(self.step // self.save_model_every) log_fn({'loss': loss.item()}) # 记录损失 self.step += 1 # 增加步数 print('训练完成。')
Trainer
类提供了扩散模型训练的完整框架,集成了现代深度学习训练技术,包括:
- 模型权重管理:通过指数移动平均(EMA)机制维护模型参数的平滑版本
- 优化控制:支持混合精度训练、梯度累积和梯度裁剪,提高训练效率和稳定性
- 检查点机制:定期保存模型状态,支持训练中断后的恢复
- 数据流管理:高效处理视频数据与文本描述的批次加载
训练器类的设计充分考虑了扩散模型训练的特殊需求,通过配置参数如批量大小、学习率和训练步数等,可以灵活适应不同的训练场景和计算资源。整体训练流程由
train
方法协调,它实现了数据加载、前向传播、损失计算、反向传播和参数更新的完整循环。
配置系统
为确保模型训练和推理过程的可复现性和灵活性,我们实现了配置系统来管理所有相关参数:
# 内容:用于使用文本到视频扩散训练TinySora模型的默认配置文件 training_data_dir: "./training_data" # 包含训练数据(文本和视频帧)的目录 model: dim: 64 # 模型的维度(嵌入大小) use_bert_text_cond: True # 为模型输入启用基于BERT的文本条件 dim_mults: [1, 2, 4, 8] # 每个模型块的缩放因子(增加深度) init_dim: null # 初始维度,未指定(默认为None) init_kernel_size: 7 # 初始层的内核大小(通常用于卷积层) use_sparse_linear_attn: True # 为效率启用稀疏注意力机制 block_type: "basic" # 模型块的类型(例如,'basic'、'resnet'等) diffusion: image_size: 32 # 视频帧的高度和宽度 num_frames: 5 # 视频中的帧数(序列长度) timesteps: 10 # 训练期间使用的扩散时间步数 loss_type: "l1" # 优化的损失函数('l1'表示L1损失) use_dynamic_thres: False # 是否在训练期间使用动态阈值 dynamic_thres_percentile: 0.9 # 用于动态阈值的阈值百分比 trainer: ema_decay: 0.995 # 模型权重的指数移动平均衰减率 train_batch_size: 2 # 训练期间每批的样本数 train_lr: 0.0001 # 训练学习率 train_num_steps: 10000 # 总训练步数(轮次) gradient_accumulate_every: 1 # 梯度累积频率(1表示不累积) amp: False # 是否使用自动混合精度进行训练(默认:False) step_start_ema: 2000 # 开始应用EMA平滑的步骤 update_ema_every: 10 # 更新EMA权重的频率(每10步) save_model_every: 10 # 每10步保存一次模型 results_folder: "./saved_models" # 保存结果(模型、样本)的文件夹 num_sample_rows: 4 # 采样期间显示的行数(可视化) max_grad_norm: null # 用于裁剪的最大梯度范数(null表示不裁剪)
配置系统采用YAML格式,将参数组织为三个主要部分:
- 模型配置:定义了架构规模和特性,如维度、注意力机制类型和文本条件使用
- 扩散过程配置:设置了图像尺寸、帧数和扩散步数等关键参数
- 训练器配置:管理训练流程的各方面,包括批量大小、学习率和保存频率
这种参数化设计使研究者能够在不修改代码的情况下轻松调整模型行为,便于进行实验和比较不同配置的效果。
模型初始化与训练
完成所有组件的定义后,我们可以初始化模型并启动训练过程:
# 使用模型的配置参数初始化3D U-Net模型。
# 此模型被移动到GPU(cuda)。
model = Unet3D(**config['model']).cuda() # 以U-Net模型作为去噪函数初始化GaussianDiffusion模型。
# 从`config['diffusion']`加载扩散过程的其他配置参数。
# 模型被移动到GPU(cuda)。
diffusion = GaussianDiffusion( denoise_fn = model, # 该模型将用于在扩散过程中去除噪声图像。 **config['diffusion'] # 其他扩散设置,如时间步、噪声调度等。
).cuda() # 使用扩散模型、训练配置和包含训练数据的文件夹初始化Trainer类。
# 这也被移动到GPU。
trainer = Trainer( diffusion_model = diffusion, # 要训练的扩散模型。 **config['trainer'], # 训练过程的配置设置(例如,学习率、批量大小)。 folder = config['training_data_dir'] # 存储训练数据的目录。
) # 开始训练 trainer.train()
训练过程会在每次迭代后显示损失值,输出类似于:
0: 0.9512512 1: 0.5235211 ...
训练完成后,模型权重将保存在配置的
results_folder
目录中,随时可用于视频生成。
视频生成实现
要使用训练好的模型生成视频,我们需要实现以下辅助函数:
def generate_video(diffusion: GaussianDiffusion, text: str, batch_size: int, cond_scale: float) -> torch.Tensor: """使用训练好的扩散模型生成视频。""" with torch.no_grad(): video = diffusion.sample(cond=[text], batch_size=batch_size, cond_scale=cond_scale) return video def save_video_as_gif_pil(video_tensor: torch.Tensor, output_path: str) -> None: video_np = (video_tensor.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() * 255).astype(np.uint8) frames = [Image.fromarray(frame) for frame in video_np] frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=100, loop=0) print(f"已保存GIF:{output_path}")
以下是使用训练好的模型生成视频的完整流程:
# 对于推理,我们应该加载预训练模型
DEFAULT_MODEL_PATH = "./saved_models"
DEFAULT_OUTPUT_DIR = "./results" # 找到最新的模型检查点
model_path = DEFAULT_MODEL_PATH
if os.path.isdir(model_path): checkpoint_files = [f for f in os.listdir(model_path) if f.endswith(".pt")] if not checkpoint_files: raise FileNotFoundError(f"在{model_path}中未找到模型检查点") checkpoint_files.sort() model_path = os.path.join(model_path, checkpoint_files[-1])
print('从路径加载模型:', model_path) trainer.load(milestone=-1) # 加载最新模型 # 生成视频
text_prompt = "News Reporter talking"
batch_size = 1
cond_scale = 2.0
generated_video = generate_video(diffusion, text_prompt, batch_size, cond_scale) # 保存视频
gif_filename = sanitize_filename(text_prompt) + ".gif"
output_path = os.path.join(DEFAULT_OUTPUT_DIR, gif_filename) # 如果输出目录不存在,则创建
Path(DEFAULT_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) save_video_as_gif_pil(generated_video, output_path)
模型生成的视频示例(经过10K轮训练):
"新闻记者讲话"提示生成的视频
未来研究方向
本项目可通过以下方向进一步拓展和优化:
- 数据多样化:除MSR-VTT外,可使用COCO或ActivityNet等数据集进行训练,或针对特定应用场景收集定制数据
- 架构改进:- 增加时间分辨率,生成更长的视频序列- 提高空间分辨率,生成更高清的视频内容
- 训练策略优化:- 实现渐进式训练,先学习低分辨率再提升到高分辨率- 探索不同的噪声调度方案
- 新的条件控制:- 添加姿态引导或布局控制- 结合图像和文本作为混合条件
通过以上改进,可以构建更强大、生成更高质量视频的系统,进一步推动文本到视频生成技术的发展。
github地址:
https://avoid.overfit.cn/post/88567712b4f547469d74113f6d0810e0
—— 作者:Fareed Khan
相关文章:
从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
本文详细介绍了基于扩散模型构建的文本到视频生成系统,展示了在MSRV-TT和Shutterstock视频标注数据集上训练的模型输出结果。以下是模型在不同提示词下的生成示例。 首先展示一些模型生成效果展示 提示词:“A person holding a camera”(训练…...
每天学一个 Linux 命令(14):cat
Linux 文件查看与合并命令:cat cat(全称 concatenate)是 Linux 中用于查看文件内容、合并文件或创建简单文件的基础命令。它操作简单但功能灵活,是日常文件处理的常用工具。 1. 命令作用 查看文件内容:直接输出文件内容到终端。合并文件:将多个文件内容合并输出或保存到…...
05--MQTT物联网协议
一、MQTT的概念 MQTT 协议快速入门 2025:基础知识和实用教程 | EMQ 1.MQTT(Message Queuing Telemetry Transport)是一种轻量级、基于发布-订阅模式的消息传输协议,适用于资源受限的设备和低带宽、高延迟或不稳定的网络环境。它…...
免费下载 | 2025天津大学:智能制造与数字孪生技术:面向可持续制造方向发展
一、新一代智能制造模式下的思考 当代智能制造的发展阶段 智能制造定义:智能制造是基于新一代信息通信技术与先进制造技术深度融合,贯穿于设计、生产、管理、服务等制造活动的各个环节,具有自感知、自学习、自决策、自执行、自适应等功能的新…...
考研单词笔记 2025.04.12
aware a知道的,意识到的,警觉的 awareness n意识,了解,觉察 conscious a有意识的,意识到的,有意的,刻意的,神志清醒的,慎重的,关注的 unconscious a无意识…...
八股总结(Java)持续更新!
八股总结(java) ArrayList和LinkedList有什么区别 ArrayList底层是动态数组,LinkedList底层是双向链表;前者利于随机访问,后者利于头尾插入;前者内存连续分配,后者通过指针连接多块不连续的内存…...
SpringBoot3快速入门笔记
springboot3简介 SpringBoot 帮我们简单、快速地创建一个独立的、生产级别的 Spring 应用(说明:SpringBoot底层是Spring) 大多数 SpringBoot 应用只需要编写少量配置即可快速整合 Spring 平台以及第三方技术 特性: ● 快速创建…...
vue3中,element-plus中el-input的v-model和value的用法示例
el-input的v-model,邦定响应式变量 <el-col :span"6"><el-form-item label"检验类别" prop"verifyType"><el-input v-model"applyAllInfo.applyBasicInfo.verifyTypeName" readonly /></el-form-item…...
python求π近似值
【问题描述】用公式π/4≈1-1/31/5-1/7..1/(2*N-1).求圆周率PI的近似值。 从键盘输入一个整数N值,利用上述公式计算出π的近似值,然后输出π值,保留小数后8位。 【样例输入】1000 【样例输出】3.14059265 def countpi(N):p0040nowid0for i i…...
Gerapy二次开发:搜索器组件设计开发与应用(Vue父子组件通信)
搜索器组件设计开发与应用 写在前面搜索器字段定义与样式设计具体实现components/Search.vuedeploy/Index.vue后端views.py运行效果总结欢迎加入Gerapy二次开发教程专栏! 本专栏专为新手开发者精心策划了一系列内容,旨在引领你深入探索Gerapy框架的二次迭代之旅。 本专栏将全…...
深入解析Python爬虫技术:从基础到实战的功能工具开发指南
一、引言:Python 爬虫技术的核心价值 在数据驱动的时代,网络爬虫作为获取公开数据的重要工具,正发挥着越来越关键的作用。Python 凭借其简洁的语法、丰富的生态工具以及强大的扩展性,成为爬虫开发的首选语言。根据 Stack Overflow 2024 年开发者调查,68% 的专业爬虫开发者…...
Python爬虫-爬取全球股市涨跌幅和涨跌额数据
前言 本文是该专栏的第52篇,后面会持续分享python爬虫干货知识,记得关注。 本文中,笔者将基于Python爬虫,实现批量采集全球股市行情(亚洲,美洲,欧非,其他等)的各股市“涨跌幅”以及“涨跌额”数据。 具体实现思路和详细逻辑,笔者将在正文结合完整代码进行详细介绍。…...
【NLP 59、大模型应用 —— BPE 算法】
你和生生不息的河流,生动了我人生中的美好瞬间 —— 25.4.11 一、词表的构造问题 为了nlp模型训练,词表(字表)是必要的 统计训练语料中的所有字符(或词)是一种做法,但是容易出现一些问题&…...
SQL基础入门:从CRUD到JOIN再到索引(通俗易懂版)
一、为什么需要SQL? 想象你在管理一个图书馆: 传统方法:手动记录每本书的位置、借阅者、归还日期SQL方法:用数据库系统自动管理,快速查询《Java编程思想》在哪个书架 SQL(Structured Query Language&…...
系统编程1(进程的概念与原理)
进程的概念与原理 计算机组成部分一般遵循冯诺依曼结构,也就是由控制器、运算器、存储器、输入设备、输出设备五个部分组成。 ⦁ 程序的编译 一般在编写出程序之后,并不能直接运行,而是需要把程序通过编译器进行编译,生成可执行…...
Git基础知识
Git基础知识 目录 一、Git简介 1.1 什么是Git?1.2 基本概念1.3 Git与其他版本控制系统的区别 二、Git安装与配置 2.1 安装Git2.2 基础配置2.3 高级配置2.4 多账户配置 三、基本操作 3.1 创建仓库3.2 基本工作流3.3 分支操作3.4 查看历史 四、高级操作 4.1 撤销修改…...
【Flink运行时架构】核心组件
在Flink的运行架构中,有两大比较重要的组件:作业管理器(JobManager)和任务管理器(TaskManager)。 Flink的作业提交与任务处理时的系统如下图所示。 其中,客户端并不是处理系统的一部分ÿ…...
【区块链安全 | 第四十篇】合约审计之delegatecall(二)
文章目录 漏洞代码代码分析攻击流程攻击代码前文重现修复建议审计思路 在阅读本文之前,请确保已先行阅读:【区块链安全 | 第三十九篇】合约审计之delegatecall(一) 漏洞代码 存在一漏洞代码如下: // 库合约…...
Redis实现分布式定时任务
设计思路 任务表示:每个任务通过一个特定格式的键来表示。键名可以包含任务ID等信息,值可以是任务的具体内容或指向任务详情的引用。过期机制:利用Redis的EXPIRE命令为任务设置过期时间,当到达设定的时间点时,Redis会…...
ERC20合约的基本调用
文章目录 ERC20合约的基本调用合约功能compile.js 代码读取文件 进行合约编译获取二进制对象导出对象 index.js 代码编译合约读取私钥设置收款账户构造 web3 对象获取账户地址获取 abi 和 bin创建合约交易部署合约构造转账交易验证转账后余额 测试项目目录执行查询 ERC20合约的…...
『Kubernetes(K8S) 入门进阶实战』实战入门 - Pod 详解
『Kubernetes(K8S) 入门进阶实战』实战入门 - Pod 详解 Pod 结构 每个 Pod 中都可以包含一个或者多个容器,这些容器可以分为两类 用户程序所在的容器,数量可多可少Pause 容器,这是每个 Pod 都会有的一个根容器,它的作用有两个 可…...
【React框架】什么是 Vite?如何使用vite自动生成react的目录?
什么是 Vite? Vite 是一个基于原生 ES Modules 开发的前端构建工具,由 Evan You(Vue 的作者)开发。它最大的特点包括: 极速冷启动:因为利用了浏览器原生的 ES Modules,所以在开发时无需等待整…...
JS实现文件点击或者拖拽上传
B站看到了渡一大师课的切片,自己实现了一下,做下记录 效果展示 分为上传前、上传中和上传后 实现 分为两步 界面交互网络请求 源码如下 upload.html <!DOCTYPE html> <html lang"zh-CN"><head><meta charset&q…...
【Vue #3】指令补充样式绑定
一、指令修饰符 Vue 的指令修饰符(Directive Modifiers)是 Vue 模板语法中的重要特性,它们以半角句号 . 开头,用于对指令的绑定行为进行特殊处理 修饰符作用如下: 简化事件处理(如阻止默认行为、停止冒泡…...
Vue.js组件安全工程化演进:从防御体系构建到安全性能融合
——百万级流量场景下的安全组件架构与源码级解决方案 文章目录 总起:安全工程化的组件革命 分论: 一、现存组件架构的七宗罪与安全改造路径 1.1 组件生态安全赤字现状 1.2 架构级安全缺陷深度剖析 1.3 性能与安全的死亡螺旋 二、百万级…...
LINUX基础 [二] - Linux常见指令
目录 💻前言 💻指令 🎮ls指令 🎮pwd指令 🎮whoami指令 🎮cd指令 🎮clear指令 🎮touch指令 🎮mkdir指令 🎮rmdir指令 🎮rm指令 &#…...
Linux进阶命令
目录 一、touch 1. 基本语法 2. 常用选项 二、which 1. 基本语法 2. 主要功能 3. 常用选项 三、find 1. 基本语法 2. 常用选项和表达式 四、more 1. 基本语法 2. 常用操作 3. 对比 more 和 less 五、grep 1. 基本语法 2. 常用选项 六、wc 1. 基本语法 2. 常…...
【Spring Boot 过滤器】
文章目录 前言一、什么是过滤器 Filter?二、Spring Boot 中使用 Filter 的方式1. 使用 Component 注解2. 使用 FilterRegistrationBean 显式注册 三、自定义过滤器示例1. 引入必要依赖2. 创建一个自定义 Filter3. 使用 FilterRegistrationBean 显式注册 四、多个 Fi…...
SPI通讯的软硬件NSS SSM SSI
学习自记: 1. NSS(Slave Select,从设备选择) 功能: NSS是SPI通信中用于选择从设备的信号线。主设备通过拉低NSS信号选中某个从设备,使其参与通信。通信结束后,主设备释放NSS&#…...
Java基础:集合List、Map、Set(超详细版)
集合体系概述 Collection常用方法 补充:addAll() Collection的遍历方式 迭代器 增强for(空集合可以,null不可以) lambda 集合对象存储对象原理 遍历方式的区别 List集合 特点、特有方法 遍历方式 (同上)…...
vue+leaflet 区域划分_反向遮罩层
leaflet 区域划分_遮罩层 geojson在线生成器网址:(https://datav.aliyun.com/portal/school/atlas/area_selector) 点击前往阿里云geojson生成器 效果图: 实现下面效果,只需要把addSateLayer函数的调用取消掉就好了. //添加遮罩层代码function addMask() {var latlngs;var fe…...
聊一聊原子操作和弱内存序
1、原子操作概念 在并发编程中,原子操作(Atomic Operation)是实现线程安全的基础机制之一。从宏观上看,原子操作是“不可中断”的单元,但若深入微观层面,其本质是由底层处理器提供的一组特殊指令来保证其原…...
免费送源码:Java+ssm+MySQL 校园二手书销售平台设计与实现 计算机毕业设计原创定制
摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对校园二手书销售平台等问题,对校…...
DAPP实战篇:使用ethersjs连接智能合约并输入地址查询该地址余额
本系列目录 专栏:区块链入门到放弃查看目录-CSDN博客文章浏览阅读400次。为了方便查看将本专栏的所有内容列出目录,按照顺序查看即可。后续也会在此规划一下后续内容,因此如果遇到不能点击的,代表还没有更新。声明:文中所出观点大多数源于笔者多年开发经验所总结,如果你…...
14.【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--CAP
CAP 是一款专为 .NET 生态设计的开源框架,其核心目标是解决微服务中跨服务数据一致性问题。在分布式系统中,传统事务无法跨服务保证数据一致性,CAP 通过本地事务与消息记录绑定,再利用消息中间件(如 RabbitMQ、Kafka 等…...
智能资源管理机制-重传机制
一、发送端资源管理的核心机制 1. 滑动窗口(Sliding Window) 这是TCP协议的核心优化设计: 窗口动态滑动:发送端不需要保留所有已发送的分组,只需维护一个"发送窗口"窗口大小:由接收方通告的接…...
【Linux网络与网络编程】08.传输层协议 UDP
传输层协议负责将数据从发送端传输到接收端。 一、再谈端口号 端口号标识了一个主机上进行通信的不同的应用程序。在 TCP/IP 协议中,用 "源IP","源端口号","目的 IP","目的端口号"&…...
局域网下ESP32-S3 LED灯的UDP控制
在局域网下通过IP地址控制ESP32-S3上的LED,可以使用UDP或TCP协议。以下是一个基于UDP协议的完整示例,包括ESP32-S3的服务器代码和一个简单的Python客户端代码。 ESP32-S3 服务器代码 import socket import time import network import machineled Non…...
call、bind、apply
call、bind、apply它们三个都是函数的方法,都可以用于改变this的指向问题。 var person "liangxiao" let obj {name:"张三",say:function() {console.log(this.name);} }obj.say(); setTimeout(function() {obj.say(); },1000) obj.say()打…...
Redis 哨兵模式 搭建
1 . 哨兵模式拓扑 与 简介 本文介绍如何搭建 单主双从 多哨兵模式的搭建 哨兵有12个作用 。通过发送命令,让Redis服务器返回监控其运行状态,包括主服务器和从服务器。 当哨兵监测到master宕机,会自动将slave切换成master,然后通过…...
客户端负载均衡与服务器端负载均衡详解
客户端负载均衡与服务器端负载均衡详解 1. 客户端负载均衡(Client-Side Load Balancing) 核心概念 定义:负载均衡逻辑在客户端实现,客户端主动选择目标服务实例。典型场景:微服务内部调用(如Spring Cloud…...
Ningx负载均衡
Ningx负载均衡 upstream(上游)配置负载均衡1、weight(加权轮询)2、ip_hash(负载均衡)3、url hash负载均衡4、least_conn(最小连接负载均衡) upstream(上游)配置负载均衡 Nginx负载均衡 参考: nginx从安装…...
头歌软件工程导论UML画图题(基于starUML)
一.结构化分析方法-数据流图 本关卡需要画图的一共有5关,直接将此图画好每关提交一次即可,以下的所有图均以此方法提交 二.面向对象分析之用例图 三.面向对象分析之类图 注意此处创建Class之后,双击Class出现以下选项 点击相应的选项创建属性…...
智能车摄像头开源—9 动态权、模糊PID、速度决策、路径优化
目录 一、前言 二、动态权 1.概述 2.偏差值加动态权 三、模糊PID 四、速度决策 1.曲率计算 2.速度拟合 3.速度控制 五、路径 六、国赛视频 一、前言 在前中期通过识别直道、弯道等元素可进行加减速操作实现速度的控制,可进一步缩减一圈的运行速度ÿ…...
java基础 this和super的介绍
this和super this关键字的用法super关键字的用法this与super的区别和注意事项 this关键字的用法 this是自身的一个对象,代表对象本身,可以理解为:指向对象本身的一个指针 class Person{private String name;private int age;public String …...
《Python星球日记》第25天:Pandas 数据分析
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 订阅专栏:《Python星球日记》 目录 一、引言二、数据分组与聚合1. 分组操…...
C++在Linux上生成动态库并调用接口测试
加减乘除demo代码 项目结构 CPP/ ├── calculator.cpp ├── calculator.h ├── main.cpp 头文件 #ifndef CALCULATOR_H #define CALCULATOR_H#ifdef __cplusplus extern "C" {#endifdouble add(double a, double b);double subtract(double a, double b…...
Cesium.js(6):Cesium相机系统
Camera表示观察场景的视角。通过操作摄像机,可以控制视图的位置、方向和角度。 帮助文档:Camera - Cesium Documentation 1 setView setView 方法允许你指定相机的目标位置和姿态。你可以通过 Cartesian3 对象来指定目标位置,并通过 orien…...
机器学习中的数学(PartⅡ)——线性代数:概述
首先引入代数和线性代数的概念: 在将一些直观的、基于经验或直觉的概念转化为严格的数学或逻辑定义时,一种常用方法是构建一组对象和一组操作这些对象的规则,这就是代数。线性代数是研究向量和某些操作向量的规则。 其次从更广泛的意义上定…...
基于双闭环PID控制器的永磁同步电机控制系统匝间故障Simulink仿真
欢迎微♥关注“电击小子程高兴的MATLAB小屋”获取巨额优惠 1.模型简介 本仿真模型基于MATLAB/Simulink(版本MATLAB 2013Rb)软件。建议采用matlab2013 Rb及以上版本打开。(若需要其他版本可联系代为转换,高于该版本的matlab均可正…...