当前位置: 首页 > news >正文

[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构

link

[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构

目录
  • [源码解析] 模型并行分布式训练Megatron (2) --- 整体架构
    • 0x00 摘要
    • 0x01 启动
      • 1.1 分布式启动
      • 1.2 构造基础
        • 1.2.1 获取模型
        • 1.2.2 获取数据集
        • 1.2.3 步进函数
          • 1.2.3.1 广播数据
    • 0x02 Pretrain
    • 0x03 初始化
      • 3.1 initialize_megatron
      • 3.2 初始化分布式环境
      • 3.3 初始化进程组全局变量
    • 0x04 设置模型
      • 4.1 setup_model_and_optimizer
      • 4.2 模型
        • 4.2.1 BertModel
        • 4.2.2 语言模型
        • 4.2.3 ParallelTransformer
          • 4.2.3.1 获取层数
          • 4.2.3.2 前向传播
      • 4.3 get_model
    • 0x05 数据并行
      • 5.1 设置数据
      • 5.2 DDP
        • 5.2.1 定义
        • 5.2.2 初始化
        • 5.2.3 内存
        • 5.2.4 支撑函数
        • 5.2.5 梯度规约
    • 0x06 训练
      • 6.1 训练主体
      • 6.2 训练step
      • 6.3 获取schedule
    • 0xFF 参考

0x00 摘要

NVIDIA Megatron 是一个基于 PyTorch 的分布式训练框架,用来训练超大Transformer语言模型,其通过综合应用了数据并行,Tensor并行和Pipeline并行来复现 GPT3,值得我们深入分析其背后机理。

本系列大概有6~7篇文章,通过论文和源码和大家一起学习研究。本文将对 Megatron 的基本架构做一下梳理。

本系列其他文章为:

[源码解析] 模型并行分布式训练Megatron (1) --- 论文 & 基础

0x01 启动

1.1 分布式启动

启动脚本在 examples/pretrain_bert_distributed.sh,其利用了 torch.distributed.launch 来启动多个进程。具体业务代码是 pretrain_bert.py。

因为 GPUS_PER_NODE 是8,所以 nproc_per_node 是8,这样,在本机上就启动了8个进程,每个进程之中含有模型的一部分进程的 rank 是被 torch.distributed.launch 调用 elastic 自动分配的

#!/bin/bash

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE= ( ( (( ((GPUS_PER_NODE*$NNODES))

DATA_PATH=<Specify path and file prefix>_text_sentence
CHECKPOINT_PATH=<Specify path>

DISTRIBUTED_ARGS=“–nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT”

python -m torch.distributed.launch $DISTRIBUTED_ARGS
pretrain_bert.py
–num-layers 24
–hidden-size 1024
–num-attention-heads 16
–micro-batch-size 4
–global-batch-size 32
–seq-length 512
–max-position-embeddings 512
–train-iters 1000000
–save $CHECKPOINT_PATH
–load $CHECKPOINT_PATH
–data-path $DATA_PATH
–vocab-file bert-vocab.txt
–data-impl mmap
–split 949,50,1
–distributed-backend nccl
–lr 0.0001
–lr-decay-style linear
–min-lr 1.0e-5
–lr-decay-iters 990000
–weight-decay 1e-2
–clip-grad 1.0
–lr-warmup-fraction .01
–log-interval 100
–save-interval 10000
–eval-interval 1000
–eval-iters 10
–fp16

1.2 构造基础

pretrain_bert.py 会调用 pretrain 进行预训练。

if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,ModelType.encoder_or_decoder,forward_step, args_defaults={<span class="hljs-string">'tokenizer_type'</span>: <span class="hljs-string">'BertWordPieceLowerCase'</span>})

1.2.1 获取模型

model_provider返回模型普通版本(vanilla version)。所谓vanilla,我们指的是一个简单的cpu模型,没有 fp16或 ddp,但是已经被 Megatron 改造为并行的版本。

def model_provider(pre_process=True, post_process=True):"""Build the model."""
print_rank_0(<span class="hljs-string">'building BERT model ...'</span>)args = get_args()
num_tokentypes = <span class="hljs-number">2</span> <span class="hljs-keyword">if</span> args.bert_binary_head <span class="hljs-keyword">else</span> <span class="hljs-number">0</span>
model = BertModel(num_tokentypes=num_tokentypes,add_binary_head=args.bert_binary_head,parallel_output=<span class="hljs-literal">True</span>,pre_process=pre_process,post_process=post_process)<span class="hljs-keyword">return</span> model

1.2.2 获取数据集

train_valid_test_datasets_provider 会接受train/valid/test数据集的大小,并返回 “train,valid,test” 数据集。

def train_valid_test_datasets_provider(train_val_test_num_samples):"""Build train, valid, and test datasets."""args = get_args()
print_rank_0(<span class="hljs-string">'&gt; building train, validation, and test datasets '</span><span class="hljs-string">'for BERT ...'</span>)
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(data_prefix=args.data_path,data_impl=args.data_impl,splits_string=args.split,train_valid_test_num_samples=train_val_test_num_samples,max_seq_length=args.seq_length,masked_lm_prob=args.mask_prob,short_seq_prob=args.short_seq_prob,seed=args.seed,skip_warmup=(<span class="hljs-keyword">not</span> args.mmap_warmup),binary_head=args.bert_binary_head)
print_rank_0(<span class="hljs-string">"&gt; finished creating BERT datasets ..."</span>)<span class="hljs-keyword">return</span> train_ds, valid_ds, test_ds

1.2.3 步进函数

forward_step函数接受一个“数据迭代器”和“模型”,并返回一个“loss”标量,该标量带有一个字典,其中key:value是希望在训练期间监视的信息,例如“lm loss:value”。还要求此函数将“batch generator”添加到timers类中。

def forward_step(data_iterator, model):"""Forward step."""args = get_args()
<span class="hljs-comment"># Get the batch.</span>
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(data_iterator)<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> args.bert_binary_head:types = <span class="hljs-literal">None</span><span class="hljs-comment"># Forward pass through the model.</span>
output_tensor = model(tokens, padding_mask, tokentype_ids=types,lm_labels=lm_labels)<span class="hljs-keyword">return</span> output_tensor, partial(loss_func, loss_mask, sentence_order)

1.2.3.1 广播数据

forward_step 会调用 get_batch 获取batch 数据,其内部会从迭代器获取数据,然后使用broadcast_data函数把输入数据从 rank 0 广播到所有tensor-model-parallel 其他 ranks之上。

注意,数据并行是把不同数据加载到不同的rank之上,而 Tensor模型并行组之中每个rank都加载同样数据

def get_batch(data_iterator):"""Build the batch."""
<span class="hljs-comment"># Items and their type.</span>
keys = [<span class="hljs-string">'text'</span>, <span class="hljs-string">'types'</span>, <span class="hljs-string">'labels'</span>, <span class="hljs-string">'is_random'</span>, <span class="hljs-string">'loss_mask'</span>, <span class="hljs-string">'padding_mask'</span>]
datatype = torch.int64<span class="hljs-comment"># Broadcast data.</span>
<span class="hljs-keyword">if</span> data_iterator <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:data = <span class="hljs-built_in">next</span>(data_iterator) <span class="hljs-comment"># 获取数据</span>
<span class="hljs-keyword">else</span>:data = <span class="hljs-literal">None</span>
data_b = mpu.broadcast_data(keys, data, datatype) <span class="hljs-comment"># 把数据广播到各个GPU</span><span class="hljs-comment"># Unpack.</span>
tokens = data_b[<span class="hljs-string">'text'</span>].long()
types = data_b[<span class="hljs-string">'types'</span>].long()
sentence_order = data_b[<span class="hljs-string">'is_random'</span>].long()
loss_mask = data_b[<span class="hljs-string">'loss_mask'</span>].<span class="hljs-built_in">float</span>()
lm_labels = data_b[<span class="hljs-string">'labels'</span>].long()
padding_mask = data_b[<span class="hljs-string">'padding_mask'</span>].long()<span class="hljs-keyword">return</span> tokens, types, sentence_order, loss_mask, lm_labels, padding_mask

broadcast_data 在每个model parallel group之上,把数据从rank 0发送到同组其他成员。

def broadcast_data(keys, data, datatype):"""Broadcast data from rank zero of each model parallel group to themembers of the same model parallel group.
Arguments:keys: list of keys in the data disctionary to be broadcasteddata: data dictionary of string keys and cpu tensor values.datatype: torch data type of all tensors in data associatedwith keys.
"""</span>
<span class="hljs-comment"># Build (key, size) and (key, number of elements) dictionaries along</span>
<span class="hljs-comment"># with the total number of elements on all ranks.</span>
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,data)<span class="hljs-comment"># Pack on rank zero.</span>
<span class="hljs-keyword">if</span> get_tensor_model_parallel_rank() == <span class="hljs-number">0</span>: <span class="hljs-comment"># rank 0才压缩</span><span class="hljs-comment"># Check that all keys have the same data type.</span>_check_data_types(keys, data, datatype)<span class="hljs-comment"># Flatten the data associated with the keys</span>flatten_data = torch.cat([data[key].contiguous().view(-<span class="hljs-number">1</span>) <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys], dim=<span class="hljs-number">0</span>).cuda()
<span class="hljs-keyword">else</span>:flatten_data = torch.empty(total_numel,device=torch.cuda.current_device(),dtype=datatype)<span class="hljs-comment"># Broadcast</span>
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),group=get_tensor_model_parallel_group())<span class="hljs-comment"># Unpack</span>
output = {}
offset = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys:size = key_size[key]numel = key_numel[key]output[key] = flatten_data.narrow(<span class="hljs-number">0</span>, offset, numel).view(size)offset += numel<span class="hljs-keyword">return</span> output

get_tensor_model_parallel_src_rank 计算与张量模型并行组中第一个local rank对应的全局rank。

def get_tensor_model_parallel_src_rank():"""Calculate the global rank corresponding to the first local rankin the tensor model parallel group."""global_rank = torch.distributed.get_rank()local_world_size = get_tensor_model_parallel_world_size()return (global_rank // local_world_size) * local_world_size

逻辑图具体如下,三个不同的函数分别为预训练提供不同的功能输入,做到了解耦。

0x02 Pretrain

BERT训练主要分为两步:

  • Pre-train:pre-train是迁移学习的基础,是训练token-level的语义理解。
  • Fine-tuning:在已经训练好的语言模型基础之上,加入特定领域(比如金融医疗)的参数来重新训练,比如对于分类问题就可以在pre-train模型基础之上加上一个softmax,再使用语料 fine-tune。

Pre-train 主要如下:

  • 初始化Megatron。

  • 使用model_provider设置模型、优化器和lr计划。

  • 调用train_val_test_data_provider以获取train/val/test数据集。

  • 使用forward_step_func训练模型。

具体代码如下:

def pretrain(train_valid_test_dataset_provider,model_provider,model_type,forward_step_func,extra_args_provider=None,args_defaults={}):"""Main training program.
This function will run the followings in the order provided:1) initialize Megatron.2) setup model, optimizer and lr schedule using the model_provider.3) call train_val_test_data_provider to get train/val/test datasets.4) train the modle using the forward_step_func.
"""</span><span class="hljs-comment"># Initalize and get arguments, timers, and Tensorboard writer.</span>
initialize_megatron(extra_args_provider=extra_args_provider,args_defaults=args_defaults)<span class="hljs-comment"># Adjust the startup time so it reflects the largest value.</span>
<span class="hljs-comment"># This will be closer to what scheduler will see (outside of</span>
<span class="hljs-comment"># image ... launches.</span>
<span class="hljs-keyword">global</span> _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()args = get_args()
timers = get_timers()<span class="hljs-comment"># Model, optimizer, and learning rate. 使用model_provider设置模型、优化器和lr计划</span>
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,model_type)<span class="hljs-comment"># Data stuff. 调用train_val_test_data_provider以获取train/val/测试数据集</span>
<span class="hljs-keyword">if</span> args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:all_data_iterators = [build_train_valid_test_data_iterators(train_valid_test_dataset_provider)<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(model))]train_data_iterator = [data_iterators[<span class="hljs-number">0</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]valid_data_iterator = [data_iterators[<span class="hljs-number">1</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]test_data_iterator = [data_iterators[<span class="hljs-number">2</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
<span class="hljs-keyword">else</span>:train_data_iterator, valid_data_iterator, test_data_iterator \= build_train_valid_test_data_iterators(train_valid_test_dataset_provider)iteration = <span class="hljs-number">0</span>
<span class="hljs-keyword">if</span> args.do_train <span class="hljs-keyword">and</span> args.train_iters &gt; <span class="hljs-number">0</span>:iteration = train(forward_step_func, <span class="hljs-comment"># 训练模型</span>model, optimizer, lr_scheduler,train_data_iterator, valid_data_iterator)<span class="hljs-keyword">if</span> args.do_valid:prefix = <span class="hljs-string">'the end of training for val data'</span>evaluate_and_print_results(prefix, forward_step_func,valid_data_iterator, model,iteration, <span class="hljs-literal">False</span>)<span class="hljs-keyword">if</span> args.save <span class="hljs-keyword">and</span> iteration != <span class="hljs-number">0</span>:save_checkpoint(iteration, model, optimizer, lr_scheduler)<span class="hljs-keyword">if</span> args.do_test:<span class="hljs-comment"># Run on test data.</span>prefix = <span class="hljs-string">'the end of training for test data'</span>evaluate_and_print_results(prefix, forward_step_func,test_data_iterator, model,<span class="hljs-number">0</span>, <span class="hljs-literal">True</span>)

对于我们分析来说,initialize_megatron 是重点,这里初始化了 megatron。

0x03 初始化

3.1 initialize_megatron

initialize_megatron 方法会设置全局变量,初始化分布式环境等等。

def initialize_megatron(extra_args_provider=None, args_defaults={},ignore_unknown_args=False, allow_no_cuda=False):"""Set global variables, initialize distributed, andset autoresume and random seeds.`allow_no_cuda` should not be set unless using megatron for cpu only data processing. In general this arg should not be set unless you know what you are doing.Returns a function to finalize distributed env initialization (optionally, only when args.lazy_mpu_init == True)"""if not allow_no_cuda:# Make sure cuda is available.assert torch.cuda.is_available(), 'Megatron requires CUDA.'
<span class="hljs-comment"># Parse args, build tokenizer, and set adlr-autoresume,</span>
<span class="hljs-comment"># tensorboard-writer, and timers.</span>
set_global_variables(extra_args_provider=extra_args_provider, <span class="hljs-comment"># 设置全局变量</span>args_defaults=args_defaults,ignore_unknown_args=ignore_unknown_args)<span class="hljs-comment"># torch.distributed initialization</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">finish_mpu_init</span>():args = get_args()<span class="hljs-comment"># Pytorch distributed.</span>_initialize_distributed() <span class="hljs-comment"># 设置分布式</span><span class="hljs-comment"># Random seeds for reproducibility.</span><span class="hljs-keyword">if</span> args.rank == <span class="hljs-number">0</span>:<span class="hljs-built_in">print</span>(<span class="hljs-string">'&gt; setting random seeds to {} ...'</span>.<span class="hljs-built_in">format</span>(args.seed))_set_random_seed(args.seed)<span class="hljs-comment"># Set pytorch JIT layer fusion options.</span>
_set_jit_fusion_options()args = get_args()
<span class="hljs-keyword">if</span>  args.lazy_mpu_init:args.use_cpu_initialization=<span class="hljs-literal">True</span><span class="hljs-comment"># delayed initialization of DDP-related stuff</span><span class="hljs-comment"># We only set basic DDP globals    </span>set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)<span class="hljs-comment"># and return function for external DDP manager</span><span class="hljs-comment"># to call when it has DDP initialized</span>set_tensor_model_parallel_rank(args.rank)    <span class="hljs-keyword">return</span> finish_mpu_init
<span class="hljs-keyword">else</span>:<span class="hljs-comment"># Megatron's MPU is the master. Complete initialization right away.</span>finish_mpu_init()<span class="hljs-comment"># Autoresume.</span>_init_autoresume()<span class="hljs-comment"># Compile dependencies.</span>_compile_dependencies()<span class="hljs-comment"># No continuation function</span><span class="hljs-keyword">return</span> <span class="hljs-literal">None</span>

3.2 初始化分布式环境

_initialize_distributed 代码位于 megatron/initialize.py,此方法会:

  • 调用 torch.distributed.init_process_group 初始化分布式环境。
  • 调用 mpu.initialize_model_parallel 来设置模型并行,数据并行等各种进程组,我们下文会重点讨论。

创建完worker进程之后,程序需要知道哪些进程在训练同一个模型,torch.distributed.init_process_group 就实现了这个功能。torch.distributed.init_process_group 会生成一个进程组,同组内进程训练同一个模型,也能确定用什么方式进行通信。进程组会给组内每个进程一个序号,就是gloabl rank,如果是多机并行,每个机器创建的进程之间也有一个序号,就是 local rank。如果是单机多卡并行,local rank 和 global rank是一致的。

def _initialize_distributed():"""Initialize torch.distributed and mpu."""args = get_args()
device_count = torch.cuda.device_count()
<span class="hljs-keyword">if</span> torch.distributed.is_initialized():args.rank = torch.distributed.get_rank()args.world_size = torch.distributed.get_world_size()
<span class="hljs-keyword">else</span>:<span class="hljs-comment"># Manually set the device ids.</span><span class="hljs-keyword">if</span> device_count &gt; <span class="hljs-number">0</span>:device = args.rank % device_count<span class="hljs-keyword">if</span> args.local_rank <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:<span class="hljs-keyword">assert</span> args.local_rank == device, \<span class="hljs-string">'expected local-rank to be the same as rank % device-count.'</span><span class="hljs-keyword">else</span>:args.local_rank = devicetorch.cuda.set_device(device)
<span class="hljs-comment"># Call the init process</span>
torch.distributed.init_process_group( <span class="hljs-comment"># 初始化PyTorch分布式环境</span>backend=args.distributed_backend,world_size=args.world_size, rank=args.rank,timeout=timedelta(minutes=<span class="hljs-number">10</span>))<span class="hljs-comment"># Set the tensor model-parallel, pipeline model-parallel, and</span>
<span class="hljs-comment"># data-parallel communicators.</span>
<span class="hljs-keyword">if</span> device_count &gt; <span class="hljs-number">0</span>:<span class="hljs-keyword">if</span> mpu.model_parallel_is_initialized():<span class="hljs-built_in">print</span>(<span class="hljs-string">'model parallel is already initialized'</span>)<span class="hljs-keyword">else</span>:<span class="hljs-comment"># 初始化模型并行,比如设置各种进程组</span>mpu.initialize_model_parallel(args.tensor_model_parallel_size,args.pipeline_model_parallel_size,args.virtual_pipeline_model_parallel_size,args.pipeline_model_parallel_split_rank)

3.3 初始化进程组全局变量

因为调用了 mpu.initialize_model_parallel 来设置模型并行,数据并行等各种进程组,所以我们假定目前进程组都已经设置成功,所以每个 rank 对应的进程都有自己的全局变量。假定目前有16个GPU,属于两个node,rank 0 ~7 属于第一个节点,rank 8 ~ 15 属于第二个节点。下面的 gi 指的是第 i 个 GPU。

  • _TENSOR_MODEL_PARALLEL_GROUP :当前 rank 所属于的Intra-layer model parallel group,就是tensor 并行进程组。
    • 假如每一层分为两个tensor,则 _TENSOR_MODEL_PARALLEL_GROUP 例子为:[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]。
  • _PIPELINE_MODEL_PARALLEL_GROUP :当前 rank 所属于的Intra-layer model parallel group,就是流水线进程组。
    • 假如流水线深度为4,则例子为 [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]。
  • _MODEL_PARALLEL_GROUP :当前 rank 所属于的模型并行进程组,包括了以上两组。
    • 针对我们例子,就是完整模型被复制了两份,两份分别对应的 GPU 具体是[0, 1, 4, 5, 8, 9, 12, 13],[2, 3, 6, 7, 10, 11, 14, 15]
  • _EMBEDDING_GROUP : 嵌入对应的进程组。
  • _DATA_PARALLEL_GROUP :当前 rank 所属于的Data parallel group。
    • 假如数据并行度数为2,则例子为[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]。
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

0x04 设置模型

在 Pretrain 之中,会调用如下来设置模型,优化器等等。

# Model, optimizer, and learning rate. 使用model_provider设置模型、优化器和lr计划
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,model_type)

4.1 setup_model_and_optimizer

setup_model_and_optimizer 方法会设置模型和优化器,其中重点是get_model。

def setup_model_and_optimizer(model_provider_func, model_type):"""Setup model and optimizer."""args = get_args()model = get_model(model_provider_func, model_type)unwrapped_model = unwrap_model(model,(torchDDP, LocalDDP, Float16Module))optimizer = get_megatron_optimizer(unwrapped_model)lr_scheduler = get_learning_rate_scheduler(optimizer)
<span class="hljs-keyword">if</span> args.load <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:timers = get_timers()<span class="hljs-comment"># Extra barrier is added to make sure all ranks report the</span><span class="hljs-comment"># max time.</span>torch.distributed.barrier()args.iteration = load_checkpoint(model, optimizer, lr_scheduler)torch.distributed.barrier()
<span class="hljs-keyword">else</span>:args.iteration = <span class="hljs-number">0</span><span class="hljs-comment"># We only support local DDP with multiple micro-batches.</span>
<span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(model) &gt; <span class="hljs-number">1</span> <span class="hljs-keyword">or</span> mpu.get_pipeline_model_parallel_world_size() &gt; <span class="hljs-number">1</span>:<span class="hljs-keyword">assert</span> args.DDP_impl == <span class="hljs-string">'local'</span><span class="hljs-comment"># get model without FP16 and/or TorchDDP wrappers</span>
<span class="hljs-keyword">if</span> args.iteration == <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> <span class="hljs-built_in">len</span>(unwrapped_model) == <span class="hljs-number">1</span> \<span class="hljs-keyword">and</span> <span class="hljs-built_in">hasattr</span>(unwrapped_model[<span class="hljs-number">0</span>], <span class="hljs-string">'init_state_dict_from_bert'</span>):unwrapped_model[<span class="hljs-number">0</span>].init_state_dict_from_bert()<span class="hljs-keyword">if</span> args.fp16:optimizer.reload_model_params()<span class="hljs-keyword">return</span> model, optimizer, lr_scheduler

4.2 模型

4.2.1 BertModel

我们首先看看 BertModel 的初始化函数,略过其他功能函数。其主要调用了 get_language_model。

class BertModel(MegatronModule):"""Bert Language model."""
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self,num_tokentypes=<span class="hljs-number">2</span>,add_binary_head=<span class="hljs-literal">True</span>,parallel_output=<span class="hljs-literal">True</span>,pre_process=<span class="hljs-literal">True</span>,post_process=<span class="hljs-literal">True</span></span>):<span class="hljs-built_in">super</span>(BertModel, self).__init__()args = get_args()self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropyself.add_binary_head = add_binary_headself.parallel_output = parallel_outputself.pre_process = pre_processself.post_process = post_processinit_method = init_method_normal(args.init_method_std)scaled_init_method = scaled_init_method_normal(args.init_method_std,args.num_layers)<span class="hljs-comment"># 获取语言模型</span>self.language_model, self._language_model_key = get_language_model(num_tokentypes=num_tokentypes,add_pooler=self.add_binary_head,encoder_attn_mask_type=AttnMaskType.padding,init_method=init_method,scaled_init_method=scaled_init_method,pre_process=self.pre_process,post_process=self.post_process)self.initialize_word_embeddings(init_method_normal)<span class="hljs-keyword">if</span> self.post_process: <span class="hljs-comment"># 如果是最后一层,会特殊处理</span>self.lm_head = BertLMHead(self.word_embeddings_weight().size(<span class="hljs-number">0</span>),args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)self._lm_head_key = <span class="hljs-string">'lm_head'</span>self.binary_head = <span class="hljs-literal">None</span><span class="hljs-keyword">if</span> self.add_binary_head:self.binary_head = get_linear_layer(args.hidden_size, <span class="hljs-number">2</span>,init_method)self._binary_head_key = <span class="hljs-string">'binary_head'</span>

4.2.2 语言模型

get_language_model 会获取一个 TransformerLanguageModel。

def get_language_model(num_tokentypes, add_pooler,encoder_attn_mask_type, init_method=None,scaled_init_method=None, add_encoder=True,add_decoder=False,decoder_attn_mask_type=AttnMaskType.causal,pre_process=True, post_process=True):"""Build language model and return along with the key to save."""args = get_args()
<span class="hljs-keyword">if</span> init_method <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:init_method = init_method_normal(args.init_method_std)<span class="hljs-keyword">if</span> scaled_init_method <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:scaled_init_method = scaled_init_method_normal(args.init_method_std,args.num_layers)<span class="hljs-comment"># Language model.</span>
language_model = TransformerLanguageModel(init_method,scaled_init_method,encoder_attn_mask_type,num_tokentypes=num_tokentypes,add_encoder=add_encoder,add_decoder=add_decoder,decoder_attn_mask_type=decoder_attn_mask_type,add_pooler=add_pooler,pre_process=pre_process,post_process=post_process
)
<span class="hljs-comment"># key used for checkpoints.</span>
language_model_key = <span class="hljs-string">'language_model'</span><span class="hljs-keyword">return</span> language_model, language_model_key

TransformerLanguageModel 就是具体的语言模型,其中重要的是 ParallelTransformer。这里会依据传入的配置来进行生成。

  • 如果是第一层,即有 pre_process,则会加入 embedding layer。
  • 如果是中间层,则会根据 encoder 还是 decoder 来生成对应的 ParallelTransformer。
  • 如果是最后一层,即有 post_process,则会加入 Pooler,在外层 BertModel 也会有对应处理。
class TransformerLanguageModel(MegatronModule):"""Transformer language model.
Arguments:transformer_hparams: transformer hyperparametersvocab_size: vocabulary sizemax_sequence_length: maximum size of sequence. Thisis used for positional embeddingembedding_dropout_prob: dropout probability for embeddingsnum_tokentypes: size of the token-type embeddings. 0 valuewill ignore this embedding
"""</span><span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self,init_method,output_layer_init_method,encoder_attn_mask_type,num_tokentypes=<span class="hljs-number">0</span>,add_encoder=<span class="hljs-literal">True</span>,add_decoder=<span class="hljs-literal">False</span>,decoder_attn_mask_type=AttnMaskType.causal,add_pooler=<span class="hljs-literal">False</span>,pre_process=<span class="hljs-literal">True</span>,post_process=<span class="hljs-literal">True</span></span>):<span class="hljs-built_in">super</span>(TransformerLanguageModel, self).__init__()args = get_args()self.pre_process = pre_processself.post_process = post_processself.hidden_size = args.hidden_sizeself.num_tokentypes = num_tokentypesself.init_method = init_methodself.add_encoder = add_encoderself.encoder_attn_mask_type = encoder_attn_mask_typeself.add_decoder = add_decoderself.decoder_attn_mask_type = decoder_attn_mask_typeself.add_pooler = add_poolerself.encoder_hidden_state = <span class="hljs-literal">None</span><span class="hljs-comment"># Embeddings.</span><span class="hljs-keyword">if</span> self.pre_process:self.embedding = Embedding(self.hidden_size,args.padded_vocab_size,args.max_position_embeddings,args.hidden_dropout,self.init_method,self.num_tokentypes)self._embedding_key = <span class="hljs-string">'embedding'</span><span class="hljs-comment"># Transformer.</span><span class="hljs-comment"># Encoder (usually set to True, False if part of an encoder-decoder</span><span class="hljs-comment"># architecture and in encoder-only stage).</span><span class="hljs-keyword">if</span> self.add_encoder:self.encoder = ParallelTransformer(self.init_method,output_layer_init_method,self_attn_mask_type=self.encoder_attn_mask_type,pre_process=self.pre_process,post_process=self.post_process)self._encoder_key = <span class="hljs-string">'encoder'</span><span class="hljs-keyword">else</span>:self.encoder = <span class="hljs-literal">None</span><span class="hljs-comment"># Decoder (usually set to False, True if part of an encoder-decoder</span><span class="hljs-comment"># architecture and in decoder-only stage).</span><span class="hljs-keyword">if</span> self.add_decoder:<span class="hljs-comment"># Temporary assertion until we verify correctness of pipeline parallelism</span><span class="hljs-comment"># implementation of T5.</span>self.decoder = ParallelTransformer(self.init_method,output_layer_init_method,layer_type=LayerType.decoder,self_attn_mask_type=self.decoder_attn_mask_type,pre_process=self.pre_process,post_process=self.post_process)self._decoder_key = <span class="hljs-string">'decoder'</span><span class="hljs-keyword">else</span>:self.decoder = <span class="hljs-literal">None</span><span class="hljs-keyword">if</span> self.post_process:<span class="hljs-comment"># Pooler.</span><span class="hljs-keyword">if</span> self.add_pooler:self.pooler = Pooler(self.hidden_size, self.init_method)self._pooler_key = <span class="hljs-string">'pooler'</span>

4.2.3 ParallelTransformer

这里会调用 ParallelTransformerLayer 生成具体的 Transformer层,我们会在后文中进行分析。

即,ParallelTransformer 包括多个 Transformer,其中每层 Transformer 是一个 ParallelTransformerLayer

class ParallelTransformer(MegatronModule):"""Transformer class."""
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, init_method, output_layer_init_method,layer_type=LayerType.encoder,self_attn_mask_type=AttnMaskType.padding,pre_process=<span class="hljs-literal">True</span>, post_process=<span class="hljs-literal">True</span></span>):<span class="hljs-built_in">super</span>(ParallelTransformer, self).__init__()args = get_args()self.bf16 = args.bf16self.fp32_residual_connection = args.fp32_residual_connectionself.pre_process = pre_processself.post_process = post_processself.input_tensor = <span class="hljs-literal">None</span><span class="hljs-comment"># Store activation checkpoiting flag.</span>self.activations_checkpoint_method = args.activations_checkpoint_methodself.activations_checkpoint_num_layers = args.activations_checkpoint_num_layersself.distribute_checkpointed_activations = args.distribute_checkpointed_activations<span class="hljs-comment"># Number of layers.</span>self.num_layers = mpu.get_num_layers( <span class="hljs-comment"># 获得本Transformer的具体层数</span>args, args.model_type == ModelType.encoder_and_decoder)<span class="hljs-comment"># Transformer layers.</span><span class="hljs-keyword">def</span> <span class="hljs-title function_">build_layer</span>(<span class="hljs-params">layer_number</span>):<span class="hljs-keyword">return</span> ParallelTransformerLayer( <span class="hljs-comment"># 返回一层 Transformmer</span>init_method,output_layer_init_method,layer_number,layer_type=layer_type,self_attn_mask_type=self_attn_mask_type)<span class="hljs-keyword">if</span> args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:<span class="hljs-comment"># Number of layers in each model chunk is the number of layers in the stage,</span><span class="hljs-comment"># divided by the number of model chunks in a stage.</span>self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size<span class="hljs-comment"># With 8 layers, 2 stages, and 4 model chunks, we want an assignment of</span><span class="hljs-comment"># layers to stages like (each list is a model chunk):</span><span class="hljs-comment"># Stage 0: [0]  [2]  [4]  [6]</span><span class="hljs-comment"># Stage 1: [1]  [3]  [5]  [7]</span><span class="hljs-comment"># With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of</span><span class="hljs-comment"># layers to stages like (each list is a model chunk):</span><span class="hljs-comment"># Stage 0: [0, 1]  [4, 5]</span><span class="hljs-comment"># Stage 1: [2, 3]  [6, 7]</span>offset = mpu.get_virtual_pipeline_model_parallel_rank() * (args.num_layers // args.virtual_pipeline_model_parallel_size) + \(mpu.get_pipeline_model_parallel_rank() * self.num_layers)<span class="hljs-keyword">else</span>:<span class="hljs-comment"># Each stage gets a contiguous set of layers.</span>offset = mpu.get_pipeline_model_parallel_rank() * self.num_layersself.layers = torch.nn.ModuleList( <span class="hljs-comment"># 生成 num_layers 个 Transformer</span>[build_layer(i + <span class="hljs-number">1</span> + offset) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(self.num_layers)])<span class="hljs-keyword">if</span> self.post_process:<span class="hljs-comment"># Final layer norm before output.</span>self.final_layernorm = LayerNorm(args.hidden_size,eps=args.layernorm_epsilon,no_persist_layer_norm=args.no_persist_layer_norm)

目前逻辑如下,我们假定有两个 transformer:

4.2.3.1 获取层数

这里一个重点就是获取层数,即获取本模型在并行处理状况下,应该拥有多少层。如果模型一共64层,流水线深度为16,则并行每个阶段有4层,则本子模型拥有4层。

def get_num_layers(args, is_encoder_and_decoder_model):"""Compute the number of transformer layers resident on the current rank."""if get_pipeline_model_parallel_world_size() > 1:if is_encoder_and_decoder_model:assert args.pipeline_model_parallel_split_rank is not Nonenum_ranks_in_encoder = args.pipeline_model_parallel_split_ranknum_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoderif is_pipeline_stage_before_split():num_layers = args.num_layers // num_ranks_in_encoderelse:num_layers = args.num_layers // num_ranks_in_decoderelse:num_layers = args.num_layers // get_pipeline_model_parallel_world_size()else:num_layers = args.num_layersreturn num_layers

get_pipeline_model_parallel_world_size 获取本流水线组world size数目,就是流水线深度。

def get_pipeline_model_parallel_world_size():"""Return world size for the pipeline model parallel group."""global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZEif _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZEreturn torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())

_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 的意思是流水线深度 p,就是纵向切 p-1刀。比如一共 12 层,纵向切 5 刀,则有 6 个stage,每个 stage 有 2 层。

4.2.3.2 前向传播

我们接着看看其前向传播函数,这里主要就是调用内部 ParallelTransformerLayer 的 forward 方法,如果是第一层或者最后一层,则做特殊处理。

def forward(self, hidden_states, attention_mask,encoder_output=None, enc_dec_attn_mask=None,inference_params=None):
<span class="hljs-keyword">if</span> self.pre_process:<span class="hljs-comment"># Data format change to avoid explicit tranposes : [b s h] --&gt; [s b h].</span><span class="hljs-comment"># If the input flag for fp32 residual connection is set, convert for float.</span><span class="hljs-keyword">if</span> self.fp32_residual_connection:hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous().<span class="hljs-built_in">float</span>()<span class="hljs-comment"># Otherwise, leave it as is.</span><span class="hljs-keyword">else</span>:hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()
<span class="hljs-keyword">else</span>:<span class="hljs-comment"># See set_input_tensor()</span>hidden_states = self.input_tensor<span class="hljs-keyword">if</span> encoder_output <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:encoder_output = encoder_output.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()<span class="hljs-keyword">if</span> self.activations_checkpoint_method <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:hidden_states = self._checkpointed_forward(hidden_states,attention_mask,encoder_output,enc_dec_attn_mask)
<span class="hljs-keyword">else</span>:<span class="hljs-keyword">for</span> index <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(self.num_layers):layer = self._get_layer(index)hidden_states = layer( <span class="hljs-comment"># 调用ParallelTransformerLayer的forward函数</span>hidden_states,attention_mask,encoder_output=encoder_output,enc_dec_attn_mask=enc_dec_attn_mask,inference_params=inference_params)<span class="hljs-comment"># Final layer norm.</span>
<span class="hljs-keyword">if</span> self.post_process:<span class="hljs-comment"># Reverting data format change [s b h] --&gt; [b s h].</span>hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()output = self.final_layernorm(hidden_states)
<span class="hljs-keyword">else</span>:output = hidden_states<span class="hljs-keyword">return</span> output

4.3 get_model

现在让我们回到 get_model,把生成模型的流程整理出来。

BERT之中含有多个transformer,所以直接按照层数切分,每一层是一模一样的transformer layer。前面提到了,在我们样例之中启动了8个进程,每个进程里面有一个子模型,即原始BERT模型的部分层。但是怎么知道每个子模型包含了多少层?答案是:因为已经建立了各种进程组,所以 get_model 方法会依据目前进程组情况进行处理。单个进程内模型获取如下:

  • 如果是有 virtual 设置,则会遍历 virtual size,生成对应数目的模型(BertModel)。
  • 否则如果是 encoder_and_decoder,则针对split进行配置。
  • 设置 tensor model parallel 属性。
  • 把本模型放置到GPU之上。
  • 如果需要数据并行,则配置DDP。

具体代码如下:

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):"""Build the model."""args = get_args()args.model_type = model_type
<span class="hljs-comment"># Build model.</span>
<span class="hljs-keyword">if</span> mpu.get_pipeline_model_parallel_world_size() &gt; <span class="hljs-number">1</span> <span class="hljs-keyword">and</span> \args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>: <span class="hljs-comment"># 有virtual设置,后续会提到</span>model = []<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(args.virtual_pipeline_model_parallel_size): <span class="hljs-comment"># 遍历virtual</span><span class="hljs-comment"># 设置rank,主要是为了看是不是第一层,最后一层</span>mpu.set_virtual_pipeline_model_parallel_rank(i) <span class="hljs-comment"># Set pre_process and post_process only after virtual rank is set.</span>pre_process = mpu.is_pipeline_first_stage()post_process = mpu.is_pipeline_last_stage()this_model = model_provider_func( <span class="hljs-comment"># 获取原始模型 BertModel</span>pre_process=pre_process,post_process=post_process)this_model.model_type = model_typemodel.append(this_model) <span class="hljs-comment"># 模型列表之中添加一个新的 BertModel</span>
<span class="hljs-keyword">else</span>:pre_process = mpu.is_pipeline_first_stage() <span class="hljs-comment"># 是不是第一层</span>post_process = mpu.is_pipeline_last_stage() <span class="hljs-comment"># 是不是最后一层</span>add_encoder = <span class="hljs-literal">True</span>add_decoder = <span class="hljs-literal">True</span><span class="hljs-keyword">if</span> model_type == ModelType.encoder_and_decoder:<span class="hljs-keyword">if</span> mpu.get_pipeline_model_parallel_world_size() &gt; <span class="hljs-number">1</span>:rank = mpu.get_pipeline_model_parallel_rank()split_rank = args.pipeline_model_parallel_split_rankworld_size = mpu.get_pipeline_model_parallel_world_size()pre_process = rank == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> rank == split_rank  <span class="hljs-comment"># 是不是第一层</span>post_process = (rank == (split_rank - <span class="hljs-number">1</span>)) <span class="hljs-keyword">or</span> ( <span class="hljs-comment"># 是不是最后一层</span>rank == (world_size - <span class="hljs-number">1</span>))add_encoder = mpu.is_pipeline_stage_before_split()add_decoder = mpu.is_pipeline_stage_after_split()model = model_provider_func( <span class="hljs-comment"># 获取原始模型</span>pre_process=pre_process,post_process=post_process,add_encoder=add_encoder,add_decoder=add_decoder)<span class="hljs-keyword">else</span>:model = model_provider_func( <span class="hljs-comment"># 获取原始模型</span>pre_process=pre_process,post_process=post_process)model.model_type = model_type<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> <span class="hljs-built_in">isinstance</span>(model, <span class="hljs-built_in">list</span>):model = [model]<span class="hljs-comment"># Set tensor model parallel attributes if not set.</span>
<span class="hljs-comment"># Only parameters that are already tensor model parallel have these</span>
<span class="hljs-comment"># attributes set for them. We should make sure the default attributes</span>
<span class="hljs-comment"># are set for all params so the optimizer can use them.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model_module.parameters():mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)<span class="hljs-comment"># GPU allocation.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model: <span class="hljs-comment"># 把本模型放置到GPU之上</span>model_module.cuda(torch.cuda.current_device())<span class="hljs-comment"># Fp16 conversion.</span>
<span class="hljs-keyword">if</span> args.fp16 <span class="hljs-keyword">or</span> args.bf16:model = [Float16Module(model_module, args) <span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]<span class="hljs-keyword">if</span> wrap_with_ddp: <span class="hljs-comment"># 如果需要数据并行,则配置DDP</span><span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'torch'</span>:i = torch.cuda.current_device()model = [torchDDP(model_module, device_ids=[i], output_device=i,process_group=mpu.get_data_parallel_group())<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]<span class="hljs-keyword">elif</span> args.DDP_impl == <span class="hljs-string">'local'</span>:model = [LocalDDP(model_module,args.accumulate_allreduce_grads_in_fp32,args.use_contiguous_buffers_in_local_ddp)<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]<span class="hljs-keyword">else</span>:<span class="hljs-keyword">raise</span> NotImplementedError(<span class="hljs-string">'Unknown DDP implementation specified: '</span><span class="hljs-string">'{}. Exiting.'</span>.<span class="hljs-built_in">format</span>(args.DDP_impl))<span class="hljs-keyword">return</span> model

单个进程内的逻辑大致如下,这里 torchDDP 的意思是把 BertModel 之中的 module 用 torchDDP 来封装。

0x05 数据并行

5.1 设置数据

build_train_valid_test_data_iterators 方法会对数据进行处理,提供了 train,valid,test 三种不同的数据集。

def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provider):"""XXX"""args = get_args()(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
<span class="hljs-comment"># Backward compatibility, assume fixed batch size.</span>
<span class="hljs-keyword">if</span> args.iteration &gt; <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> args.consumed_train_samples == <span class="hljs-number">0</span>:args.consumed_train_samples = args.iteration * args.global_batch_size
<span class="hljs-keyword">if</span> args.iteration &gt; <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> args.consumed_valid_samples == <span class="hljs-number">0</span>:<span class="hljs-keyword">if</span> args.train_samples <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:args.consumed_valid_samples = (args.iteration // args.eval_interval) * \args.eval_iters * args.global_batch_size<span class="hljs-comment"># Data loader only on rank 0 of each model parallel group.</span>
<span class="hljs-keyword">if</span> mpu.get_tensor_model_parallel_rank() == <span class="hljs-number">0</span>:<span class="hljs-comment"># Number of train/valid/test samples.</span><span class="hljs-keyword">if</span> args.train_samples:train_samples = args.train_samples<span class="hljs-keyword">else</span>:train_samples = args.train_iters * args.global_batch_sizeeval_iters = (args.train_iters // args.eval_interval + <span class="hljs-number">1</span>) * \args.eval_iterstest_iters = args.eval_iterstrain_val_test_num_samples = [train_samples,eval_iters * args.global_batch_size,test_iters * args.global_batch_size]<span class="hljs-comment"># Build the datasets.</span>train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(train_val_test_num_samples)<span class="hljs-comment"># Build dataloders.</span>train_dataloader = build_pretraining_data_loader(train_ds, args.consumed_train_samples)valid_dataloader = build_pretraining_data_loader(valid_ds, args.consumed_valid_samples)test_dataloader = build_pretraining_data_loader(test_ds, <span class="hljs-number">0</span>)<span class="hljs-comment"># Flags to know if we need to do training/validation/testing.</span>do_train = train_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.train_iters &gt; <span class="hljs-number">0</span>do_valid = valid_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.eval_iters &gt; <span class="hljs-number">0</span>do_test = test_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.eval_iters &gt; <span class="hljs-number">0</span><span class="hljs-comment"># Need to broadcast num_tokens and num_type_tokens.</span>flags = torch.cuda.LongTensor([<span class="hljs-built_in">int</span>(do_train), <span class="hljs-built_in">int</span>(do_valid), <span class="hljs-built_in">int</span>(do_test)])
<span class="hljs-keyword">else</span>:flags = torch.cuda.LongTensor([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>])<span class="hljs-comment"># Broadcast num tokens.</span>
torch.distributed.broadcast(flags,mpu.get_tensor_model_parallel_src_rank(),group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[<span class="hljs-number">0</span>].item()
args.do_valid = flags[<span class="hljs-number">1</span>].item()
args.do_test = flags[<span class="hljs-number">2</span>].item()<span class="hljs-comment"># Build iterators.</span>
dl_type = args.dataloader_type<span class="hljs-keyword">if</span> train_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:train_data_iterator = <span class="hljs-built_in">iter</span>(train_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(train_dataloader))
<span class="hljs-keyword">else</span>:train_data_iterator = <span class="hljs-literal">None</span><span class="hljs-keyword">if</span> valid_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:valid_data_iterator = <span class="hljs-built_in">iter</span>(valid_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(valid_dataloader))
<span class="hljs-keyword">else</span>:valid_data_iterator = <span class="hljs-literal">None</span><span class="hljs-keyword">if</span> test_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:test_data_iterator = <span class="hljs-built_in">iter</span>(test_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(test_dataloader))
<span class="hljs-keyword">else</span>:test_data_iterator = <span class="hljs-literal">None</span><span class="hljs-keyword">return</span> train_data_iterator, valid_data_iterator, test_data_iterator

5.2 DDP

在 get_model 之中,有如下代码使用 DDP。

from megatron.model import DistributedDataParallel as LocalDDP
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

if wrap_with_ddp:
if args.DDP_impl == ‘torch’:
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]

<span class="hljs-keyword">elif</span> args.DDP_impl == <span class="hljs-string">'local'</span>:model = [LocalDDP(model_module,args.accumulate_allreduce_grads_in_fp32,args.use_contiguous_buffers_in_local_ddp)<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]<span class="hljs-keyword">else</span>:<span class="hljs-keyword">raise</span> NotImplementedError(<span class="hljs-string">'Unknown DDP implementation specified: '</span><span class="hljs-string">'{}. Exiting.'</span>.<span class="hljs-built_in">format</span>(args.DDP_impl))

所以我们看看 megatron 自己的 DDP实现。

5.2.1 定义

定义只有注释可以看看,使用连续的(contiguous)内存来存储和累积梯度,每一种类型的张量属于一个统一的内存,可以统一做 allreduce。

class DistributedDataParallel(DistributedDataParallelBase):"""DDP with contiguous buffers options to storre and accumulate gradients.This class:- has the potential to reduce memory fragmentation.- provides the option to do the gradient accumulationin a type other than the params type (for example fp32)
Arguments:module: input model.accumulate_allreduce_grads_in_fp32: if true do the gradient accumulationand the gradient all-reduce all in in float32. If this option istrue, we require `use_contiguous_buffers` to be true too.use_contiguous_buffers: if true, use a contiguous buffer to store thegradients.
"""</span>

5.2.2 初始化

初始化方法的目的是把同类型梯度连续存储。

def __init__(self, module,accumulate_allreduce_grads_in_fp32,use_contiguous_buffers):
<span class="hljs-built_in">super</span>(DistributedDataParallel, self).__init__(module)self.accumulate_allreduce_grads_in_fp32 \= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
<span class="hljs-comment"># If we are using fp32-accumulate-allreduce explicitly</span>
<span class="hljs-comment"># this means we need main grads in a continous buffer.</span>
<span class="hljs-keyword">if</span> self.accumulate_allreduce_grads_in_fp32:<span class="hljs-keyword">assert</span> self.use_contiguous_buffers<span class="hljs-comment"># ===================================</span>
<span class="hljs-comment"># Rest of this part applies only to</span>
<span class="hljs-comment"># the case we use continuous buffers.</span>
<span class="hljs-comment"># ===================================</span>
self._grad_buffers = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> self.use_contiguous_buffers: <span class="hljs-comment"># 这里只考虑连续内存</span>self._grad_buffers = {} <span class="hljs-comment"># 定义buffer</span><span class="hljs-comment"># Simple function to define buffer type.</span><span class="hljs-keyword">def</span> <span class="hljs-title function_">_get_buffer_type</span>(<span class="hljs-params">param</span>): <span class="hljs-comment"># 返回buffer类型</span><span class="hljs-keyword">return</span> torch.<span class="hljs-built_in">float</span> <span class="hljs-keyword">if</span> \self.accumulate_allreduce_grads_in_fp32 <span class="hljs-keyword">else</span> param.dtype<span class="hljs-comment"># First calculate total number of elements per type.</span>type_num_elements = {}<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span><span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span>dtype = _get_buffer_type(param) <span class="hljs-comment"># 获取参数类型</span>type_num_elements[dtype] = type_num_elements.get(dtype, <span class="hljs-number">0</span>) \+ param.data.nelement() <span class="hljs-comment"># 该类型参数数目做相应增加</span><span class="hljs-comment"># 目前 type_num_elements 是各种类型参数的个数          </span><span class="hljs-comment"># Allocate the buffer.</span><span class="hljs-keyword">for</span> dtype, num_elements <span class="hljs-keyword">in</span> type_num_elements.items(): <span class="hljs-comment"># 遍历各种类型</span>self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) <span class="hljs-comment"># 分配内存</span><span class="hljs-comment"># 这里是假定反向传播是参数的反方向,存储每个参数梯度的起始位置    </span><span class="hljs-comment"># Assume the back prop order is reverse the params order, </span><span class="hljs-comment"># store the start index for the gradients.</span><span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span><span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span>dtype = _get_buffer_type(param) <span class="hljs-comment"># 获取参数类型</span>type_num_elements[dtype] -= param.data.nelement() <span class="hljs-comment"># 减少size</span><span class="hljs-comment"># 确定该参数在MemoryBuffer的位置</span>param.main_grad = self._grad_buffers[dtype].get( <span class="hljs-comment"># 获取该参数对应的内存</span>param.data.shape, type_num_elements[dtype])<span class="hljs-comment"># Backward hook.</span><span class="hljs-comment"># Accumalation function for the gradients. We need</span><span class="hljs-comment"># to store them so they don't go out of scope.</span>self.grad_accs = []<span class="hljs-comment"># Loop over all the parameters in the model.</span><span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span><span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span><span class="hljs-comment"># Expand so we get access to grad_fn.</span>param_tmp = param.expand_as(param)<span class="hljs-comment"># Get the gradient accumulator functtion.</span>grad_acc = param_tmp.grad_fn.next_functions[<span class="hljs-number">0</span>][<span class="hljs-number">0</span>] <span class="hljs-comment"># 得到参数对应的梯度函数</span>grad_acc.register_hook(self._make_param_hook(param)) <span class="hljs-comment"># 注册了hook</span>self.grad_accs.append(grad_acc) <span class="hljs-comment"># 统一管理梯度函数,其实就是book keeping作用</span>

5.2.3 内存

MemoryBuffer 是内存抽象。

class MemoryBuffer:
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, numel, dtype</span>):self.numel = numelself.dtype = dtypeself.data = torch.zeros(self.numel, <span class="hljs-comment"># 初始化内存</span>dtype=self.dtype,device=torch.cuda.current_device(),requires_grad=<span class="hljs-literal">False</span>)<span class="hljs-keyword">def</span> <span class="hljs-title function_">zero</span>(<span class="hljs-params">self</span>):<span class="hljs-string">"""Reset the buffer to zero."""</span>self.data.zero_()<span class="hljs-keyword">def</span> <span class="hljs-title function_">get</span>(<span class="hljs-params">self, shape, start_index</span>):<span class="hljs-string">"""Return a tensor with the input `shape` as a view into the1-D data starting at `start_index`."""</span>end_index = start_index + shape.numel() <span class="hljs-comment"># 定位到该张量在内存buffer之中的位置</span><span class="hljs-keyword">assert</span> end_index &lt;= self.numel, \<span class="hljs-string">'requested tensor is out of the buffer range.'</span>buffer_tensor = self.data[start_index:end_index] <span class="hljs-comment"># 拿到内存</span>buffer_tensor = buffer_tensor.view(shape)<span class="hljs-keyword">return</span> buffer_tensor <span class="hljs-comment"># </span>

5.2.4 支撑函数

下面是两个支撑函数,分别是用于拷贝梯度和将buffer清零。

def _make_param_hook(self, param):"""Create the all-reduce hook for backprop."""# Hook used for back-prop.def param_hook(*unused):# Add the gradient to the buffer.if param.grad.data is not None:param.main_grad.add_(param.grad.data) # 把梯度拷贝到连续内存之中# Now we can deallocate grad memory.param.grad = Nonereturn param_hook

def zero_grad_buffer(self):
“”“Set the grad buffer data to zero. Needs to be called at the
begining of each iteration.”“”

assert self._grad_buffers is not None, ‘buffers are not initialized.’
for , buffer in self.grad_buffers.items():
buffer
.zero()

我们假定模型有6个参数,3个 fp32,3 个 fp16,所以被组合成两个连续内存 MemoryBuffer。

5.2.5 梯度规约

allreduce_gradients 是 DDP 对外提供的 API,在后面 train step 之中会调用到。

def allreduce_gradients(self):"""Reduce gradients across data parallel ranks."""# If we have buffers, simply reduce the data in the buffer.if self._grad_buffers is not None:# 连续内存for _, buffer_ in self._grad_buffers.items():  # 遍历各种类型的bufferbuffer_.data /= mpu.get_data_parallel_world_size()torch.distributed.all_reduce( # 统一归并buffer_.data, group=mpu.get_data_parallel_group())else:# Otherwise, bucketize and all-reducebuckets = {} # 否则还是用桶来归并# Pack the buckets.for param in self.module.parameters(): # 遍历梯度if param.requires_grad and param.grad is not None:tp = param.data.type()if tp not in buckets:buckets[tp] = []buckets[tp].append(param) # 同类型的梯度放到对应类型的桶之中param.main_grad = param.grad
    <span class="hljs-comment"># For each bucket, all-reduce and copy all-reduced grads.</span><span class="hljs-keyword">for</span> tp <span class="hljs-keyword">in</span> buckets:bucket = buckets[tp]grads = [param.grad.data <span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> bucket] <span class="hljs-comment"># 把桶里的梯度拿出来</span>coalesced = _flatten_dense_tensors(grads) <span class="hljs-comment"># 打平梯度</span>coalesced /= mpu.get_data_parallel_world_size()torch.distributed.all_reduce( <span class="hljs-comment"># 归并</span>coalesced, group=mpu.get_data_parallel_group())<span class="hljs-keyword">for</span> buf, synced <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(grads, _unflatten_dense_tensors(coalesced, grads)):buf.copy_(synced)

运行时候,分别对两种类型的连续内存做 AllReduce。

0x06 训练

Pretrain 之中会调用 train 来进行训练。

if args.do_train and args.train_iters > 0:iteration = train(forward_step_func,model, optimizer, lr_scheduler,train_data_iterator, valid_data_iterator)

6.1 训练主体

train 是常规的套路,大家基本上按照名字就可以理解。

def train(forward_step_func, model, optimizer, lr_scheduler,train_data_iterator, valid_data_iterator):"""Train the model function."""args = get_args()timers = get_timers()
<span class="hljs-comment"># Write args to tensorboard</span>
write_args_to_tensorboard()<span class="hljs-comment"># Turn on training mode which enables dropout.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:model_module.train() <span class="hljs-comment"># </span><span class="hljs-comment"># Tracking loss.</span>
total_loss_dict = {}<span class="hljs-comment"># Iterations.</span>
iteration = args.iterationreport_memory_flag = <span class="hljs-literal">True</span>
<span class="hljs-keyword">while</span> iteration &lt; args.train_iters:update_num_microbatches(args.consumed_train_samples)loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \train_step(forward_step_func, <span class="hljs-comment"># 训练</span>train_data_iterator,model,optimizer,lr_scheduler)iteration += <span class="hljs-number">1</span>args.consumed_train_samples += mpu.get_data_parallel_world_size() * \args.micro_batch_size * \get_num_microbatches()<span class="hljs-comment"># Logging.</span>loss_scale = optimizer.get_loss_scale().item()params_norm = <span class="hljs-literal">None</span><span class="hljs-keyword">if</span> args.log_params_norm:params_norm = calc_params_l2_norm(model)report_memory_flag = training_log(loss_dict, total_loss_dict,optimizer.param_groups[<span class="hljs-number">0</span>][<span class="hljs-string">'lr'</span>],iteration, loss_scale,report_memory_flag, skipped_iter,grad_norm, params_norm, num_zeros_in_grad)<span class="hljs-comment"># Autoresume</span><span class="hljs-keyword">if</span> args.adlr_autoresume <span class="hljs-keyword">and</span> \(iteration % args.adlr_autoresume_interval == <span class="hljs-number">0</span>):check_adlr_autoresume_termination(iteration, model, optimizer,lr_scheduler)<span class="hljs-comment"># Evaluation</span><span class="hljs-keyword">if</span> args.eval_interval <span class="hljs-keyword">and</span> iteration % args.eval_interval == <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> \args.do_valid:prefix = <span class="hljs-string">'iteration {}'</span>.<span class="hljs-built_in">format</span>(iteration)evaluate_and_print_results(prefix, forward_step_func,valid_data_iterator, model,iteration, <span class="hljs-literal">False</span>)<span class="hljs-comment"># Checkpointing</span>saved_checkpoint = <span class="hljs-literal">False</span><span class="hljs-keyword">if</span> args.exit_signal_handler:signal_handler = get_signal_handler()<span class="hljs-keyword">if</span> <span class="hljs-built_in">any</span>(signal_handler.signals_received()):save_checkpoint_and_time(iteration, model, optimizer,lr_scheduler)sys.exit()<span class="hljs-keyword">if</span> args.save <span class="hljs-keyword">and</span> args.save_interval <span class="hljs-keyword">and</span> \iteration % args.save_interval == <span class="hljs-number">0</span>:save_checkpoint_and_time(iteration, model, optimizer,lr_scheduler)saved_checkpoint = <span class="hljs-literal">True</span><span class="hljs-comment"># Exiting based on duration</span><span class="hljs-keyword">if</span> args.exit_duration_in_mins:train_time = (time.time() - _TRAIN_START_TIME) / <span class="hljs-number">60.0</span>done_cuda = torch.cuda.IntTensor([train_time &gt; args.exit_duration_in_mins])torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX)done = done_cuda.item()<span class="hljs-keyword">if</span> done:<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> saved_checkpoint:save_checkpoint_and_time(iteration, model, optimizer,lr_scheduler)sys.exit()<span class="hljs-comment"># Exiting based on iterations</span><span class="hljs-keyword">if</span> args.exit_interval <span class="hljs-keyword">and</span> iteration % args.exit_interval == <span class="hljs-number">0</span>:<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> saved_checkpoint:save_checkpoint_and_time(iteration, model, optimizer,lr_scheduler)torch.distributed.barrier()sys.exit()<span class="hljs-keyword">return</span> iteration

6.2 训练step

train_step 会获取 get_forward_backward_func 得到 schedule,因为是流水线并行,所以需要 schedule 如何具体训练。

def train_step(forward_step_func, data_iterator,model, optimizer, lr_scheduler):"""Single training step."""args = get_args()timers = get_timers()
<span class="hljs-comment"># Set grad to zero.</span>
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span> <span class="hljs-keyword">and</span> args.use_contiguous_buffers_in_local_ddp:<span class="hljs-keyword">for</span> partition <span class="hljs-keyword">in</span> model:partition.zero_grad_buffer()
optimizer.zero_grad()<span class="hljs-comment"># 获取训练schedule</span>
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( <span class="hljs-comment"># 进行训练</span>forward_step_func, data_iterator, model,optimizer, timers, forward_only=<span class="hljs-literal">False</span>)<span class="hljs-comment"># Empty unused memory</span>
<span class="hljs-keyword">if</span> args.empty_unused_memory_level &gt;= <span class="hljs-number">1</span>:torch.cuda.empty_cache()<span class="hljs-comment"># All-reduce if needed.</span>
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span>:<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:model_module.allreduce_gradients()<span class="hljs-comment"># All-reduce word_embeddings' grad across first and last stages to ensure</span>
<span class="hljs-comment"># that word_embeddings parameters stay in sync.</span>
<span class="hljs-comment"># This should only run for models that support pipelined model parallelism</span>
<span class="hljs-comment"># (BERT and GPT-2).</span>
<span class="hljs-keyword">if</span> mpu.is_rank_in_embedding_group(ignore_virtual=<span class="hljs-literal">True</span>) <span class="hljs-keyword">and</span> \mpu.get_pipeline_model_parallel_world_size() &gt; <span class="hljs-number">1</span>:<span class="hljs-keyword">if</span> mpu.is_pipeline_first_stage(ignore_virtual=<span class="hljs-literal">True</span>):unwrapped_model = model[<span class="hljs-number">0</span>]<span class="hljs-keyword">elif</span> mpu.is_pipeline_last_stage(ignore_virtual=<span class="hljs-literal">True</span>):unwrapped_model = model[-<span class="hljs-number">1</span>]<span class="hljs-keyword">else</span>:  <span class="hljs-comment"># We do not support the interleaved schedule for T5 yet.</span>unwrapped_model = model[<span class="hljs-number">0</span>]unwrapped_model = unwrap_model(unwrapped_model, (torchDDP, LocalDDP, Float16Module))<span class="hljs-keyword">if</span> unwrapped_model.share_word_embeddings:word_embeddings_weight = unwrapped_model.word_embeddings_weight()<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span>:grad = word_embeddings_weight.main_grad<span class="hljs-keyword">else</span>:grad = word_embeddings_weight.gradtorch.distributed.all_reduce(grad, group=mpu.get_embedding_group())<span class="hljs-comment"># Update parameters.</span>
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()<span class="hljs-comment"># Update learning rate.</span>
<span class="hljs-keyword">if</span> update_successful:increment = get_num_microbatches() * \args.micro_batch_size * \args.data_parallel_sizelr_scheduler.step(increment=increment)skipped_iter = <span class="hljs-number">0</span>
<span class="hljs-keyword">else</span>:skipped_iter = <span class="hljs-number">1</span><span class="hljs-comment"># Empty unused memory</span>
<span class="hljs-keyword">if</span> args.empty_unused_memory_level &gt;= <span class="hljs-number">2</span>:torch.cuda.empty_cache()<span class="hljs-keyword">if</span> mpu.is_pipeline_last_stage(ignore_virtual=<span class="hljs-literal">True</span>):<span class="hljs-comment"># Average loss across microbatches.</span>loss_reduced = {}<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> losses_reduced[<span class="hljs-number">0</span>]:losses_reduced_for_key = [x[key] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> losses_reduced]loss_reduced[key] = <span class="hljs-built_in">sum</span>(losses_reduced_for_key) / <span class="hljs-built_in">len</span>(losses_reduced_for_key)<span class="hljs-keyword">return</span> loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
<span class="hljs-keyword">return</span> {}, skipped_iter, grad_norm, num_zeros_in_grad

6.3 获取schedule

get_forward_backward_func 获取 pipeline 的schedule,这里分为 flush 和 interleaving 两种,我们后续会分析这两种schedule。

def get_forward_backward_func():args = get_args()if mpu.get_pipeline_model_parallel_world_size() > 1:if args.virtual_pipeline_model_parallel_size is not None:forward_backward_func = forward_backward_pipelining_with_interleavingelse:forward_backward_func = forward_backward_pipelining_without_interleavingelse:forward_backward_func = forward_backward_no_pipeliningreturn forward_backward_func

训练逻辑大体拓展为:

至此,Megatron 基本架构分析完毕,下一篇我们介绍模型并行设置。

0xFF 参考

[细读经典]Megatron论文和代码详细分析(2)

[细读经典]Megatron论文和代码详细分析(1)

Megatron-LM源码阅读(一)

Megatron-LM源码阅读(二)

megatron学习总结

GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

www.DeepL.com/Translator

https://developer.nvidia.com/gtc/2020/slides/s21496-megatron-lm-training-multi-billion-parameter-language-models-using-model-parallelism.pdf

NVIDIA解决方案架构师深度解析大规模参数语言模型Megatron-BERT

相关文章:

[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构

link [源码解析] 模型并行分布式训练Megatron (2) --- 整体架构 目录 [源码解析] 模型并行分布式训练Megatron (2) --- 整体架构 0x00 摘要0x01 启动 1.1 分布式启动1.2 构造基础 1.2.1 获取模型1.2.2 获取数据集1.2.3 步进函数 1.2.3.1 广播数据0x02 Pretrain0x03 初始化 3.1 …...

kubeadm搭建k8s集群

前置环境&#xff1a; 准备三台虚拟机 192.168.1.104&#xff08;用来做k8s的mater节点&#xff09; 192.168.1.105&#xff08;节点node2&#xff09; 192.168.1.109&#xff08;节点node3&#xff09; 关闭防火墙 systemctl stop firewalld systemctl disable firewalld安装…...

家用无线路由器的 2.4GHz 和 5GHz

家中的无线路由器 WiFi 名称有两个&#xff0c;一个后面带有 “5G” 的标记&#xff0c;这让人产生疑问&#xff1a;“连接带‘5G’的 WiFi 是不是速度更快&#xff1f;” 实际上&#xff0c;这里的 “5G” 并不是移动通信中的 5G 网络&#xff0c;而是指路由器的工作频率为 5G…...

#渗透测试#漏洞挖掘#红蓝攻防#漏洞挖掘#未授权漏洞-Es未授权漏洞

免责声明 本教程仅为合法的教学目的而准备&#xff0c;严禁用于任何形式的违法犯罪活动及其他商业行为&#xff0c;在使用本教程前&#xff0c;您应确保该行为符合当地的法律法规&#xff0c;继续阅读即表示您需自行承担所有操作的后果&#xff0c;如有异议&#xff0c;请立即停…...

Windows 使用 非安装版MySQL 8

1.下载MySQL 8 https://cdn.mysql.com//Downloads/MySQL-8.0/mysql-8.0.40-winx64.zip 2.创建my.ini 下载解压后&#xff0c;发现根目录没有my.ini文件&#xff0c;需手动创建 my.ini # For advice on how to change settings please see # http://dev.mysql.com/doc/refma…...

nginx Rewrite 相关功能

一、Nginx Rewrite 概述 定义 Nginx 的 Rewrite 模块允许对请求的 URI 进行重写操作。它可以基于一定的规则修改请求的 URL 路径&#xff0c;然后将请求定向到新的 URL 地址&#xff0c;这在很多场景下都非常有用&#xff0c;比如实现 URL 美化、网站重构后的 URL 跳转等。主要…...

2024年AI相关的论文写作经验(附实践资料下载)

在撰写AI相关的论文时&#xff0c;以下是一些实用的经验和技巧&#xff1a; 明确写作目标&#xff1a;在开始写作之前&#xff0c;明确你的论文类型&#xff08;期刊论文、毕业论文等&#xff09;和目标&#xff0c;这将影响你的写作方式和工具选择。 AI辅助文献检索&#xff…...

List详解

List详解 在Java中&#xff0c;List是一个接口&#xff0c;它继承自Collection接口。List接口为数据的有序集合提供了操作接口&#xff0c;其中可以包含重复的元素。这个接口的实现类以特定的方式存储元素&#xff0c;允许元素根据索引进行访问&#xff0c;同时还支持通过迭代…...

Flutter实现可拖拽操作Draggable

文章目录 1. Draggable 控件的构造函数主要参数&#xff1a; 2. Draggable 的工作原理3. 常见用法示例 1&#xff1a;基本的拖拽控件解释&#xff1a;示例 2&#xff1a;与 DragTarget 配合使用解释&#xff1a; 4. Draggable 的回调详解5. 总结 Draggable 是 Flutter 中一个用…...

【QSS样式表 - ⑥】:QPushButton控件样式

文章目录 QPushBUtton控件样式QSS示例 QPushBUtton控件样式 常用子控件 常用伪状态 QSS示例 代码&#xff1a; QPushButton {background-color: #99B5D1;color: white;font-weigth: bold;border-radius: 20px; }QPushButton:hover {background-color: red; }QPushButton:p…...

DPO(Direct Preference Optimization)算法解释:中英双语

中文版 DPO paper: https://arxiv.org/pdf/2305.18290 DPO 算法详解&#xff1a;从理论到实现 1. 什么是 DPO&#xff1f; DPO&#xff08;Direct Preference Optimization&#xff09;是一种直接基于人类偏好进行优化的算法&#xff0c;旨在解决从人类偏好数据中训练出表现…...

springboot495基于java的物资综合管理系统的设计与实现(论文+源码)_kaic

摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#xff0c;人们已经发现离不开了互联网。新技术的产生&#xff0c;往往能解决一些老技术的弊端问题。因为传统物资综合管理系统信息管理难度大&#xff0c;容错率低&am…...

JavaScript语言的编程范式

JavaScript&#xff1a;面向对象与函数式编程的双重奏 在编程世界中&#xff0c;JavaScript 无疑是一颗璀璨的明星&#xff0c;它不仅主宰着前端开发领域&#xff0c;还在后端、桌面应用、甚至物联网设备上展现出了强大的生命力。JavaScript 的魅力在于其灵活多变的编程范式&a…...

MyBatis动态 SQL 的执行原理

MyBatis 动态 SQL 是 MyBatis 框架中的一个重要特性&#xff0c;它允许开发者根据条件动态地生成不同的 SQL 语句。通过使用动态 SQL&#xff0c;开发者可以根据传入的参数动态地构建 SQL 查询&#xff0c;这样就避免了写多个 SQL 语句&#xff0c;提升了代码的灵活性和可维护性…...

PostgreSQL自带的一个命令行工具pg_waldump

pg_waldump是PostgreSQL自带的一个命令行工具&#xff0c;用于以人类可读的形式显示PostgreSQL数据库集簇的预写式日志&#xff08;Write-Ahead Logging&#xff0c;WAL&#xff09;。以下是对pg_waldump的详细介绍&#xff1a; 一、主要用途 pg_waldump主要用于调试或教育目…...

K8s 常用资源介绍

在 Kubernetes 中&#xff0c;资源指的是可以在集群中管理的对象&#xff08;Objects&#xff09;。这些资源用来定义和控制应用、服务、以及集群的状态。以下是 Kubernetes 中常见的资源及其用途介绍&#xff1a; 1. 工作负载资源&#xff08;Workloads Resources&#xff09;…...

基于 Python 大数据的拼团购物数据分析系统的设计与实现

标题:基于 Python 大数据的拼团购物数据分析系统的设计与实现 内容:1.摘要 本文设计并实现了一个基于 Python 大数据的拼团购物数据分析系统。通过对拼团购物数据的收集、清洗和分析&#xff0c;系统能够为商家提供用户行为分析、商品销售情况分析等功能&#xff0c;帮助商家更…...

finalshell密码解密

finalshell密码解密 在线网站运行java https://c.runoob.com/compile/10/ import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.math.BigInteger; import java.security.MessageDigest; import java.security.N…...

利用Java爬虫速卖通按关键字搜索AliExpress商品

在这个信息爆炸的时代&#xff0c;数据的价值日益凸显。对于电商领域的从业者来说&#xff0c;能够快速获取商品信息成为了一项重要的技能。速卖通&#xff08;AliExpress&#xff09;作为全球领先的跨境电商平台&#xff0c;拥有海量的商品数据。本文将介绍如何使用Java语言编…...

每天40分玩转Django:Django缓存

一、Django缓存概述 在高并发的Web应用中,缓存是提高性能的重要手段。通过缓存频繁访问的数据,可以显著减少数据库查询和渲染模板的时间,从而加快响应速度,提升用户体验。Django提供了多层级的缓存方案,可以灵活地满足不同场景下的缓存需求。 Django支持的缓存方式包括: 视图…...

matrix-breakout-2-morpheus

将这一关的镜像导入虚拟机&#xff0c;出现以下页面表示导入成功 以root身份打开kali终端&#xff0c;输入以下命令&#xff0c;查看靶机ip arp-scan -l 根据得到的靶机ip&#xff0c;浏览器访问进入环境 我们从当前页面没有得到有用的信息&#xff0c;尝试扫描后台 发现有一个…...

第七节:GLM-4v-9b模型的视觉模型源码解读

文章目录 前言一、EVA2CLIPModel视觉编码模块结构二、PatchEmbedding图像分块源码解读三、GLM的transformer结构源码解读四、GLU映射方法源码解读前言 清华智普的GLM-4v-9b模型,作为优化的多模态大模型,特别适用于国内应用场景,解决了国外模型本地化不足的问题。本专栏提供…...

@RestControllerAdvice和@ControllerAdvice的区别

RestControllerAdvice 和 ControllerAdvice 都是 Spring 框架中的注解&#xff0c;用于定义全局的异常处理、数据绑定、模型属性共享等功能。它们的区别主要体现在返回值的处理和适用的场景。 1. ControllerAdvice 功能&#xff1a; ControllerAdvice 是 Spring MVC 提供的全局…...

c++ 类似与c# 线程 AutoResetEvent 和 ManualResetEvent的实现

在 C 中&#xff0c;没有直接类似于 C# 的 AutoResetEvent 和 ManualResetEvent 的类&#xff0c;但可以通过一些线程同步机制来实现类似的功能。C 提供了一些线程同步原语&#xff0c;如 std::condition_variable 和 std::mutex&#xff0c;这些可以用来模拟类似于 C# 中 Auto…...

简单贪吃蛇小游戏的设计与实现

文章目录 1、知识预备1.1 WIN32 API1.1.1 什么是WIN32 API1.1.2 了解部分WIN32 API1.1.2.1 控制台坐标1.1.2.2 控制台光标1.1.2.3 获取键盘按键情况 2.1 宽字符2.1.1 C语言的国际化2.1.2 宽字符的打印 2、 贪吃蛇游戏设计2.1 游戏开始2.2 游戏运行2.2.1 更新分数2.2.2 按键检测…...

动态规划<五> 子数组问题(含对应LeetcodeOJ题)

目录 引例 经典LeetcodeOJ题 1.第一题 2.第二题 3.第三题 4.第四题 5.第五题 6.第六题 7.第七题 引例 OJ传送门 Leetcode<53> 最大子数组和 画图分析: 使用动态规划解决 1.状态表示 dp[i]表示以i位置为结尾的所有子数组中的最大和 2.状态转移方程 子数组的问题可以…...

计算机网络——期末复习(4)协议或技术汇总、思维导图

思维导图 协议与技术 物理层通信协议&#xff1a;曼彻斯特编码链路层通信协议&#xff1a;CSMA/CD &#xff08;1&#xff09;停止-等待协议&#xff08;属于自动请求重传ARQ协议&#xff09;&#xff1a;确认、否认、重传、超时重传、 &#xff08;2&#xff09;回退N帧协…...

在 RK3568 Linux 系统上使用 TUN 设备:详细教程

RK3568 是一个基于 ARM 架构的处理器,广泛应用于嵌入式系统和物联网设备。Linux 系统上的 TUN(网络隧道)设备提供了一个虚拟的网络接口,允许用户空间程序通过内核与网络栈进行交互。本文将详细介绍如何在 RK3568 上配置和使用 TUN 设备,适用于搭建 VPN 或容器网络等应用场…...

记录一次前端绘画海报的过程及遇到的几个问题

先看效果 使用工具 html2canvas import html2canvas from html2canvas// 绘画前的内容 我就不过多写了<div class"content" ref"contentRef" v-show"!imgShow"><img :src"getReplaceImg(friendObj.coverUrl)" alt"&qu…...

费舍尔信息矩阵全面讲述

费舍尔信息矩阵&#xff08;Fisher Information Matrix&#xff09; 费舍尔信息矩阵是统计学中一个非常重要的概念&#xff0c;尤其在参数估计、最大似然估计&#xff08;MLE&#xff09;和贝叶斯推断中具有广泛的应用。它反映了参数估计的不确定性程度&#xff0c;也可以用来…...

【CSS in Depth 2 精译_094】16.2:CSS 变换在动效中的应用(下)——导航菜单的文本标签“飞入”特效与交错渲染效果的实现

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第五部分 添加动效 ✔️【第 16 章 变换】 ✔️ 16.1 旋转、平移、缩放与倾斜 16.1.1 变换原点的更改16.1.2 多重变换的设置16.1.3 单个变换属性的设置 16.2 变换在动效中的应用 16.2.1 放大图标&am…...

webpack3 webpack4 webpack5 有什么区别

性能优化 Webpack 3 性能优化主要依赖开发者手动配置各种插件。例如&#xff0c;在代码分割方面&#xff0c;需要通过CommonsChunkPlugin来实现公共模块的提取&#xff0c;其配置相对复杂。如果配置不当&#xff0c;可能会导致模块重复打包等问题&#xff0c;影响构建效率和最终…...

vue2 升级为 vite 打包

VUE2 中使用 Webpack 打包、开发&#xff0c;每次打包时间太久&#xff0c;尤其是在开发的过程中&#xff0c;本文记录一下 VUE2 升级Vite 步骤。 安装 Vue2 Vite 依赖 dev 依赖 vitejs/plugin-vue2": "^2.3.3 vitejs/plugin-vue2-jsx": "^1.1.1 vite&…...

[创业之路-206]:《华为战略管理法-DSTE实战体系》- 6-关键成功因素法CSF

目录 一、概述 1、定义与起源 2、关键成功因素的定义 3、关键成功因素的来源 4、关键成功因素的确认方法 5、关键成功因素法的步骤 6、关键成功因素法的应用 7、关键成功因素法的优势与局限性 二、 关键成功因素法CSF的应用 1、企业战略管理 2、项目管理 3、绩效管…...

WebRTC服务质量(08)- 重传机制(05) RTX机制

WebRTC服务质量&#xff08;01&#xff09;- Qos概述 WebRTC服务质量&#xff08;02&#xff09;- RTP协议 WebRTC服务质量&#xff08;03&#xff09;- RTCP协议 WebRTC服务质量&#xff08;04&#xff09;- 重传机制&#xff08;01) RTX NACK概述 WebRTC服务质量&#xff08;…...

Go的select的运行原理

Go语言中的select语句是一种专门用于处理多个通道&#xff08;channel&#xff09;操作的控制结构。其运行原理可以概括为以下几点&#xff1a; 1. 监听多个通道 select语句能够同时监听多个通道上的操作&#xff0c;这些操作可以是发送操作或接收操作。每个通道操作都对应se…...

操作002:HelloWorld

文章目录 操作002&#xff1a;HelloWorld一、目标二、具体操作1、创建Java工程①消息发送端&#xff08;生产者&#xff09;②消息接收端&#xff08;消费者&#xff09;③添加依赖 2、发送消息①Java代码②查看效果 3、接收消息①Java代码②控制台打印③查看后台管理界面 操作…...

3D坐标下,一点在某一线段上的左右方向的判定

3D坐标下&#xff0c;一点在某一线段上的左右方向的判定 代码 代码 #include <iostream> #include <Eigen/Dense>#define M_PI 3.1415926// 计算三点组成平面的参数和变换到XOY平面的变换矩阵 void computePlaneAndTransform(const Eigen::Vector3d& P1, cons…...

Visual Studio 使用 GitHub Copilot 与 IntelliCode 辅助编码 【AI辅助开发系列】

&#x1f380;&#x1f380;&#x1f380;【AI辅助编程系列】&#x1f380;&#x1f380;&#x1f380; Visual Studio 使用 GitHub Copilot 与 IntelliCode 辅助编码Visual Studio 安装和管理 GitHub CopilotVisual Studio 使用 GitHub Copilot 扩展Visual Studio 使用 GitHu…...

Linux高级--2.4.5 靠协议头保证传输的 MAC/IP/TCP/UDP---协议帧格式

任何网络协议&#xff0c;都必须要用包头里面设置写特殊字段来标识自己&#xff0c;传输越复杂&#xff0c;越稳定&#xff0c;越高性能的协议&#xff0c;包头越复杂。我们理解这些包头中每个字段的作用要站在它们解决什么问题的角度来理解。因为没人愿意让包头那么复杂。 本…...

【UE5 C++课程系列笔记】14——GameInstanceSubsystem与动态多播的简单结合使用

效果 通过在关卡蓝图中触发GameInstanceSubsystem包含的委托&#xff0c;来触发所有绑定到这个委托的事件&#xff0c;从而实现跨蓝图通信。 步骤 1. 新建一个C类 这里命名为“SubsystemAndDelegate” 引入GameInstanceSubsystem.h&#xff0c;让“SubsystemAndDelegate”继承…...

PyQt实战——随机涂格子的特色进度条(十一)

系类往期文章&#xff1a; PyQt5实战——多脚本集合包&#xff0c;前言与环境配置&#xff08;一&#xff09; PyQt5实战——多脚本集合包&#xff0c;UI以及工程布局&#xff08;二&#xff09; PyQt5实战——多脚本集合包&#xff0c;程序入口QMainWindow&#xff08;三&…...

.NET 8.0 项目升级到 .NET 9.0

本文项目从.NETCore3.1开始一直延续到目前&#xff0c;如果您没有升级过&#xff0c;请参考以下文章&#xff1a; .Net Core 2.2 升级到 .Net Core 3.1&#xff1a;https://blog.csdn.net/hefeng_aspnet/article/details/131259537 NetCore3.1或Net6.0项目升级到Net7.0&#x…...

用Python写炸金花游戏

文章目录 **代码分解与讲解**1. **扑克牌的生成与洗牌**2. **给玩家发牌**3. **打印玩家的手牌**4. **定义牌的优先级**5. **判断牌型**6. **确定牌型优先级**7. **比较两手牌的大小**8. **打印结果** 完整代码 以下游戏规则&#xff1a; 那么我们要实现的功能&#xff0c;就是…...

深度学习中的并行策略概述:2 Data Parallelism

深度学习中的并行策略概述&#xff1a;2 Data Parallelism 数据并行&#xff08;Data Parallelism&#xff09;的核心在于将模型的数据处理过程并行化。具体来说&#xff0c;面对大规模数据批次时&#xff0c;将其拆分为较小的子批次&#xff0c;并在多个计算设备上同时进行处…...

电商平台能挡住恶意网络爬虫的攻击吗?

爬虫盗取电商数据的步骤 爬虫技术作为一种数据获取工具&#xff0c;正逐渐成为电商平台的一大隐患。网络爬虫不仅能够获取商家关键信息并滋生仿冒网站&#xff0c;还能收集用户敏感信息&#xff0c;对用户的财产安全和隐私造成严重威胁。同时&#xff0c;爬虫攻击还会扰乱正常…...

Jenkins安装方法二

配置环境 和 Jenkins 官方的 yum 源之后进行安装 # 关闭防火墙 $ sudo systemctl stop firewalld $ sudo systemctl disable firewalld# 安装 EPEL 源 $ sudo yum install -y epel-release # 安装 wget $ sudo yum install -y wget# 配置 Jenkins 官方 yum 源 $ sudo wget -O /…...

Nginx性能优化全方案:打造一个高效服务器

提到前面&#xff1a;一个热衷技术&#xff0c;反对八股的资深研发&#xff0c;不卖课不引流&#xff0c;专注分享高质量教学博客。 如果觉得文章还不错的话&#xff0c;可以点赞收藏关注 支持一下&#xff0c;持续分享高质量技术博客。 如果有什么需要改进的地方还请大佬指出❌…...

【每日学点鸿蒙知识】沙箱目录、图片压缩、characteristicsArray、gm-crypto 国密加解密、通知权限

1、HarmonyOS 如何创建应用沙箱目录&#xff1f; 下载文件&#xff0c;想下载到自己新建的应用沙箱目录&#xff0c;有什么方法实现吗&#xff1f; fs.mkdir可以创建目录 参考文档&#xff1a;https://developer.huawei.com/consumer/cn/doc/harmonyos-references-V5/js-apis…...

XMLHttpRequest的基础知识

get请求 const xml new XMLHttpRequest(); xml.open("GET", "https://jsonplaceholder.typicode.com/todos/1", true); xml.onreadystatechange function () {if (xml.readyState 4 && xml.status 200) {console.log(xml.responseText);} }…...