MindSpore框架学习项目-ResNet药物分类-构建模型
目录
2.构建模型
2.1定义模型类
2.1.1 基础块ResidualBlockBase
ResidualBlockBase代码解析
2.1.2 瓶颈块ResidualBlock
ResidualBlock代码解释
2.1.3 构建层
构建层代码说明
2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
ResNet组建类代码解析
2.1.5 实例化resnet_xx网络
实例化resnet_xx网络代码分析
2.2模型初始化
模型初始化代码解析
本项目可以在华为云modelart上租一个实例进行,也可以在配置至少为单卡3060的设备上进行
https://console.huaweicloud.com/modelarts/
Ascend环境也适用,但是注意修改device_target参数
需要本地编译器的一些代码传输、修改等可以勾上ssh远程开发
说明:项目使用的数据集来自华为云的数据资源。项目以深度学习任务构建的一般流程展开(数据导入、处理 > 模型选择、构建 > 模型训练 > 模型评估 > 模型优化)。
主线为‘一般流程’,同时代码中会标注出一些要点(# 要点1-1-1:设置使用的设备
)作为支线,帮助学习mindspore框架在进行深度学习任务时一些与pytorch的差异。
可以只看目录中带数字标签的部分来快速查阅代码。
2.构建模型
2.1定义模型类
要求:
补充如下代码的空白处
主要完成:
1. 实现1个卷积层和1个ReLU激活函数的定义
2. 实现ResidualBlockBase和ResidualBlock模块的残差连接,并补全self.layer4的参数
导入mindspore训练环节(包括模型构建、激活函数、反向传播、损失函数等需要的库)
from mindspore import Model
from mindspore import context
import mindspore.ops as ops
from mindspore import Tensor, nn, set_context, GRAPH_MODE, train
from mindspore import load_checkpoint, load_param_into_net
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal
初始化:weight_init = Normal(mean=0, sigma=0.02) 用于初始化卷积层;
gamma_init = Normal(mean=1, sigma=0.02) 用于初始化批归一化层
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
2.1.1 基础块ResidualBlockBase
conv1 和 conv2 的参数设置:
conv1 负责处理输入数据的空间下采样(通过 stride 参数)或通道数变换(通过 out_channels),同时进行第一次特征提取。
conv2 固定为 3×3 卷积,不改变空间尺寸(默认 stride=1),仅对 conv1 的输出进一步提取特征。
(卷积层当池化层用)
class ResidualBlockBase(nn.Cell):
expansion: int = 1def __init__(self, in_channel: int, out_channel: int,
stride: int = 1, norm: Optional[nn.Cell] = None,
down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:
self.norm = nn.BatchNorm2d(out_channel)else:
self.norm = norm# 要点2-1-1:实现1个卷积层和一个ReLU激活函数的定义# 1. Conv2d:# in_channels (int) - Conv2d层输入Tensor的空间维度。# out_channels (int) - Conv2d层输出Tensor的空间维度。# kernel_size (Union[int, tuple[int]]) - 指定二维卷积核的高度和宽度, 卷积核大小为3X3;# stride (Union[int, tuple[int]],可选) - 二维卷积核的移动步长。
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride,
weight_init=weight_init)
self.conv2 = nn.Conv2d(in_channel, out_channel,
kernel_size=3, weight_init=weight_init)# 2. ReLU:逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。需要调用MindSpore的相关API.
self.relu = nn.ReLU()
self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""
identity = x out = self.conv1(x)
out = self.norm(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm(out)if self.down_sample is not None:
identity = self.down_sample(x)# 要点2-1-2: # 1. 实现ResidualBlockBase模块的残差连接
out = out+identity # 输出为主分支与shortcuts之和
out = self.relu(out)return out
ResidualBlockBase代码解析
核心类定义:ResidualBlockBase
作用:实现残差网络的基础块(Basic Block),包含主分支(卷积路径)和短路连接(Shortcut),解决深层网络梯度消失问题。
输入:
in_channel :输入特征图通道数
out_channel :输出特征图通道数
stride :卷积步长(控制特征图尺寸变化,用于下采样)
norm :归一化层(默认使用 BatchNorm2d)
down_sample :下采样模块(用于调整短路连接的维度,确保与主分支输出维度一致)
要点 2-1-1:定义卷积层和 ReLU 激活函数
Conv2d 层实现
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, # 3x3卷积核,步长由参数控制
weight_init=weight_init) # 权重初始化(正态分布,sigma=0.02)
self.conv2 = nn.Conv2d(in_channel, out_channel, # 第二个卷积层输入通道仍为in_channel(残差块基础版)
kernel_size=3, weight_init=weight_init)
关键参数:
kernel_size=3:固定使用 3x3 卷积核,符合残差块基础设计(如 ResNet-18/34 的 Basic Block)。
stride=stride:第一个卷积层的步长由外部传入(用于下采样),第二个卷积层步长固定为 1(保持尺寸)。
weight_init=weight_init:使用正态分布初始化权重(Normal(mean=0, sigma=0.02)),避免梯度爆炸 / 消失。
ReLU 激活函数
self.relu = nn.ReLU() # 直接调用MindSpore的ReLU模块,逐元素计算max(0, x)
作用:引入非线性,避免网络退化为线性层,同时缓解梯度消失。
要点 2-1-2:实现残差连接
if self.down_sample is not None:
identity = self.down_sample(x) # 下采样:调整短路连接的维度(通道数/尺寸)
out = out + identity # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out) # 最后一次ReLU激活,输出非线性特征
残差连接逻辑:
短路连接(Identity Mapping):当输入x的维度(通道数 / 尺寸)与主分支输出out一致时,直接相加(identity = x)。
若维度不一致(如通道数增加或尺寸缩小),通过down_sample模块对x进行下采样(通常是 1x1 卷积 + 步长调整),确保形状匹配。
相加操作:
核心公式:输出 = 主分支输出 + 短路连接,强制保留原始输入信息,使梯度能直接回传至浅层。
激活函数位置:相加后再进行一次 ReLU 激活,确保输出为非线性特征,符合 ResNet 设计规范。
关键模块解析
1. 归一化层(Norm)处理
if not norm:
self.norm = nn.BatchNorm2d(out_channel) # 默认使用BatchNorm2d
else:
self.norm = norm # 支持自定义归一化层(如GroupNorm)
作用:对卷积输出进行归一化,加速训练并提升模型鲁棒性。
位置:每个卷积层后立即接归一化层,再接 ReLU 激活(Conv→Norm→ReLU 顺序)。
2. 下采样模块(down_sample)
self.down_sample = down_sample # 由外部传入,通常是1x1卷积+步长
触发场景:当in_channel ≠ out_channel或stride > 1时,需通过下采样调整短路连接的维度。
典型实现:
down_sample = nn.SequentialCell([
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(out_channel)
])
通过 1x1 卷积调整通道数,步长调整尺寸,确保与主分支输出形状一致。
3. 权重初始化策略
weight_init = Normal(mean=0, sigma=0.02) # 卷积层权重初始化
gamma_init = Normal(mean=1, sigma=0.02) # BatchNorm的γ参数初始化(未在当前代码中使用)
正态分布初始化:较小的标准差(σ=0.02)避免初始权重过大导致激活值饱和,符合深度学习框架的常见实践(如 PyTorch 的默认初始化)。
2.1.2 瓶颈块ResidualBlock
class ResidualBlock(nn.Cell):
expansion = 4 def __init__(self, in_channel: int, out_channel: int,
stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channel, out_channel,
kernel_size=1, weight_init=weight_init)
self.norm1 = nn.BatchNorm2d(out_channel)
self.conv2 = nn.Conv2d(out_channel, out_channel,
kernel_size=3, stride=stride,
weight_init=weight_init)
self.norm2 = nn.BatchNorm2d(out_channel)
self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
kernel_size=1, weight_init=weight_init)
self.norm3 = nn.BatchNorm2d(out_channel * self.expansion) self.relu = nn.ReLU()
self.down_sample = down_sample def construct(self, x): identity = x out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out) if self.down_sample is not None:
identity = self.down_sample(x)
# 2. 实现ResidualBlock模块的残差连接
out = out+identity # 输出为主分支与shortcuts之和
out = self.relu(out) return out
ResidualBlock代码解释
核心类定义:ResidualBlock(瓶颈块)
作用:实现深层残差网络的瓶颈结构,通过 “降维 - 特征提取 - 升维” 减少计算量,支持构建更深的网络(如 50 层以上)。
核心参数:
expansion=4 :升维因子(固定为 4,符合 ResNet 设计规范),即最后一个 1x1 卷积将通道数扩展为out_channel×4。
in_channel :输入特征图通道数
out_channe :中间层特征图通道数(经 1x1 卷积降维后的通道数)
stride :3x3 卷积的步长(控制特征图尺寸变化,用于下采样)
down_sample :下采样模块(调整短路连接的维度,确保与主分支输出维度一致)
要点:瓶颈块的结构设计
瓶颈块通过三层卷积实现 “降维→特征提取→升维”,显著减少计算量(对比基础块的两层 3x3 卷积):
第一层:1x1 卷积(降维)
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, weight_init=weight_init)
作用:将输入通道数从in_channel降为out_channel(如输入 256→输出 64),减少后续 3x3 卷积的计算量。
卷积核大小:1x1,仅改变通道数,不改变特征图尺寸(stride=1,无填充)。
第二层:3x3 卷积(特征提取)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, weight_init=weight_init)
作用:在降维后的低维空间提取空间特征(如边缘、纹理)。
关键参数:stride=stride:支持下采样(如 stride=2 时特征图尺寸减半),由外部传入(用于构建不同 stage 的残差块)。
kernel_size=3:保持 3x3 卷积核,确保感受野与基础块一致。
第三层:1x1 卷积(升维)
self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=1, weight_init=weight_init)
作用:将通道数从out_channel升至out_channel×expansion(如 64→256),与短路连接维度匹配(因 ResNet 的 stage 设计中,输出通道数通常是输入的 4 倍)。
核心公式:输出通道数 = out_channel × expansion(此处expansion=4是 ResNet 瓶颈块的固定设计)。
要点:残差连接与维度匹配
if self.down_sample is not None:
identity = self.down_sample(x) # 调整短路连接的维度
out = out + identity # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out) # 最后一次ReLU激活
触发下采样的场景:
当以下任意条件成立时,需通过down_sample调整短路连接:输入通道数in_channel ≠ 输出通道数out_channel×expansion(升维导致通道数不匹配)。
stride > 1(特征图尺寸缩小,短路连接需同步下采样)。
下采样模块实现(通常由外部传入):
down_sample = nn.SequentialCell([
nn.Conv2d(in_channel, out_channel*self.expansion, kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(out_channel*self.expansion)
])
通过 1x1 卷积调整通道数,步长调整尺寸,确保identity与主分支输出out形状一致(通道数、高度、宽度均相同)。
归一化层与激活函数的顺序
# 每一层的处理流程:Conv → BatchNorm → ReLU
out = self.conv1(x) # 1x1卷积(降维)
out = self.norm1(out) # BatchNorm2d
out = self.relu(out) # ReLU激活
out = self.conv2(out) # 3x3卷积(特征提取)
out = self.norm2(out) # BatchNorm2d
out = self.relu(out) # ReLU激活
out = self.conv3(out) # 1x1卷积(升维)
out = self.norm3(out) # BatchNorm2d(升维后归一化)
设计原则:符合 ResNet 的 “Post-Normalization” 架构,即在卷积后立即归一化,再激活,确保每一层输入处于稳定分布,加速训练收敛。
权重初始化策略
weight_init = Normal(mean=0, sigma=0.02) # 与基础块一致,小方差初始化避免梯度爆炸
作用:对 1x1 和 3x3 卷积的权重进行正态分布初始化,确保初始权重较小,激活值不会因过大输入导致饱和(如 ReLU 的负数区域失活)。
与基础残差块(ResidualBlockBase)的区别
2.1.3 构建层
根据给定参数构建由指定数量残差块组成的网络层,包括处理下采样及层间连接等
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
channel: int, block_nums: int, stride: int = 1):
down_sample = None if stride != 1 or last_out_channel != channel * block.expansion: down_sample = nn.SequentialCell([
nn.Conv2d(last_out_channel, channel * block.expansion,
kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
]) layers = []
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample)) in_channel = channel * block.expansion for _ in range(1, block_nums): layers.append(block(in_channel, channel)) return nn.SequentialCell(layers)
构建层代码说明
功能定位
ResNet 的整体架构就是通过make_layer函数不断堆叠残差块,形成多个 stage,每个 stage 内部保持相同的通道数,相邻 stage 之间通过下采样调整尺寸和通道数,最终构建出深度神经网络。
输入参数:
last_out_channel :上一层输出的特征图通道数(用于判断是否需要下采样)。
block :残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块,通过Type[Union]支持两种类型)。
channel :当前 stage 的基础通道数(瓶颈块中为降维后的通道数,基础块中为输出通道数)。
block_nums :当前 stage 包含的残差块数量(如 ResNet-50 的每个 stage 包含 3/4/6/3 个瓶颈块)。
stride :当前 stage 第一个残差块的卷积步长(控制下采样,默认 1 表示不采样)。
输出:
由多个残差块组成的nn.SequentialCell序列(可直接作为网络的一个 stage,如 ResNet 的layer1、layer2等)。
核心代码逻辑解析
1. 下采样模块(down_sample)的条件判断与创建(核心考点)
if stride != 1 or last_out_channel != channel * block.expansion:
down_sample = nn.SequentialCell([
nn.Conv2d(last_out_channel, channel * block.expansion,
kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
])
触发条件(满足任意一条即需下采样):
stride != 1:需要对特征图尺寸进行下采样(如 stride=2 时尺寸减半)。
last_out_channel != channel * block.expansion:输入通道数与当前 stage 输出通道数不一致(瓶颈块中输出通道是channel×4,基础块中是channel×1)。
下采样实现:
通过1x1 卷积调整通道数(从last_out_channel到channel×block.expansion)。
卷积步长设为stride,同步调整特征图尺寸(与主分支的 3x3 卷积步长一致)。
接 BatchNorm 层归一化,确保短路连接的输出分布稳定。
核心作用:保证短路连接(identity)的维度与主分支输出一致,使out + identity操作可行。
2. 残差块序列的构建
layers = []
# 添加第一个残差块(可能包含下采样)
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))
# 更新输入通道为当前stage的输出通道(block.expansion倍)
in_channel = channel * block.expansion
# 添加后续残差块(无下采样,stride=1,通道数已对齐)
for _ in range(1, block_nums):
layers.append(block(in_channel, channel)) # 输入通道为上一个块的输出通道
第一个块的特殊性:
传入stride和down_sample,处理当前 stage 的下采样和通道对齐(如 ResNet 中layer2的第一个块 stride=2,实现尺寸减半)。
若无需下采样(stride=1 且通道数匹配),down_sample=None,短路连接直接使用输入x。
后续块的一致性:
输入通道in_channel固定为channel×block.expansion(即上一个块的输出通道)。
不再传入stride(默认 1)和down_sample(无需下采样,通道数已对齐),所有后续块仅做恒等残差连接。
2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops
class ResNet(nn.Cell):
def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
layer_nums: List[int], num_classes: int, input_channel: int) -> None:
super(ResNet, self).__init__() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
self.norm = nn.BatchNorm2d(64)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = make_layer(64, block, 64, layer_nums[0])
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 要点2-1-3:layer4的输出通道参数‘512’的含义
self.avg_pool = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) def construct(self, x): x = self.conv1(x)
x = self.norm(x)
x = self.relu(x)
x = self.max_pool(x) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avg_pool(x,(2,3)) x = self.flatten(x)
x = self.fc(x) return x
ResNet组建类代码解析
1. 类定义与核心参数
class ResNet(nn.Cell):
def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
layer_nums: List[int], num_classes: int, input_channel: int) -> None:
super(ResNet, self).__init__()
block:残差块类型(基础块ResidualBlockBase或瓶颈块ResidualBlock),决定网络层数和计算复杂度。
layer_nums:各 stage 的残差块数量(如[3, 4, 6, 3]对应 ResNet-50)。
num_classes:分类任务的类别数(如中药材分类的 12 类)。
input_channel:全连接层输入通道数(由最后一个 stage 的输出通道决定,如瓶颈块下为512×4=2048)。
2. 主干网络结构
输入层与初始特征提取
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init) # 7x7卷积
self.norm = nn.BatchNorm2d(64) # 归一化
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') # 最大池化
7x7 卷积:输入 3 通道(RGB 图像),输出 64 通道,步长 2,初步提取特征并降采样(尺寸减半)。
最大池化:核大小 3x3,步长 2,pad_mode='same'保持空间尺寸对称减半(如 224→112→56)。
四个 stage(layer1-layer4)
self.layer1 = make_layer(64, block, 64, layer_nums[0]) # stride=1(默认)
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2) # 下采样
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 关键考点:补全layer4参数
make_layer功能:动态构建残差块序列,每个 stage:第一个块通过stride=2实现下采样(layer2-layer4),通道数翻倍(如 64→128→256→512)。
block.expansion控制通道升维(基础块 = 1,瓶颈块 = 4),例如瓶颈块下64×4=256作为下 stage 输入。
输出层
self.avg_pool = ops.ReduceMean(keep_dims=True) # 全局平均池化
self.flatten = nn.Flatten() # 展平特征
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) # 全连接分类
全局平均池化:替代全连接层前的全连接操作,减少参数数量,输出特征图尺寸为(batch, 512×expansion, 1, 1)。
全连接层:将特征映射到num_classes维空间,输出分类概率。
3. 前向传播逻辑
def construct(self, x):
x = self.conv1(x) → self.norm(x) → self.relu(x) → self.max_pool(x) # 初始特征提取
x = self.layer1(x) → self.layer2(x) → self.layer3(x) → self.layer4(x) # 四级残差块特征提取
x = self.avg_pool(x, (2, 3)) # 对空间维度(H=2, W=3,假设输入为7x7)做平均池化
x = self.flatten(x) # 展平为一维向量(shape: [batch, input_channel])
x = self.fc(x) # 分类输出
return x
空间尺寸变化:假设输入 224x224,经过conv1(stride=2)和max_pool(stride=2)后尺寸为 56x56,每层 stage 若stride=2则尺寸减半(56→28→14→7),最终layer4输出 7x7。
通道数变化:随 stage 递增(64→128→256→512),经block.expansion后瓶颈块通道数为 256→512→1024→2048。
4. 核心要点与设计原则
layer4参数补全(题目要求):输入通道为256×block.expansion(上一 stage 输出),当前 stage 基础通道512,stride=2实现最后一次下采样。
残差块类型兼容性:通过block参数支持基础块(浅层)和瓶颈块(深层),expansion自动适配通道逻辑(无需为不同块编写独立代码)。
下采样策略:每个 stage 的第一个块通过stride=2和1x1卷积调整通道 / 尺寸,保证残差连接维度匹配。
计算效率:瓶颈块通过1x1卷积降维减少 3x3 卷积计算量,使深层网络(如 ResNet-152)训练可行。
5. 代码关键作用
模块化构建:通过make_layer和残差块组合,快速搭建不同深度的 ResNet(如 50 层、101 层)。
特征提取流程:从浅层边缘检测到深层语义特征,逐层抽象,适应图像分类任务。
维度匹配:自动处理残差连接的通道和尺寸对齐,避免手动计算错误。
2.1.5 实例化resnet_xx网络
实例化resnet50
def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
input_channel: int):
model = ResNet(block, layers, num_classes, input_channel)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):
resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)
实例化resnet_xx网络代码分析
ps:代码中‘ return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)’
为什么能跨函数识别到 pretrained参数?
在 Python 中,这是因为作用域的规则 。在resnet50函数中,pretrained是该函数的参数,属于局部作用域。当调用_resnet函数时,pretrained作为参数传递给_resnet函数,所以_resnet函数能够识别并使用这个参数
1. 通用模型构建函数 _resnet
def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
input_channel: int):
model = ResNet(block, layers, num_classes, input_channel)return model
功能:通用 ResNet 模型构建接口,通过参数化残差块类型、层数、分类数等,灵活生成不同配置的 ResNet 模型。
参数解析:block:残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块)。
layers:各 stage 的残差块数量列表(如[3,4,6,3]对应 ResNet-50 的四个 stage)。
num_classes:分类任务的类别数(如中药材的 12 类)。
pretrained:是否加载预训练权重(布尔值,True表示加载)。
pretrained_ckpt:预训练权重文件路径(如"./LoadPretrainedModel/resnet50_224_new.ckpt")。
input_channel:全连接层输入维度(由最后一个 stage 的输出通道决定,如 ResNet-50 为 2048)。
2. 特定模型:ResNet-50 的封装 resnet50
def resnet50(num_classes: int = 1000, pretrained: bool = False):
resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)
功能:直接生成 ResNet-50 模型,固定了 ResNet-50 的核心配置(瓶颈块、各 stage 块数、输入通道)。
关键参数固定值:block=ResidualBlock:使用瓶颈块(Bottleneck Block),适用于深层网络(ResNet-50/101/152)。
layers=[3,4,6,3]:ResNet-50 的标准配置(四个 stage 分别包含 3、4、6、3 个瓶颈块)。
input_channel=2048:最后一个 stage 输出通道数(512 基础通道 × 瓶颈块 expansion=4)。
预训练支持:resnet50_ckpt指定了预训练权重路径(如用户需要加载 ImageNet 预训练权重,可通过pretrained=True启用)。
3. 代码设计核心价值
模块化与复用性:
_resnet作为通用构建函数,通过参数化残差块类型和层数,可扩展生成 ResNet-18(基础块 +[2,2,2,2])、ResNet-101([3,4,23,3])等变体,避免重复代码。用户友好性:
resnet50函数封装了 ResNet-50 的具体配置,用户只需指定num_classes(分类数)和pretrained(是否加载预训练),即可快速获取模型,降低使用门槛。
2.2模型初始化
要求:
对定义的ResNet50模型进行实例化
实例化一个用于12分类的resnet50模型
# 要点2-2-1: 对定义的ResNet50模型进行实例化
network = resnet50(num_classes=12)
num_class = 12
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
network.fc = fcfor param in network.get_parameters():
param.requires_grad = True
模型初始化代码解析
1. 实例化 ResNet50 模型
network = resnet50(num_classes=12)
作用:调用resnet50函数创建 ResNet50 模型实例,指定分类数为 12(如中药材的 12 类)。
内部逻辑:
resnet50函数通过_resnet生成 ResNet50 模型,默认使用瓶颈块(ResidualBlock)和标准层数[3,4,6,3],并将原 1000 类的全连接层(fc 层)初始化为 12 类输出(但需后续调整,见下文)。2. 替换全连接层适配新任务
num_class = 12 # 新任务的分类数(如中药材的12类)
in_channel = network.fc.in_channels # 获取原fc层的输入通道数(ResNet50为2048)
fc = nn.Dense(in_channels=in_channel, out_channels=num_class) # 新建12类输出的全连接层
network.fc = fc # 替换原模型的fc层
背景:预训练 ResNet50 的 fc 层通常输出 1000 类(ImageNet 任务),需替换为新任务的分类数(12 类)。
关键操作:获取原 fc 层输入维度(in_channel=2048,由 ResNet50 的全局平均池化输出决定)。
新建全连接层fc,输入维度保持 2048,输出维度改为 12。
替换原模型的 fc 层,完成模型输出适配。
3. 启用所有参数训练 -- 全量微调
for param in network.get_parameters():
param.requires_grad = True
作用:将模型所有参数的梯度计算标志(requires_grad)设为True,允许训练时更新所有参数。
场景意义:若使用预训练模型(pretrained=True),此操作表示 “端到端微调”(所有层参数均参与训练),适合新数据集与预训练数据分布差异较大的场景(如中药材分类 vs ImageNet 通用分类)。
若未使用预训练(pretrained=False),则模型从头开始训练,所有参数自然需要梯度更新。
相关文章:
MindSpore框架学习项目-ResNet药物分类-构建模型
目录 2.构建模型 2.1定义模型类 2.1.1 基础块ResidualBlockBase ResidualBlockBase代码解析 2.1.2 瓶颈块ResidualBlock ResidualBlock代码解释 2.1.3 构建层 构建层代码说明 2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现 ResNet组建类代码解析…...
ChatTempMail - AI驱动的免费临时邮箱服务
在当今数字世界中,保护在线隐私的需求日益增长。ChatTempMail应运而生,作为一款融合人工智能技术的新一代临时邮箱服务,它不仅提供传统临时邮箱的基本功能,还通过AI技术大幅提升了用户体验。 核心功能与特性 1. AI驱动的智能邮件…...
(leetcode) 力扣100 9.找到字符串中所有字母异位词(滑动窗口)
题目 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词 的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 数据范围 1 < s.length, p.length < 3 * 104 s 和 p 仅包含小写字母 样例 示例 1: 输入: s "cbaebabacd", p &quo…...
深入了解 Stable Diffusion:AI 图像生成的奥秘
一、引言 AI 艺术与图像生成技术的兴起改变了我们创造和体验视觉内容的方式。在过去几年里,深度学习模型已经能够创造出令人惊叹的艺术作品,这些作品不仅模仿了人类艺术家的风格,甚至还能创造出前所未有的新风格。在这个领域,Sta…...
场外期权平值期权 实值期权 虚值期权有什么区别?收益如何计算?
期权汇 场外期权按价值状态分为平值、虚值、实值期权。 01|实值期权对于看涨期权而言,如果行权价格低于标的市场价格,则该期权处于实值状态;对于看跌期权,如果行权价格高于标的市场价格,则处于实值状态…...
微软系统 红帽系统 网络故障排查:ping、traceroute、netstat
在微软(Windows)和红帽(Red Hat Enterprise Linux,RHEL)等系统中,网络故障排查是确保系统正常运行的重要环节。 ping、traceroute(在Windows中为tracert)和netstat是三个常用的网络…...
HOT 100 | 【子串】76.最小覆盖子串、【普通数组】53.最大子数组和、【普通数组】56.合并区间
一、【子串】76.最小覆盖子串 1. 解题思路 定义两个哈希表分别用于 t 统计字符串 t 的字符个数,另一个sub_s用于统计字符串 t 在 s 的子串里面字符出现的频率。 为了降低时间复杂度,定义一个变量t_count用于统计 t 哈希表中元素的个数。哈希表sub_s是一…...
基于CNN的猫狗图像分类系统
一、系统概述 本系统是基于PyTorch框架构建的智能图像分类系统,专门针对CIFAR-10数据集中的猫(类别3)和狗(类别5)进行分类任务。系统采用卷积神经网络(CNN)作为核心算法,结合图形用…...
《时序数据库全球格局:国产与国外主流方案的对比分析》
引言 时序数据库(Time Series Database, TSDB)是专门用于存储、查询和分析时间序列数据的数据库系统,广泛应用于物联网(IoT)、金融、工业监控、智能运维等领域。近年来,随着大数据和物联网技术的发展&…...
力扣-2.两数相加
题目描述 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这两个数都…...
富乐德传感技术盘古信息 | 锚定“未来工厂”新坐标,开启传感器制造行业数字化转型新征程
在数字化浪潮下,制造业正经历深刻变革。 传感器作为智能制造的核心基础部件,正面临着质量精度要求升级、交付周期缩短、成本管控严苛等多重挑战。传统依赖人工纸质管理、设备数据孤岛化的生产模式,已成为制约高端传感器制造突破“高精度、高…...
RT-Thread 深入系列 Part 2:RT-Thread 内核核心机制深度剖析
摘要: 本文从线程管理、调度器原理、中断处理与上下文切换、IPC 同步机制、内存管理五大核心模块出发,深入剖析 RT-Thread 内核实现细节,并辅以源码解读、流程图、时序图与性能数据。 目录 线程管理与调度器原理 1.1 线程控制块(T…...
uni-app,小程序自定义导航栏实现与最佳实践
文章目录 前言为什么需要自定义导航栏?基本实现方案1. 关闭原生导航栏2. 自定义导航栏组件结构3. 获取状态栏高度4. 样式设置 内容区域适配跨平台适配要点iOS与Android差异处理 常见导航栏效果实现1. 透明导航栏2. 滚动渐变导航栏3. 自定义返回逻辑 解决常见问题1. …...
小程序消息订阅的整个实现流程
以下是微信小程序消息订阅的完整实现流程,分为 5个核心步骤 和 3个关键注意事项: 一、消息订阅完整流程 步骤1:配置订阅消息模板 登录微信公众平台进入「功能」→「订阅消息」选择公共模板或申请自定义模板,获取模板IDÿ…...
istio in action之Gateway流量入口与安全
入口网关,简单来说,就是如何让外部世界和我们精心构建的集群内部服务顺畅地对话。在网络安全领域,有一个词叫流量入口,英文叫Ingress。这指的是那些从我们自己网络之外,比如互联网,发往我们内部网络的流量。…...
LeetCode 1722. 执行交换操作后的最小汉明距离 题解
示例: 输入:source [1,2,3,4], target [2,1,4,5], allowedSwaps [[0,1],[2,3]] 输出:1 解释:source 可以按下述方式转换: - 交换下标 0 和 1 指向的元素:source [2,1,3,4] - 交换下标 2 和 3 指向的元…...
区块链详解
1. 引言 1.1 背景 在数字化时代,信息的存储、传输和验证面临诸多挑战,如数据篡改、信任缺失、中心化风险等。区块链技术应运而生,作为一种分布式账本技术,它通过去中心化、去信任化、不可篡改等特性,为解决这些问题提…...
申能集团笔试1
目录 注意 过程 注意 必须开启摄像头和麦克风 只能用网页编程,不能用本地环境 可以用Index进行测试 过程 我还以为是编程,没想到第一次是企业人际关系、自我评价的选择题,哈哈哈有点轻松,哦对他要求不能泄漏题目,…...
机器人手臂的坐标变换:一步步计算齐次矩阵过程 [特殊字符]
大家好!今天我们来学习如何计算机器人手臂的坐标变换。别担心,我会用最简单的方式解释这个过程,就像搭积木一样简单! 一、理解问题 我们有一个机器人手臂,由多个关节组成。每个关节都有自己的坐标系,我们需要计算从世界坐标系(W)到末端执行器(P₃)的完整变换。 二、已…...
神经元和神经网络定义
在深度学习中,神经元和神经网络是构成神经网络模型的基本元素。让我们从基础开始,逐步解释它们的含义和作用。 1️⃣ 神经元是什么? 神经元是神经网络中的基本计算单元,灵感来自于生物神经系统中的神经元。每个人的脑中有数以亿…...
Vue——Axios
一、Axios 是什么 Axios 是一个基于 promise 网络请求库,作用于 node.js 和浏览器中。 它是 isomorphic 的 ( 即同一套代 码可以运行在浏览器和 node.js 中 ) 。在服务端它使用原生 node.js http 模块 , 而在客户端 ( 浏览端 ) 则使 用 XMLHttpRequest…...
力扣:轮转数组
题目 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 例子 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出: [5,6,7,1,2,3,4] 解释: 向右轮转 1 步: [7,1,2,3,4,5,6] 向右轮转 2 步: [6,7,1,2,3,4,5] 向右轮转 3 步: [5…...
TCP/IP协议的体系结构
文章目录 前言数据链路层网络层传输层应用层 前言 TCP/IP通信体系主要分为四个层次,从底至上分别为: 数据链路层 >网络层 > 传输层 >应用层 该体系的工作原理主要依靠封装与分用的使用完成对信息的传递与解析。 1. 所谓封装,就是上层…...
Vue3 中 ref 与 reactive 的区别及底层原理详解
一、核心区别 1. 数据类型与使用场景 • ref 可定义基本类型(字符串、数字、布尔值)和对象类型的响应式数据。对于对象类型,ref 内部会自动调用 reactive 将其转换为响应式对象。 语法特点:需通过 .value 访问或修改数据&#…...
MySQL 与 Elasticsearch 数据一致性方案
MySQL 与 Elasticsearch 数据一致性方案 前言一、同步双写(Synchronous Dual Write)🔄二、异步双写(Asynchronous Dual Write)📤三、定时同步(Scheduled Synchronization)ǵ…...
rust-candle学习笔记11-实现一个简单的自注意力
参考:about-pytorch 定义ScaledDotProductAttention结构体: use candle_core::{Result, Device, Tensor}; use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: …...
RabbitMQ-运维
文章目录 前言运维-集群介绍多机多节点单机多节点 多机多节点下载配置hosts⽂件配置Erlang Cookie启动节点构建集群查看集群状态 单机多节点安装启动两个节点再启动两个节点验证RabbitMQ启动成功搭建集群把rabbit2, rabbit3添加到集群 宕机演示仲裁队列介绍raft算法协议 raft基…...
101 alpha——8 学习
alpha (-1 * rank(((sum(open, 5) * sum(returns, 5)) - delay((sum(open, 5) * sum(returns, 5)),这里我们操作符都明白,现在来看金融意义 金融意义 里层是这个 (sum(open, 5) * sum(returns, 5)) - delay((sum(open, 5) * sum(returns, 5)), 10 这里是两个相减…...
YOLOv1模型架构、损失值、NMS极大值抑制
文章目录 前言一、YOLO系列v11、核心思想2、流程解析 二、损失函数1、位置误差2、置信度误差3、类别概率损失 三、NMS(非极大值抑制)总结YOLOv1的优缺点 前言 YOLOv1(You Only Look Once: Unified, Real-Time Object Detection)由…...
webpack代理天地图瓦片
1.安装 npm install http-proxy-middleware --save-dev2.webpack代理 const { createProxyMiddleware } require(http-proxy-middleware);module.exports {devServer: {port: 8080, // 改为你需要的端口https: false, // 如果你启用了 https,这里要对应before(a…...
RabbitMQ-高级特性1
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言消息确认机制介绍手动确认方法代码前言代码编写消息确认机制的演示自动确认automanual 持久化介绍交换机持久化队列持久化消息持久化 持久化代码持久化代码演示…...
Git_idea界面进行分支合并到主分支详细操作
最近闲着也是闲着,再来讲一下Git合并分支的操作吧。基本上咱们干开发的都会用到git吧,比如我们在大数据开发中,有一个主分支master,还有其他的诸多分支dev1.1.0,dev1.2.0......等。 以我近期开发的代码来讲,在开发分支开发完毕后&…...
MOS关断时波形下降沿振荡怎么解决
问题阐述: 一个直流电机控制电路,部分原理图如下: 波形如下: 原因分析: L:线路寄生电感(如PCB走线、MOS管引脚电感)。 C:MOS管输出电容(Coss)、…...
【Day 23】HarmonyOS开发实战:从AR应用到元宇宙交互
一、空间感知开发实战 1. 环境语义建模(NEXT增强) // 构建3D空间语义地图 spatialMapper.createMap({mode: SEMANTIC, // 语义分割模式objectTypes: [WALL, FLOOR, TABLE, DOOR ],onUpdate: (mesh) > {this.arScene.updateMesh(mesh) // 实时更新3D…...
ZYNQ笔记(十九):VDMA VGA 输出分辨率可调
版本:Vivado2020.2(Vitis) 任务:以 VDAM IP 为核心实现 VGA 彩条图像显示,同时支持输出分辨率可调。 (PS 端写入彩条数据到 DDR 通过 VDMA 读取出来输出给 VGA 进行显示) 目录 一、介绍 二、硬…...
江西同为科技有限公司受邀参展2025长江流域跨博会
江西同为科技有限公司是一家专注于电力保护设备研发与生产的高新技术企业,深耕于电气联接与保护领域,同时产品远销海外,在国内国际市场与客户保持长期稳定的合作。江西同为在跨境电子商务领域运营多年,有着深厚、丰富的行业经验&a…...
2025 SD省集总结
文章目录 DAY1时间安排题解T1. 花卉港湾T2. 礎石花冠T3.磷磷开花 DAY2时间安排题解T1. MEX 求和T2.最大异或和T3.前缀最值 DAY3时间安排题解T1.重建: 地下铁道T2.走过安眠地的花丛T3.昔在、今在、永在的题目 DAY4时间安排题解T1.崩坏世界的歌姬T2.色彩褪去之后T3.每个人的结局 …...
代码随想论图论part06冗余连接
图论part06 冗余连接 代码随想录 冗余边就是已经边已经在并查集里了,从图的角度来说构成了环(冗余连接2要用到这个概念) 代码其他部分为:并查集初始化,查根,判断是否在集合里,加入集合 冗余…...
SCADA|KIO程序导出变量错误处理办法
哈喽,你好啊,我是雷工! 最近在用KingSCADA3.52版本的软件做程序时,在导出变量进行批量操作时遇到问题,现将解决办法记录如下。 以下为解决过程。 01 问题描述 在导出KIO变量时,选择*.xls格式和*.xlsx时均会报错: 报如下错误: Unknown error 0x800A0E7A ADODB Connectio…...
AUTOSAR图解==>AUTOSAR_SWS_V2XBasicTransport
AUTOSAR V2X 基础传输协议 (V2XBasicTransport) 详解 AUTOSAR经典平台中V2X通信基础传输层的规范解析 目录 1. 引言与功能概述 1.1 架构概述1.2 功能概述 2. V2XBtp模块架构 2.1 AUTOSAR架构中的V2XBtp位置2.2 主要组件与职责 3. V2XBtp模块接口 3.1 接口结构3.2 数据类型和依…...
从代码学习深度学习 - 区域卷积神经网络(R-CNN)系列 PyTorch版
文章目录 前言R-CNNFast R-CNN兴趣区域汇聚层 (RoI Pooling)代码示例:兴趣区域汇聚层 (RoI Pooling) 的计算方法Faster R-CNNMask R-CNN双线性插值 (Bilinear Interpolation) 与兴趣区域对齐 (RoI Align)兴趣区域对齐层的输入输出全卷积网络 (FCN) 的作用掩码输出形状总结前言…...
RT-THREAD RTC组件中Alarm功能驱动完善
使用Rt-Thread的目的为了更快的搭载工程,使用Rt-Thread丰富的组件和第三方包资源,解耦硬件,在更换芯片时可以移植应用层代码。你是要RTT的目的什么呢? 文章项目背景 以STM32L475RCT6为例 RTC使用的为LSE外部低速32 .756k Hz 的…...
VSCode如何解决打开html页面中文乱码的问题
VSCode如何解决打开html页面中文乱码的问题 (1)打开扩展商店: (2)点击左侧菜单栏的扩展图标(或使用快捷键CtrlShiftX)。 (3)搜索并安装插件: …...
Java学习手册:单体架构到微服务演进
一、单体架构概述 单体架构是一种传统的软件架构风格,所有的功能模块都构建在一个统一的部署单元中。这种架构的优点是简单直接,便于开发、测试和部署。然而,随着应用规模的增长和需求的复杂化,单体架构的弊端逐渐显现࿰…...
android动态调试
在 Android 应用逆向工程中,动态调试 Smali 代码是分析应用运行时行为的重要手段。以下是详细的步骤和注意事项: 1. 准备工作 工具准备: Apktool:反编译 APK 生成 Smali 代码。Android Studio/IntelliJ IDEA:安装 smal…...
Google的A2A和MCP什么关系
作者:蛙哥 原文:https://zhuanlan.zhihu.com/p/1893738350252385035 Agent2Agent和MCP在功能上各有侧重,A2A专注于Agent之间的协作,MCP关注于Agent与外部数据源的集成。因此,MCP并不完全覆盖 A2A 的能力场景࿰…...
计算几何图形算法经典问题整理
几何算法经典问题 文章目录 几何算法经典问题一、几何基础问题1. 判断两条线段是否相交2. 判断点是否在多边形内3. 凸包计算4. 判断一个有序点集的方向(顺时针 or 逆时针)5. 求多边形面积和重心 二、 高阶图形问题6. 最小外接矩形(Minimum Bo…...
系分论文《论多云架构治理的分析和应用》
系统分析师论文范文系列 【摘要】 2022年3月,我所在公司承接了某金融机构“混合云资源管理与优化平台”的设计与实施项目。我作为系统分析师,主导了多云架构的规划与治理工作。该项目旨在构建一个兼容多家公有云及私有云资源的统一管理平台,解…...
(三)毛子整洁架构(Infrastructure层/DapperHelper/乐观锁)
文章目录 项目地址一、Infrastructure Layer1.1 创建Application层需要的服务1. Clock服务2. Email 服务3. 注册服务 1.2 数据库服务1. 表配置Configurations2. Respository实现3. 数据库链接Factory实现4. Dapper的DataOnly服务实现5. 所有数据库服务注册 1.3 基于RowVersion的…...
Femap许可使用数据分析
在当今竞争激烈的市场环境中,企业对资源使用效率和成本控制的关注日益增加。Femap作为一款业界领先的有限元分析软件,其许可使用数据分析功能为企业提供了深入洞察和智能决策的支持。本文将详细介绍Femap许可使用数据分析工具的特点、优势以及如何应用这…...