PyTorch深度学习框架60天进阶学习计划 - 第46天:自动化模型设计(一)
PyTorch深度学习框架60天进阶学习计划 - 第46天:自动化模型设计(一)
第一部分:使用ENAS算法生成图像分类网络
大家好!欢迎来到我们PyTorch深度学习框架60天进阶学习计划的第46天。今天我们要深入探讨一个话题——使用高效神经架构搜索(Efficient Neural Architecture Search, ENAS)算法来自动设计图像分类网络。
1. ENAS简介
ENAS是由谷歌大脑团队在2018年提出的一种高效神经架构搜索方法。与传统NAS和DARTS相比,ENAS的主要创新在于引入了参数共享机制,大大提高了搜索效率。
ENAS的核心思想是:将整个搜索空间视为一个超大的计算图(supergraph),所有可能的模型架构都是这个超图的子图。通过让不同的子网络共享参数,ENAS避免了为每个候选架构单独训练的巨大计算开销。
与DARTS相比,ENAS的主要区别在于:DARTS使用可微分放松使得架构可以用梯度下降优化,而ENAS使用强化学习来学习如何从超图中采样高性能架构。
2. ENAS算法原理
ENAS算法由两个交替进行的步骤组成:
- 子模型采样与训练:使用控制器(Controller)从搜索空间中采样子模型架构,然后训练这些子模型。
- 控制器更新:基于子模型的性能,使用策略梯度(Policy Gradient)方法更新控制器,使其更有可能采样出高性能的架构。
这个过程可以形象地比作"建筑师(控制器)"和"工人(子模型训练)"的协作:建筑师提供设计图纸,工人根据图纸建造并给出反馈,建筑师根据反馈改进设计图纸。
3. ENAS搜索空间
ENAS的搜索空间通常是一个有向无环图(DAG),节点表示特征图,边表示操作。对于卷积网络,常见的候选操作包括:
操作类型 | 描述 | PyTorch实现 |
---|---|---|
3x3 标准卷积 | 基本卷积操作 | nn.Conv2d(C, C, 3, padding=1) |
5x5 标准卷积 | 更大感受野的卷积 | nn.Conv2d(C, C, 5, padding=2) |
3x3 深度可分离卷积 | 参数更少的卷积 | SepConv(C, C, 3, 1) |
5x5 深度可分离卷积 | 更大感受野的深度可分离卷积 | SepConv(C, C, 5, 2) |
3x3 最大池化 | 特征下采样 | nn.MaxPool2d(3, stride=1, padding=1) |
3x3 平均池化 | 另一种下采样方式 | nn.AvgPool2d(3, stride=1, padding=1) |
恒等映射 | 直接连接 | Identity() |
4. ENAS的控制器设计
ENAS控制器通常是一个循环神经网络(RNN),特别是LSTM,它学习生成网络架构。控制器根据当前状态输出每一步的动作概率,然后根据这些概率采样一个动作(比如选择哪种操作)。
以下是控制器的基本设计:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Controller(nn.Module):"""ENAS控制器,用于生成网络架构"""def __init__(self, num_nodes, num_ops, lstm_size=100, lstm_num_layers=1):super(Controller, self).__init__()self.num_nodes = num_nodesself.num_ops = num_opsself.lstm_size = lstm_sizeself.lstm_num_layers = lstm_num_layers# 输入嵌入self.embed = nn.Embedding(num_nodes + num_ops, lstm_size)# LSTM控制器self.lstm = nn.LSTMCell(lstm_size, lstm_size)# 节点选择器self.node_selector = nn.Linear(lstm_size, num_nodes)# 操作选择器self.op_selector = nn.Linear(lstm_size, num_ops)# 存储架构决策self.sampled_arch = []self.sampled_probs = []def forward(self, temperature=1.0):"""生成一个架构样本"""# 初始化LSTM隐藏状态h = torch.zeros(1, self.lstm_size).cuda()c = torch.zeros(1, self.lstm_size).cuda()# 初始化输入x = torch.zeros(1).long().cuda()# 清空之前的采样self.sampled_arch = []self.sampled_probs = []# 循环生成每个节点的连接和操作for node_idx in range(2, self.num_nodes): # 跳过输入节点# 为当前节点选择连接的前驱节点for i in range(node_idx):# 更新LSTM状态embed = self.embed(x)h, c = self.lstm(embed, (h, c))# 计算前驱节点的概率logits = self.node_selector(h) / temperatureprobs = F.softmax(logits, dim=-1)# 采样前驱节点prev_node = torch.multinomial(probs, 1).item()self.sampled_arch.append(prev_node)self.sampled_probs.append(probs[0, prev_node])# 为连接选择操作x = torch.tensor([prev_node]).cuda()embed = self.embed(x)h, c = self.lstm(embed, (h, c))# 计算操作的概率logits = self.op_selector(h) / temperatureprobs = F.softmax(logits, dim=-1)# 采样操作op_id = torch.multinomial(probs, 1).item()self.sampled_arch.append(op_id)self.sampled_probs.append(probs[0, op_id])# 更新输入x = torch.tensor([op_id + self.num_nodes]).cuda()return self.sampled_arch, torch.stack(self.sampled_probs)
5. ENAS的子模型设计
ENAS的子模型是从超图中采样出的具体架构。下面是子模型的基本实现:
class ENASModel(nn.Module):"""ENAS子模型"""def __init__(self, arch, num_nodes, num_ops, C):super(ENASModel, self).__init__()self.arch = archself.num_nodes = num_nodesself.num_ops = num_opsself.C = C # 通道数# 定义候选操作列表self.OPS = nn.ModuleList([nn.Sequential(nn.Conv2d(C, C, 3, padding=1, bias=False),nn.BatchNorm2d(C),nn.ReLU(inplace=False)), # 3x3 标准卷积nn.Sequential(nn.Conv2d(C, C, 5, padding=2, bias=False),nn.BatchNorm2d(C),nn.ReLU(inplace=False)), # 5x5 标准卷积SepConv(C, C, 3, 1), # 3x3 深度可分离卷积SepConv(C, C, 5, 2), # 5x5 深度可分离卷积nn.MaxPool2d(3, stride=1, padding=1), # 3x3 最大池化nn.AvgPool2d(3, stride=1, padding=1), # 3x3 平均池化nn.Identity() # 恒等映射])# 节点特征初始化self.nodes = nn.ModuleList([nn.Sequential(nn.Conv2d(3, C, 3, padding=1, bias=False),nn.BatchNorm2d(C),nn.ReLU(inplace=False)), # 节点0(输入处理)nn.Sequential(nn.Conv2d(C, C, 3, padding=1, bias=False),nn.BatchNorm2d(C),nn.ReLU(inplace=False)) # 节点1])# 分类器self.classifier = nn.Linear(C, 10) # CIFAR-10分类def forward(self, x):# 初始化所有节点的特征node_features = [None] * self.num_nodesnode_features[0] = self.nodes[0](x)node_features[1] = self.nodes[1](node_features[0])# 根据架构描述构建计算图idx = 0for node_idx in range(2, self.num_nodes):# 每个节点的所有输入node_inputs = []for i in range(node_idx):# 获取连接的前驱节点和操作prev_node = self.arch[idx]op_id = self.arch[idx + 1]idx += 2# 计算特征node_input = self.OPS[op_id](node_features[prev_node])node_inputs.append(node_input)# 节点特征为所有输入的和node_features[node_idx] = sum(node_inputs)# 全局平均池化out = F.adaptive_avg_pool2d(node_features[-1], 1)out = out.view(out.size(0), -1)# 分类logits = self.classifier(out)return logits# 定义可分离卷积
class SepConv(nn.Module):def __init__(self, C_in, C_out, kernel_size, padding):super(SepConv, self).__init__()self.op = nn.Sequential(nn.ReLU(inplace=False),nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),nn.BatchNorm2d(C_out),nn.ReLU(inplace=False),nn.Conv2d(C_out, C_out, kernel_size=kernel_size, stride=1, padding=padding, groups=C_out, bias=False),nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False),nn.BatchNorm2d(C_out))def forward(self, x):return self.op(x)
6. ENAS的训练过程
ENAS训练包括两个交替的阶段:子模型训练和控制器训练。下面是一个简化的ENAS训练过程:
def train_enas(controller, shared_model, train_queue, valid_queue, controller_optimizer, shared_optimizer, epochs, device='cuda'):"""ENAS训练主循环"""for epoch in range(epochs):# 1. 训练共享参数shared_model.train()controller.eval()for step, (x, target) in enumerate(train_queue):x, target = x.to(device), target.to(device)# 采样架构with torch.no_grad():arch, _ = controller()# 使用采样架构进行前向计算shared_optimizer.zero_grad()logits = shared_model(x, arch)loss = F.cross_entropy(logits, target)loss.backward()shared_optimizer.step()# 2. 训练控制器shared_model.eval()controller.train()# 在验证集上评估采样架构for step, (x, target) in enumerate(valid_queue):x, target = x.to(device), target.to(device)# 采样架构并记录概率controller_optimizer.zero_grad()arch, probs = controller()# 使用采样架构进行前向计算with torch.no_grad():logits = shared_model(x, arch)reward = compute_reward(logits, target) # 如验证准确率# 使用REINFORCE算法更新控制器log_prob = torch.sum(torch.log(probs))loss = -log_prob * rewardloss.backward()controller_optimizer.step()# 打印当前最佳架构with torch.no_grad():best_arch, _ = controller()print(f"Epoch {epoch}, Best Architecture: {best_arch}")def compute_reward(logits, target):"""计算架构奖励(通常是验证准确率)"""_, predicted = torch.max(logits, 1)correct = (predicted == target).sum().item()reward = correct / target.size(0)return reward
7. ENAS架构搜索流程图
下面是ENAS架构搜索的流程图:
8. ENAS相比DARTS的优势
ENAS和DARTS都是高效的神经架构搜索方法,但它们有不同的特点:
特性 | ENAS | DARTS |
---|---|---|
搜索方法 | 强化学习(离散) | 梯度下降(连续) |
参数共享 | 完全共享 | 软权重共享 |
计算效率 | 高(0.5 GPU天) | 中(1-4 GPU天) |
内存需求 | 低 | 高(需要二阶导数) |
架构离散化 | 不需要 | 需要(连续→离散) |
实现复杂度 | 中等 | 较高 |
第二部分:搜索空间设计对模型性能的影响
9. 搜索空间设计的重要性
搜索空间设计是神经架构搜索中最关键的因素之一。一个好的搜索空间应该满足以下条件:
- 覆盖性:包含足够多样化的架构,包括潜在的高性能架构
- 高效性:避免过于广泛导致搜索困难
- 结构合理性:符合神经网络设计的基本原则和先验知识
搜索空间设计对最终模型性能有直接影响。如果搜索空间不包含高性能架构,即使有最好的搜索算法也无法找到好的模型;反之,如果搜索空间过大,搜索效率会大大降低。
10. 常见的搜索空间类型
我们可以根据设计方式将搜索空间分为以下几类:
- 宏搜索空间(Macro):直接搜索整个网络的连接方式和操作类型,灵活性最高,但搜索难度也最大。
- 微搜索空间(Micro):预定义网络的宏观结构(如层数),只搜索重复单元(cell)的内部结构,平衡了灵活性和搜索效率。
- 层级搜索空间(Hierarchical):结合宏观和微观搜索,按层次化方式定义搜索空间。
下面是不同搜索空间类型的对比:
搜索空间类型 | 灵活性 | 搜索效率 | 典型方法 | 适用场景 |
---|---|---|---|---|
宏搜索空间 | 高 | 低 | 早期NAS | 特定任务定制 |
微搜索空间 | 中 | 高 | ENAS/DARTS | 通用视觉任务 |
层级搜索空间 | 高 | 中 | Auto-DeepLab | 复杂任务 |
11. ENAS不同搜索空间的实现
让我们实现几种不同的ENAS搜索空间,并对比它们的影响:
11.1 基于链式结构的搜索空间
class ChainSearchSpace:"""链式结构搜索空间,每个节点只连接到前一个节点"""def __init__(self, num_layers, num_ops):self.num_layers = num_layersself.num_ops = num_opsdef sample_arch(self, controller):"""从控制器采样架构"""# 只需要采样每层的操作类型arch = []for i in range(self.num_layers):op_id = controller.sample_op()arch.append(op_id)return archdef build_model(self, arch, C, num_classes):"""根据架构构建模型"""layers = []in_channels = 3 # 输入图像通道数# 干细胞层layers.append(nn.Conv2d(in_channels, C, 3, padding=1))layers.append(nn.BatchNorm2d(C))layers.append(nn.ReLU(inplace=True))# 构建主要层for i, op_id in enumerate(arch):# 定义当前层操作if op_id == 0: # 3x3 卷积layers.append(nn.Conv2d(C, C, 3, padding=1))layers.append(nn.BatchNorm2d(C))layers.append(nn.ReLU(inplace=True))elif op_id == 1: # 5x5 卷积layers.append(nn.Conv2d(C, C, 5, padding=2))layers.append(nn.BatchNorm2d(C))layers.append(nn.ReLU(inplace=True))elif op_id == 2: # 3x3 最大池化layers.append(nn.MaxPool2d(3, stride=1, padding=1))elif op_id == 3: # 3x3 平均池化layers.append(nn.AvgPool2d(3, stride=1, padding=1))# 可以添加更多操作类型# 每隔几层下采样if i > 0 and i % 3 == 0:layers.append(nn.MaxPool2d(2, stride=2))# 分类头layers.append(nn.AdaptiveAvgPool2d(1))# 构建序列模型model = nn.Sequential(*layers)model.add_module('classifier', nn.Linear(C, num_classes))return model
11.2 基于单元(Cell)的搜索空间
class CellSearchSpace:"""基于单元的搜索空间,搜索重复单元的内部结构"""def __init__(self, num_cells, num_nodes, num_ops):self.num_cells = num_cells # 单元数量self.num_nodes = num_nodes # 每个单元中的节点数self.num_ops = num_ops # 候选操作数量def sample_arch(self, controller):"""从控制器采样架构"""arch = []for i in range(self.num_cells):cell_arch = []# 为单元中的每个节点采样for j in range(2, self.num_nodes): # 跳过前两个输入节点# 为当前节点采样前驱节点和操作for k in range(j):prev_node = controller.sample_node(k)op_id = controller.sample_op()cell_arch.extend([prev_node, op_id])arch.append(cell_arch)return archdef build_model(self, arch, C, num_classes):"""根据架构构建模型"""model = CellBasedNetwork(arch, self.num_cells, self.num_nodes, self.num_ops, C, num_classes)return modelclass CellBasedNetwork(nn.Module):"""基于单元的网络模型"""def __init__(self, arch, num_cells, num_nodes, num_ops, C, num_classes):super(CellBasedNetwork, self).__init__()self.arch = archself.num_cells = num_cellsself.num_nodes = num_nodesself.num_ops = num_opsself.C = C# 定义干细胞网络self.stem = nn.Sequential(nn.Conv2d(3, C, 3, padding=1, bias=False),nn.BatchNorm2d(C))# 定义单元self.cells = nn.ModuleList()C_prev, C_curr = C, Cfor i in range(num_cells):# 每隔几个单元进行下采样if i in [num_cells//3, 2*num_cells//3]:C_curr *= 2reduction = Trueelse:reduction = Falsecell = Cell(arch[i], C_prev, C_curr, reduction, num_nodes, num_ops)self.cells.append(cell)C_prev = C_curr * num_nodes # 单元输出通道数# 分类器self.global_pooling = nn.AdaptiveAvgPool2d(1)self.classifier = nn.Linear(C_prev, num_classes)def forward(self, x):# 干细胞处理x = self.stem(x)# 通过所有单元for cell in self.cells:x = cell(x)# 分类out = self.global_pooling(x)out = out.view(out.size(0), -1)logits = self.classifier(out)return logitsclass Cell(nn.Module):"""网络中的基本单元"""def __init__(self, arch, C_in, C_out, reduction, num_nodes, num_ops):super(Cell, self).__init__()self.arch = archself.reduction = reductionself.num_nodes = num_nodes# 预处理输入stride = 2 if reduction else 1self.preprocess = nn.Sequential(nn.ReLU(inplace=False),nn.Conv2d(C_in, C_out, 1, stride=stride, bias=False),nn.BatchNorm2d(C_out))# 定义候选操作self.ops = nn.ModuleList()for i in range(num_ops):if i == 0: # 3x3 卷积op = nn.Sequential(nn.ReLU(inplace=False),nn.Conv2d(C_out, C_out, 3, padding=1, bias=False),nn.BatchNorm2d(C_out))elif i == 1: # 5x5 卷积op = nn.Sequential(nn.ReLU(inplace=False),nn.Conv2d(C_out, C_out, 5, padding=2, bias=False),nn.BatchNorm2d(C_out))elif i == 2: # 3x3 可分离卷积op = SepConv(C_out, C_out, 3, 1)elif i == 3: # 5x5 可分离卷积op = SepConv(C_out, C_out, 5, 2)elif i == 4: # 3x3 最大池化op = nn.MaxPool2d(3, stride=1, padding=1)elif i == 5: # 3x3 平均池化op = nn.AvgPool2d(3, stride=1, padding=1)elif i == 6: # 恒等映射op = nn.Identity()self.ops.append(op)def forward(self, x):# 预处理输入x = self.preprocess(x)# 初始化所有节点的特征nodes = [x]# 根据架构构建计算图idx = 0for i in range(2, self.num_nodes):# 为当前节点计算所有输入node_inputs = []for j in range(i):prev_node = self.arch[idx]op_id = self.arch[idx + 1]idx += 2# 计算该输入的特征node_input = self.ops[op_id](nodes[prev_node])node_inputs.append(node_input)# 节点特征为所有输入的和nodes.append(sum(node_inputs))# 连接所有中间节点output = torch.cat(nodes[1:], dim=1)return output
11.3 基于分层搜索空间
class HierarchicalSearchSpace:"""分层搜索空间,同时搜索网络架构和单元结构"""def __init__(self, num_blocks, num_cells_per_block, num_nodes, num_ops):self.num_blocks = num_blocksself.num_cells_per_block = num_cells_per_blockself.num_nodes = num_nodesself.num_ops = num_opsdef sample_arch(self, controller):"""从控制器采样架构"""# 采样整体网络架构network_arch = []for i in range(self.num_blocks):# 为每个块采样单元数量(可变)num_cells = controller.sample_cells_count()# 为每个块采样下采样策略downsample = controller.sample_downsample()network_arch.append((num_cells, downsample))# 采样单元内部结构cell_arch = []for j in range(2, self.num_nodes):# 为当前节点采样前驱节点和操作for k in range(j):prev_node = controller.sample_node(k)op_id = controller.sample_op()cell_arch.extend([prev_node, op_id])return (network_arch, cell_arch)def build_model(self, arch, C, num_classes):"""根据架构构建模型"""network_arch, cell_arch = archmodel = HierarchicalNetwork(network_arch, cell_arch, self.num_nodes, self.num_ops, C, num_classes)return modelclass HierarchicalNetwork(nn.Module):"""分层搜索的网络模型"""def __init__(self, network_arch, cell_arch, num_nodes, num_ops, C, num_classes):super(HierarchicalNetwork, self).__init__()self.network_arch = network_archself.cell_arch = cell_archself.num_nodes = num_nodesself.num_ops = num_opsself.C = C# 干细胞网络self.stem = nn.Sequential(nn.Conv2d(3, C, 3, padding=1, bias=False),nn.BatchNorm2d(C))# 构建网络主体self.blocks = nn.ModuleList()in_channels = Cfor block_id, (num_cells, downsample) in enumerate(network_arch):block = nn.ModuleList()# 确定是否下采样及通道数if block_id > 0 and downsample:stride = 2out_channels = in_channels * 2else:stride = 1out_channels = in_channels# 添加该块中的所有单元for cell_id in range(num_cells):# 第一个单元可能需要下采样if cell_id == 0 and stride == 2:cell = Cell(cell_arch, in_channels, out_channels, True, num_nodes, num_ops)else:cell = Cell(cell_arch, out_channels, out_channels, False, num_nodes, num_ops)block.append(cell)self.blocks.append(block)in_channels = out_channels# 分类器self.global_pooling = nn.AdaptiveAvgPool2d(1)self.classifier = nn.Linear(in_channels, num_classes)def forward(self, x):# 干细胞处理x = self.stem(x)# 通过所有块和单元for block in self.blocks:for cell in block:x = cell(x)# 分类out = self.global_pooling(x)out = out.view(out.size(0), -1)logits = self.classifier(out)return logits
12. 不同搜索空间的对比实验
让我们设计一个实验来对比不同搜索空间对ENAS性能的影响。我们将使用CIFAR-10数据集,并在三种不同的搜索空间上运行ENAS算法:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import numpy as npdef compare_search_spaces():"""对比不同搜索空间的性能"""# 设置参数C = 36 # 初始通道数num_classes = 10 # CIFAR-10epochs = 50# 定义数据加载器transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)valid_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 划分训练集和验证集indices = list(range(len(train_data)))np.random.shuffle(indices)split = int(0.8 * len(indices))train_indices, valid_indices = indices[:split], indices[split:]train_queue = DataLoader(train_data, batch_size=128,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))valid_queue = DataLoader(train_data, batch_size=128,sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_indices))test_queue = DataLoader(valid_data, batch_size=128)# 定义搜索空间search_spaces = {'chain': ChainSearchSpace(num_layers=15, num_ops=7),'cell': CellSearchSpace(num_cells=8, num_nodes=7, num_ops=7),'hierarchical': HierarchicalSearchSpace(num_blocks=3, num_cells_per_block=3, num_nodes=7, num_ops=7)}# 对比结果存储results = {}# 对每个搜索空间运行ENASfor name, search_space in search_spaces.items():print(f"Testing search space: {name}")# 创建控制器controller = Controller(num_nodes=7, num_ops=7, lstm_size=100, lstm_num_layers=1).cuda()# 创建优化器controller_optimizer = optim.Adam(controller.parameters(),lr=0.001)# 记录时间和最佳准确率start_time = time.time()# 运行ENAS搜索best_arch, best_acc = run_enas_search(controller, search_space,train_queue, valid_queue, controller_optimizer,epochs,C,num_classes)# 计算搜索时间search_time = time.time() - start_time# 从头训练最佳架构final_model = search_space.build_model(best_arch, C, num_classes).cuda()final_acc = train_from_scratch(final_model, train_queue, test_queue, epochs=100)# 记录结果results[name] = {'search_time': search_time,'search_acc': best_acc,'final_acc': final_acc}# 打印结果print("\nResults:")print("-" * 50)print(f"{'Search Space':<15} {'Search Time(h)':<15} {'Search Acc(%)':<15} {'Final Acc(%)':<15}")print("-" * 50)for name, result in results.items():print(f"{name:<15} {result['search_time']/3600:<15.2f} {result['search_acc']:<15.2f} {result['final_acc']:<15.2f}")return resultsdef run_enas_search(controller, search_space, train_queue, valid_queue, controller_optimizer, epochs, C, num_classes):"""运行ENAS搜索过程"""best_arch = Nonebest_acc = 0# 初始化共享参数shared_model = SharedModel(search_space, C, num_classes).cuda()shared_optimizer = optim.SGD(shared_model.parameters(),lr=0.05,momentum=0.9,weight_decay=3e-4)for epoch in range(epochs):# 训练共享参数for step, (x, target) in enumerate(train_queue):shared_model.train()controller.eval()x, target = x.cuda(), target.cuda(non_blocking=True)# 采样架构with torch.no_grad():arch, _ = controller()# 构建临时模型model = search_space.build_model(arch, C, num_classes).cuda()model.load_state_dict(shared_model.state_dict(), strict=False)# 前向计算和优化shared_optimizer.zero_grad()logits = model(x)loss = nn.CrossEntropyLoss()(logits, target)loss.backward()shared_optimizer.step()# 更新共享模型参数shared_model.load_state_dict(model.state_dict(), strict=False)# 训练控制器controller.train()shared_model.eval()# 采样多个架构并评估sampled_archs = []accuracies = []for _ in range(10): # 采样10个架构arch, probs = controller()sampled_archs.append(arch)# 构建临时模型model = search_space.build_model(arch, C, num_classes).cuda()model.load_state_dict(shared_model.state_dict(), strict=False)# 在验证集上评估model.eval()correct = 0total = 0with torch.no_grad():for x, target in valid_queue:x, target = x.cuda(), target.cuda(non_blocking=True)logits = model(x)_, predicted = torch.max(logits, 1)total += target.size(0)correct += (predicted == target).sum().item()acc = 100 * correct / totalaccuracies.append(acc)# 更新最佳架构if acc > best_acc:best_acc = accbest_arch = arch# 更新控制器controller_optimizer.zero_grad()baseline = sum(accuracies) / len(accuracies)# 计算所有采样架构的损失loss = 0for i, (arch, acc) in enumerate(zip(sampled_archs, accuracies)):_, probs = controller(arch=arch)log_prob = torch.sum(torch.log(probs))reward = acc - baselineloss -= log_prob * rewardloss = loss / len(sampled_archs)loss.backward()controller_optimizer.step()print(f"Epoch {epoch}: best_acc={best_acc:.2f}%")return best_arch, best_accclass SharedModel(nn.Module):"""共享参数模型"""def __init__(self, search_space, C, num_classes):super(SharedModel, self).__init__()self.search_space = search_spaceself.C = Cself.num_classes = num_classes# 初始化共享参数self.shared_params = nn.ParameterDict()# 初始化干细胞层参数self.shared_params['stem.weight'] = nn.Parameter(torch.zeros(C, 3, 3, 3))self.shared_params['stem.bn.weight'] = nn.Parameter(torch.ones(C))self.shared_params['stem.bn.bias'] = nn.Parameter(torch.zeros(C))# 初始化操作参数for i in range(7): # 7种操作self.shared_params[f'op{i}.weight'] = nn.Parameter(torch.zeros(C, C, 3, 3))self.shared_params[f'op{i}.bn.weight'] = nn.Parameter(torch.ones(C))self.shared_params[f'op{i}.bn.bias'] = nn.Parameter(torch.zeros(C))# 初始化分类器参数self.shared_params['classifier.weight'] = nn.Parameter(torch.zeros(num_classes, C))self.shared_params['classifier.bias'] = nn.Parameter(torch.zeros(num_classes))def forward(self, x):# 必须提供具体架构才能前向计算raise NotImplementedError("SharedModel需要具体架构才能前向计算")def state_dict(self):return self.shared_paramsdef train_from_scratch(model, train_queue, test_queue, epochs=100):"""从头训练最终模型"""criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(),lr=0.025,momentum=0.9,weight_decay=3e-4)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)best_acc = 0for epoch in range(epochs):# 训练model.train()for step, (x, target) in enumerate(train_queue):x, target = x.cuda(), target.cuda(non_blocking=True)optimizer.zero_grad()logits = model(x)loss = criterion(logits, target)loss.backward()optimizer.step()# 测试model.eval()correct = 0total = 0with torch.no_grad():for x, target in test_queue:x, target = x.cuda(), target.cuda(non_blocking=True)logits = model(x)_, predicted = torch.max(logits, 1)total += target.size(0)correct += (predicted == target).sum().item()acc = 100 * correct / totalif acc > best_acc:best_acc = accscheduler.step()if epoch % 10 == 0:print(f"Epoch {epoch}: acc={acc:.2f}%, best_acc={best_acc:.2f}%")return best_acc
13. 搜索空间对比结果分析
根据实验结果,我们可以分析不同搜索空间的优缺点:
通过对比实验结果,我们可以得出以下结论:
-
搜索空间复杂度与性能的权衡:
- 链式结构搜索空间最简单,搜索速度最快,但最终性能有限
- 分层结构搜索空间最复杂,搜索时间最长,但也能找到性能最好的模型
- 基于单元的搜索空间在搜索效率和模型性能之间取得了良好的平衡
-
模型大小与计算复杂度:
- 分层搜索通常会产生更大的模型,参数量和推理延迟相对较高
- 链式结构模型最小,但表达能力有限
- 实际应用时需要根据部署环境限制选择适当的搜索空间
-
搜索稳定性:
- 基于单元的搜索空间通常更稳定,不同运行之间性能波动小
- 分层结构搜索由于空间大,可能需要更多次搜索才能找到最优架构
清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!
相关文章:
PyTorch深度学习框架60天进阶学习计划 - 第46天:自动化模型设计(一)
PyTorch深度学习框架60天进阶学习计划 - 第46天:自动化模型设计(一) 第一部分:使用ENAS算法生成图像分类网络 大家好!欢迎来到我们PyTorch深度学习框架60天进阶学习计划的第46天。今天我们要深入探讨一个话题——使用…...
【上海大学计算机系统结构实验报告】多机环境下MPI并行编程
实验目的 学习编制多进程并行程序实现如下功能: 创建多进程,输出进程号和进程数。运行多进程并行例子程序。编程实现大规模矩阵的并行计算。 实验过程及结果分析 实验环境 操作系统:Ubuntu 20.04开发工具:GCC 9.3.0、OpenMPI…...
实用电脑工具,轻松实现定时操作
软件介绍 如果你的电脑有时候需要像个听话的小助手一样,按照你的指令在特定时间做些事情,比如到了点就关机、开机,或者自动打开某个软件,那你可得了解下这个小帮手啦! 小巧功能却不少 程序定时器是一款超实用的电脑…...
jQuery — 动画和事件
介绍 jQuery动画与事件是提升网页交互的核心工具。动画方面,jQuery通过简洁API实现平滑过渡效果,提供预设方法如slideUp(),支持.animate()自定义CSS属性动画,并内置队列系统实现动画链式执行。开发者可精准控制动画速度、回调时机…...
Kubernetes相关的名词解释kube-proxy插件(3)
什么是kube-proxy? kube-proxy 是一个网络代理组件,运行在每个节点(Node)上,是 Kubernetes 服务(Service)功能的核心实现之一。它的主要职责是通过维护网络规则,实现集群内服务&…...
第3章 垃圾收集器与内存分配策略《深入理解Java虚拟机:JVM高级特性与最佳实践(第3版)》
第3章 垃圾收集器与内存分配策略 3.2 对象已死 Java世界中的所有对象实例,垃圾收集器进行回收前就是确定对象哪些是活着的,哪些已经死去。 3.2.1 引用计数算法 常见的回答是:给对象中添加一个引用计数器,有地方引用࿰…...
MCP是什么?为什么突然那么火?
什么是MCP? MCP全称为Model Context Protocol(模型上下文协议),是由Anthropic公司在2024年11月推出的一个开源协议。Anthropic是一家以其开发的Claude大语言模型而闻名的公司。MCP旨在提供一个通用的开放标准,以简化大型语言模型…...
与终端同居日记:Linux指令の进阶撩拨手册
前情提要: 当你和终端的关系从「早安打卡」进阶到「深夜代码同居」,那些曾经高冷的指令开始展露致命の反差萌—— man 是那个永远在线的钢铁直男说明书,只会说:"想懂我?自己看文档!"(…...
STM32单片机入门学习——第42节: [12-2] BKP备份寄存器RTC实时时钟
写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难,但我还是想去做! 本文写于:2025.04.19 STM32开发板学习——第42节: [12-2] BKP备份寄存器&RTC实时时钟 前言开发板说…...
AI 驱动抗生素发现:从靶点到化合物测试
AI 驱动抗生素发现:从靶点到化合物测试 目录 基于 AI 驱动的研发流程发现抗生素,整合靶点选择和深度学习分子生成,显著提升了候选药物发现效率。结合数据平衡技术,机器学习和 AutoML 能有效提升不平衡数据集分类性能。RibbonFold 是一种新的 AI 模型,可以准确预测淀粉样蛋…...
群晖威联通飞牛等nas如何把宿主机硬盘挂接到可道云docker容器中
可道云系统是用户常用的一款面向个人用户的轻量级私有云存储工具,以高效管理和安全存储为核心,打造便捷的数字化办公体验。但是用户希望把原有其他磁盘中文件挂接到这个新系统中有很大的难度,主要是对linux文件系统理解有很大的误区,认为目录结构是固定的…...
用 R 语言打造交互式叙事地图:讲述黄河源区生态变化的故事
目录 🌟 项目背景:黄河源头的生态变迁 🧰 技术栈介绍 🗺️ 最终效果预览 💻 项目构建步骤 1️⃣ 数据准备 2️⃣ 构建 Leaflet 地图 3️⃣ 使用 scrollama 实现滚动触发事件 4️⃣ 使用 R Markdown / Quarto 打包发布 🎬 效果展示截图 📦 完整代码仓库 …...
opencv(双线性插值原理)
双线性插值是一种图像缩放、旋转或平移时进行像素值估计的插值方法。当需要对图像进行变换时,特别是尺寸变化时,原始图像的某些像素坐标可能不再是新图像中的整数位置,这时就需要使用插值算法来确定这些非整数坐标的像素值。 双线性插值的工…...
Flutter 弹窗队列管理:实现一个线程安全的通用弹窗队列系统
在开发复杂的 Flutter 应用时,弹窗的管理往往是一个令人头疼的问题。尤其是在多个弹窗需要按顺序显示,或者弹窗的显示需要满足特定条件时,手动管理弹窗的显示和隐藏不仅繁琐,还容易出错。为了解决这个问题,我们可以实现…...
Linux压缩与解压命令完全指南:tar.gz、zip等格式详解
Linux压缩与解压命令完全指南:tar.gz、zip等格式详解 在Linux系统中,文件压缩和解压是日常操作中不可或缺的一部分。本文将全面介绍Linux下常用的压缩和解压命令,包括tar.gz、tar、zip等格式的区别和使用方法,帮助你高效管理文件…...
doris/clickhouse常用sql
一、doris常用SQL 1、doris统计数据库的总大小(单位:MB) SELECT table_schema AS database_name,ROUND(SUM(data_length) / 1024 / 1024, 2) AS database_size_MB FROM information_schema.tables WHERE table_schema NOT IN (information…...
实现AWS Lambda函数安全地请求企业内部API返回数据
需要编写一个Lambda函数在AWS云上运行,它需要访问企业内部的API获取JSON格式的数据,企业有网关和防火墙,API有公司的okta身份认证,通过公司的域账号来授权访问,现在需要创建一个专用的域账号,让Lambda函数访…...
【Easylive】Interact与Web服务调用实例及网关安全拦截机制解析
【Easylive】项目常见问题解答(自用&持续更新中…) 汇总版 easylive-cloud-interacteasylive-cloud-web 1. 不同服务(web和interact)之间的调用方式 调用流程 • 角色分工: • easylive-cloud-web:作…...
【HDFS】EC重构过程中的校验功能:DecodingValidator
一、动机 DecodingValidator是在HDFS-15759中引入的一个用于校验EC数据重构正确性的组件。 先说下引入DecodingValidator的动机,据很多已知的ISSUE(如HDFS-14768, HDFS-15186, HDFS-15240,这些目前都已经fix了)反馈, EC在重构的时候可能会有各种各样的问题,导致数据错误…...
Chromium 134 编译指南 macOS篇:编译优化技巧(六)
1. 引言 在Chromium 134的开发过程中,优化编译速度是提升开发效率的关键因素。本文将重点介绍如何使用ccache工具来加速C/C代码的编译过程,特别是在频繁切换分支和修改代码时。通过合理配置和使用这些工具,您将能够显著减少编译时间…...
FPGA——基于DE2_115实现DDS信号发生器
FPGA——基于DE2_115实现DDS信号发生器 文章目录 FPGA——基于DE2_115实现DDS信号发生器一、实验要求二、实现过程(1)新建工程 二、波形存储器ROM(1)方波模块(2)正弦波形存储器(3)锁…...
PHP中的ReflectionClass讲解【详细版】
快餐: ReflectionClass精简版 在PHP中,ReflectionClass是一个功能强大的反射类,它就像是一个类的“X光透视镜”,能让我们在程序运行时深入了解类的内部结构和各种细节。 一、反射类的基本概念和重要性 反射是指在程序运行期间获…...
嵌入式面试题解析:常见基础知识点详解
在嵌入式领域的面试中,基础知识点的考察尤为重要。下面对一些常见面试题进行详细解析,帮助新手一步步理解。 一、原码、反码、补码及补码的好处 题目 什么叫原码、反码、补码?计算机学科引入补码有什么好处? 在计算机科学中&a…...
GPU渲染阶段介绍+Shader基础结构实现
GPU是什么 (CPU)Center Processing Unit:逻辑编程 (GPU)Graphics Processing Unit:图形处理(矩阵运算,数据公式运算,光栅化) 渲染管线 渲染管线也称为渲染流水线&#x…...
08-DevOps-向Harbor上传自定义镜像
harbor创建完成,往harbor镜像仓库中上传自定义的镜像,包括新建项目、docker配置镜像地址、镜像重命名、登录harbor、推送镜像这几个步骤,具体操作如下: harbor中新建项目 访问级别公开,代表任何人都可以拉取仓库中的镜…...
C++学习之路,从0到精通的征途:vector类的模拟实现
目录 一.vector的介绍 二.vector的接口实现 1.成员变量 2.迭代器 (1)begin (2)end 3.容量操作 (1)size,capacity (2)reserve (3)resize…...
嵌入式软件--stm32 DAY 2
大家学习嵌入式的时候,多多学习用KEIL写代码,虽然作为编译器,大家常用vscode等常用工具关联编码,但目前keil仍然是主流工具之一,学习掌握十分必要。 1.再次创建项目 1.1编译器自动生成文件 1.2初始文件 这样下次创建新…...
多模态大语言模型arxiv论文略读(二十九)
Temporal Insight Enhancement: Mitigating Temporal Hallucination in Multimodal Large Language Models ➡️ 论文标题:Temporal Insight Enhancement: Mitigating Temporal Hallucination in Multimodal Large Language Models ➡️ 论文作者:Li Su…...
【人工智能学习-01-01】20250419《数字图像处理》复习材料的word合并PDF,添加页码
前情提要 20250419今天是上师大继续教育人工智能专升本第一学期的第一次线下课。 三位老师把视频课的内容提炼重点再面授。(我先看了一遍视频,但是算法和图像都看不懂,后来就直接挂分刷满时间,不看了) 今天是面对面授…...
B端APP设计:打破传统限制,为企业开启便捷新通道
B端APP设计:打破传统限制,为企业开启便捷新通道 在数字化转型浪潮中,企业级移动应用正突破传统管理系统的功能边界,演变为连接产业链各环节的核心枢纽。本文从技术架构革新、交互模式进化、安全防护升级三个维度,系统…...
【多线程5】面试常考锁知识点
文章目录 悲观/乐观锁挂起等待锁/自旋锁偏向锁轻量级/重量级锁锁升级CASCAS引发的ABA问题解决方案 原子类 公平/不公平锁可重入锁ReentrantLock读写锁 Callable接口 这里的“悲观”“乐观”“挂起等待”“自旋”“轻量级”“重量级”“公平”“非公平”“可重入”仅代表某个锁的…...
Linux第一个系统程序——进度条
1.回车与换行 回车(CR, \r): 作用:将光标移动到当前行的行首(最左侧),但不换到下一行。 历史来源:源自打字机的“回车”操作——打字机的滑架(Carriage)需…...
C 语 言 --- 指 针 3
C 语 言 --- 指 针 3 函 数 指 针函 数 指 针 数 组代 码 解 释回 调 函 数 - - - qsort模 拟 实 现 qsort 函 数 总结 💻作 者 简 介:曾 与 你 一 样 迷 茫,现 以 经 验 助 你 入 门 C 语 言 💡个 人 主 页:笑口常开x…...
蓝桥杯之递归
1.数字三角形 题目描述 上图给出了一个数字三角形。从三角形的顶部到底部有很多条不同的路径。对于每条路径,把路径上面的数加起来可以得到一个和,你的任务就是找到最大的和(路径上的每一步只可沿左斜线向下或右斜线向下走)。 输…...
学习笔记十八——Rust 封装
🧱 Rust 封装终极指南:结构体、模块、Trait、目录结构与模块引用 🧭 目录导航 什么是封装?Rust 的封装理念Rust 的封装工具总览模块(mod)和访问控制(pub)详解结构体和枚举ÿ…...
【面试向】点积与注意力机制,逐步编码理解自注意力机制
点积(dot product)两个向量点积的数学公式点积(dot product)与 Attention 注意力机制(Attention)注意力机制的核心思想注意力机制中的缩放点积自注意力机制中,谁注意谁? 逐步编码理解…...
基础数学知识-线性代数
1. 矩阵相乘 c i j = a i k ∗ b k j c_{ij} = a_{ik} * b_{kj} cij=aik∗bkj 1. 范数 1. 向量的范数 任意一组向量设为 x ⃗ = ( x 1 , x 2 , . . . , x N ) \vec{x}=(x_1,x_2,...,x_N) x =(x1,x2,...,xN) 如下: 向量的1范数: 向量的各个元素的绝对值之和∥ …...
【KWDB 创作者计划】_上位机知识篇---Docker容器
文章目录 前言1. Docker 容器是什么?隔离性轻量级可移植性可复用性 2. Docker 核心概念镜像容器仓库Dockerfile 3. Docker 基本使用(1) 安装 Docker(2) 容器生命周期管理(3) 镜像管理(4) 进入容器内部(5) 数据持久化(挂载卷)(6) 网络管理 4. …...
指针函数和函数指针
指针函数本质是一个函数,只是函数的返回值是指针类型 函数指针本质是一个指针,只是这个指针指向的是一个函数 指针函数 函数有很多类型的返回值,例如 short funcA(参数列表) // 表示该函数返回值是一个short类型 void funcA(参数列表) // 表…...
案例驱动的 IT 团队管理:创新与突破之路:第六章 组织进化:从案例沉淀到管理体系-6.1 案例库建设方法论-6.1.2案例分级与标签体系
👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 案例分级与标签体系构建方法论:IT团队知识管理的结构化实践1. 案例库建设的战略价值与核心挑战1.1 案例管理的战略定位1.2 分级标签体系的核心价值 2. 案例分级体…...
sqlilabs-Less之HTTP头部参数的注入——基础篇
Less-18 user-agent报错注入 这一关的代码漏洞点出现在了insert语句,因为这里没有对user-agent和ip_address进行过滤,,并且输出了mysql的错误信息 补充知识点 PHP里用来获取客户端IP的变量 $_SERVER[HTTP_CLIENT_IP] #这个很少使用…...
java多线程相关内容
java线程创建的方式 一共有四种方式 继承 Thread 类:本质上是实现了 Runnable 接口的一个实例,代表一个线程的实例 启动线程的唯一方 法就是通过 **Thread 类的 start()**实例方法。start()方法是一个 native 方法,它将启动一个新线 程&…...
Windows Server .NET Core 应用程序部署到 IIS 解决首次访问加载慢的问题
第一篇: Windows .NET Core 应用程序部署到 IIS 解决首次访问加载慢的问题 第二篇:Windows Server .NET Core 应用程序部署到 IIS 解决首次访问加载慢的问题 第三篇:Windows .NET Core 应用程序部署到 IIS 解决首次访问加载慢的问题 设置…...
ubuntu24.04上使用qemu+buildroot+uboot+linux+tftp+nfs模拟搭建vexpress-ca9嵌入式linux开发环境
1 准备工作 1.1 安装依赖工具 sudo apt-get update && sudo apt-get install build-essential git bc flex libncurses5-dev libssl-dev device-tree-compiler1.2 安装arm交叉编译工具链 sudo apt install gcc-arm-linux-gnueabihf安装之后,在终端输入ar…...
Cocos Creater打包安卓App添加隐私弹窗详细步骤+常见问题处理
最终演示效果,包含所有代码内容 + 常见错误问题处理 点击服务协议、隐私政策,跳转到相关网页, 点击同意进入游戏,不同意关闭应用 一,添加Activity,命名为MyLaunchActivity 二,编写MyLaunchActivity.java的内容 package com.cocos.game.launch;import android.os.Bund…...
UI文件上传
1、文件上传:文件上传是自动化中比较麻烦棘手的部分。 有些场景我们需要上传本地文件到项目里。这种比较麻烦,因为需要点开文件上传的窗口后,打开的是windows的文件选择窗口, 而selenium是无法操作这个窗口的。 selenium只能操作…...
2.凸包优化求解
1.减而治之(Decrease and Conquer) 插入排序 典型的减而治之算法就是插入排序方法 插入排序法: 在未排序中选择一个元素,插入到已经排序号的序列中 将凸包也采用减而治之的方法 2.In-Convex-Polygon Test 怎么判断引入的极点存在于多边形里面还是外面࿱…...
从0开发一个unibest+vue3项目,使用vscode编辑器开发,总结vue2升vue3项目开始,小白前期遇到的问题
开头运行可看官网 链接: unibest官网 一:vscode中vue3代码显示报错标红波浪线 去查看扩展商店发现一些插件都弃用了,例如h5的插件以及vue老插件 解决办法:下载Vue - Official插件(注意:横杠两边是要加空格的ÿ…...
jmeter中文乱码问题解决
修改jmeter.properties配置文件 进入JMeter安装目录的bin文件夹,找到jmeter.properties文件。搜索参数sampleresult.default.encodingUTF-8,取消注释(删除行首的#),并将其值改为UTF-8。保存文件并重启JMeter生效…...
额外篇 非递归之美:归并排序与快速排序的创新实现
个人主页:strive-debug 快速排序非递归版本 非递归版本的快速排序是为了解决在空间不够的情况下,利用栈来模拟递归的过程。 递归版本的快速排序是空间换时间,好实现。 实现思路: 1. 创建一个栈,将数组的右边界下标和…...