Dilateformer实战:使用Dilateformer实现图像分类任务(二)
文章目录
- 训练部分
- 导入项目使用的库
- 设置随机因子
- 设置全局参数
- 图像预处理与增强
- 读取数据
- 设置Loss
- 设置模型
- 设置优化器和学习率调整策略
- 设置混合精度,DP多卡,EMA
- 定义训练和验证函数
- 训练函数
- 验证函数
- 调用训练和验证方法
- 运行以及结果查看
- 测试
- 完整的代码
在上一篇文章中完成了前期的准备工作,见链接:
Dilateformer实战:使用Dilateformer实现图像分类任务(一)
前期的工作主要是数据的准备,安装库文件,数据增强方式的讲解,模型的介绍和实验效果等内容。接下来,这篇主要是讲解如何训练和测试
训练部分
完成上面的步骤后,就开始train脚本的编写,新建train.py
导入项目使用的库
在train.py导入
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.dilateformer import dilateformer_tiny
from torchvision import datasetstorch.backends.cudnn.benchmark = False
import warnings
torch.autograd.set_detect_anomaly(True)
# warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
当您需要在具有多个GPU的机器上指定用于训练的GPU时,可以通过设置环境变量CUDA_VISIBLE_DEVICES
来实现。这个环境变量的值是一个由逗号分隔的GPU索引列表,索引从0开始。例如,如果您的机器上有8块GPU,并且您希望仅使用前两块GPU(即索引为0和1的GPU)进行训练,您应该设置:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
这样,只有索引为0和1的GPU会被系统识别并用于训练。类似地,如果您希望使用第三块(索引为2)和第六块(索引为5)GPU进行训练,您应该相应地设置:
os.environ['CUDA_VISIBLE_DEVICES'] = "2,5"
通过这种方式,您可以灵活地选择任意数量的GPU进行训练,而无需担心其他GPU的干扰。
设置随机因子
def seed_everything(seed=42):# 设置Python的哈希种子os.environ['PYTHONHASHSEED'] = str(seed)# 设置PyTorch的CPU随机种子torch.manual_seed(seed)# 如果使用CUDA,设置CUDA的随机种子if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed) # 如果你的代码在多个GPU上运行# 启用CUDA的确定性行为(对卷积等操作的确定性有帮助)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True# 使用示例
seed_everything(42)
这里有一些额外的说明和注意事项:
-
torch.cuda.manual_seed_all(seed)
:这个调用是可选的,但如果你在多GPU环境中工作(比如使用DataParallel
或DistributedDataParallel
),它确保所有GPU上的随机操作都将从相同的种子开始。如果你的代码只在一个GPU上运行,这个调用不是必需的,但也不会造成问题。 -
torch.backends.cudnn.benchmark = False
:当设置为True
时,cuDNN会在运行时自动选择算法来优化性能。然而,这可能会导致每次运行时的行为不完全相同,因为算法的选择可能会基于输入数据的形状和大小而变化。为了实验的可重复性,最好将其设置为False
。 -
图片加载顺序:虽然设置随机种子有助于确保模型的随机操作(如初始化权重、dropout等)是可重复的,但它本身并不直接控制图片加载的顺序。图片加载顺序通常由数据集加载器(如
DataLoader
)的shuffle
参数控制。如果你想要固定的加载顺序,确保在创建DataLoader
时将shuffle=False
。 -
其他随机性来源:请注意,即使你设置了这些随机种子,还可能存在其他随机性来源,如操作系统级别的调度或硬件层面的差异(如GPU的浮点精度差异)。在极端情况下,这些差异可能会影响结果的精确可重复性。然而,在大多数情况下,上述设置应该足以确保实验在相同的软件和环境配置下是可重复的。
设置全局参数
if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'checkpoints/Dilateformer/'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数model_lr = 1e-4BATCH_SIZE = 16EPOCHS = 300DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')use_amp = True # 是否使用混合精度use_dp = True # 是否开启dp方式的多卡训练classes = 12resume = NoneCLIP_GRAD = 5.0Best_ACC = 0 # 记录最高得分use_ema = Falsemodel_ema_decay = 0.9998start_epoch = 1seed = 1seed_everything(seed)
创建一个名为 ‘checkpoints/Dilateformer/’ 的文件夹,用于保存训练过程中的模型。如果该文件夹已经存在,则不会再次创建,否则会创建该文件夹。
设置训练模型的全局参数,包括学习率、批次大小、训练轮数、设备选择(是否使用 GPU)、是否使用混合精度、是否开启数据并行等。
注:建议使用GPU,CPU太慢了。
参数的详细解释:
model_lr:学习率,根据实际情况做调整。
BATCH_SIZE:batchsize,根据显卡的大小设置。
EPOCHS:epoch的个数,一般300够用。
use_amp:是否使用混合精度。
use_dp :是否开启dp方式的多卡训练?如果您打算使用多GPU训练将use_dp 设置为 True。
classes:类别个数。
resume:再次训练的模型路径,如果不为None,则表示加载resume指向的模型继续训练。
CLIP_GRAD:梯度的最大范数,在梯度裁剪里设置。
Best_ACC:记录最高ACC得分。
use_ema:是否使用ema,如果没有使用预训练模型,直接打开use_ema会造成不上分的情况。可以先关闭ema训练几个epoch,然后,将训练的权重赋值到resume,再将启用ema
model_ema_decay:设置了EMA的衰减率。衰减率决定了当前模型权重和之前的EMA权重在更新新的EMA权重时的相对贡献。具体来说,每次更新EMA权重时,都会按照以下公式进行:
newemaweight = decay × oldemaweight + ( 1 − decay ) × currentmodelweight \text{newemaweight} = \text{decay} \times \text{oldemaweight} + (1 - \text{decay}) \times \text{currentmodelweight} newemaweight=decay×oldemaweight+(1−decay)×currentmodelweight
例如,衰减率被设置为0.9998。这意味着在更新EMA权重时,大约99.98%的权重来自之前的EMA权重,而剩下的0.02%来自当前的模型权重。由于衰减率非常接近1,EMA权重会更多地依赖于之前的EMA权重,而不是当前的模型权重。这有助于平滑模型权重的波动,并减少噪声对最终模型性能的影响。
start_epoch:开始的epoch,默认是1,如果重新训练时,需要给start_epoch重新赋值。
SEED:随机因子,数值可以随意设定,但是设置后,不要随意更改,更改后,图片加载的顺序会改变,影响测试结果。
file_dir = 'checkpoints/Dilateformer/'
这是存放Dilateformer模型的路径。
图像预处理与增强
# 数据预处理7transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])])mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=classes)
数据处理和增强比较简单,加入了随机10度的旋转、高斯模糊、色彩饱和度明亮度的变化、Mixup等比较常用的增强手段,做了Resize和归一化。
transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])
这里设置为计算mean和std。
这里注意下Resize的大小,由于选用的模型输入是224×224的大小,所以要Resize为224×224。
数据预处理流程结合了多种常用的数据增强技术,包括随机旋转、高斯模糊、色彩抖动(ColorJitter)、Resize以及归一化,还引入了Mixup和可能的CutMix技术来进一步增强模型的泛化能力。参数详解:
- transforms.RandomRotation(10): 随机旋转图像最多10度,有助于模型学习旋转不变性。
- transforms.GaussianBlur(kernel_size=(5,5), sigma=(0.1, 3.0)): 应用高斯模糊,模拟图像的模糊情况,增强模型对模糊图像的鲁棒性。
- transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5): 调整图像的亮度、对比度和饱和度,增加数据的多样性。
- transforms.Resize((224, 224)): 将图像大小调整为224x224,以符合模型的输入要求。
- transforms.ToTensor(): 将PIL Image或NumPy ndarray转换为FloatTensor,并归一化到[0.0, 1.0]。
- transforms.Normalize(mean, std): 使用指定的均值和标准差对图像进行归一化处理,有助于模型训练。
mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=classes)
定义了一个 Mixup 函数。Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
Mixup 是一种正则化技术,通过混合输入数据和它们的标签来增强模型的泛化能力。在您的代码中,Mixup 类还包含了 CutMix 的参数,但具体实现可能需要根据您使用的库(如 timm 或自定义实现)来确定。参数详解:
mixup_alpha: Mixup 中用于Beta分布的α参数,控制混合强度的分布。 cutmix_alpha: CutMix
中用于Beta分布的α参数,同样控制混合强度的分布。 cutmix_minmax: CutMix 中裁剪区域的最小和最大比例,但在这里设为
None,可能表示使用默认的或根据 cutmix_alpha 自动计算的比例。 prob: 应用Mixup或CutMix的概率。
switch_prob: 在Mixup和CutMix之间切换的概率(如果Mixup和CutMix都被启用)。 mode:
指定Mixup是在整个批次上进行还是在单个样本之间进行。 label_smoothing: 标签平滑参数,用于减少模型对硬标签的过度自信。
num_classes: 类别数,用于标签平滑计算。
读取数据
# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE,num_workers=8,pin_memory=True,shuffle=True,drop_last=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
-
使用pytorch默认读取数据的方式,然后将dataset_train.class_to_idx打印出来,预测的时候要用到。
-
对于train_loader ,drop_last设置为True,因为使用了Mixup数据增强,必须保证每个batch里面的图片个数为偶数(不能为零),如果最后一个batch里面的图片为奇数,则会报错,所以舍弃最后batch的迭代,pin_memory设置为True,可以加快运行速度,num_workers多进程加载图像,不要超过CPU 的核数。
-
将dataset_train.class_to_idx保存到txt文件或者json文件中。
class_to_idx的结果:
{'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
设置Loss
# 设置loss函数
# 训练的loss函数为SoftTargetCrossEntropy,用于处理具有软目标(soft targets)的训练场景
criterion_train = SoftTargetCrossEntropy() # 验证的loss函数为nn.CrossEntropyLoss(),适用于多分类问题的标准交叉熵损失
criterion_val = torch.nn.CrossEntropyLoss()
设置loss函数,训练的loss为:SoftTargetCrossEntropy,验证的loss:nn.CrossEntropyLoss()。
设置模型
# 设置模型model_ft = dilateformer_tiny(pretrained=False)print(model_ft)num_fr = model_ft.head.in_featuresmodel_ft.head = nn.Linear(num_fr, classes)nn.init.xavier_uniform_(model_ft.head.weight)print(model_ft)if resume:model = torch.load(resume)print(model['state_dict'].keys())model_ft.load_state_dict(model['state_dict'])Best_ACC = model['Best_ACC']start_epoch = model['epoch'] + 1model_ft.to(DEVICE)
-
设置模型为dilateformer_tiny,然后,找到head的in_features,修改为数据集的类别,也就是classes。
-
如果resume设置为已经训练的模型的路径,则加载模型接着resume指向的模型接着训练,使用模型里的Best_ACC初始化Best_ACC,使用epoch参数初始化start_epoch。
-
如果模型输出是classes的长度,则表示修改正确了。
设置优化器和学习率调整策略
# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-6)
- 优化器设置为adamW。
- 学习率调整策略选择为余弦退火。
设置混合精度,DP多卡,EMA
if use_amp:scaler = torch.cuda.amp.GradScaler()if torch.cuda.device_count() > 1 and use_dp:print("Let's use", torch.cuda.device_count(), "GPUs!")model_ft = torch.nn.DataParallel(model_ft)if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device=DEVICE,resume=resume)else:model_ema=None
定义训练和验证函数
训练函数
class ConvAWS2d(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))def _get_weight(self, weight):weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)weight = weight - weight_meanstd = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)weight = weight / stdweight = self.weight_gamma * weight + self.weight_betareturn weightdef forward(self, x):weight = self._get_weight(self.weight)return super()._conv_forward(x, weight, None)def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs):self.weight_gamma.data.fill_(-1)super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,missing_keys, unexpected_keys, error_msgs)if self.weight_gamma.data.mean() > 0:returnweight = self.weight.dataweight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,keepdim=True).mean(dim=3, keepdim=True)self.weight_beta.data.copy_(weight_mean)std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)self.weight_gamma.data.copy_(std)class SAConv2d(ConvAWS2d):def __init__(self,in_channels,out_channels,kernel_size,s=1,p=None,g=1,d=1,act=True,bias=True):super().__init__(in_channels,out_channels,kernel_size,stride=s,padding=autopad(kernel_size, p, d),dilation=d,groups=g,bias=bias)self.switch = torch.nn.Conv2d(self.in_channels,1,kernel_size=1,stride=s,bias=True)self.switch.weight.data.fill_(0)self.switch.bias.data.fill_(1)self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))self.weight_diff.data.zero_()self.pre_context = torch.nn.Conv2d(self.in_channels,self.in_channels,kernel_size=1,bias=True)self.pre_context.weight.data.fill_(0)self.pre_context.bias.data.fill_(0)self.post_context = torch.nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,bias=True)self.post_context.weight.data.fill_(0)self.post_context.bias.data.fill_(0)self.bn = nn.BatchNorm2d(out_channels)self.act = Conv.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):# pre-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)avg_x = self.pre_context(avg_x)avg_x = avg_x.expand_as(x)x = x + avg_x# switchavg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)switch = self.switch(avg_x)# sacweight = self._get_weight(self.weight)out_s = super()._conv_forward(x, weight, None)ori_p = self.paddingori_d = self.dilationself.padding = tuple(3 * p for p in self.padding)self.dilation = tuple(3 * d for d in self.dilation)weight = weight + self.weight_diffout_l = super()._conv_forward(x, weight, None)out = switch * out_s + (1 - switch) * out_lself.padding = ori_pself.dilation = ori_d# post-contextavg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)avg_x = self.post_context(avg_x)avg_x = avg_x.expand_as(out)out = out + avg_xreturn self.act(self.bn(out))
训练的主要步骤:
1、使用AverageMeter保存自定义变量,包括loss,ACC1,ACC5。
2、进入循环,将data和target放入device上,non_blocking设置为True。如果pin_memory=True的话,将数据放入GPU的时候,也应该把non_blocking打开,这样就只把数据放入GPU而不取出,访问时间会大大减少。
如果pin_memory=False时,则将non_blocking设置为False。
3、将数据输入mixup_fn生成mixup数据。
4、将第三部生成的mixup数据输入model,输出预测结果,然后再计算loss。
5、 optimizer.zero_grad() 梯度清零,把loss关于weight的导数变成0。
6、如果使用混合精度,则
- with torch.cuda.amp.autocast(),开启混合精度。
- 计算loss。torch.nan_to_num将输入中的NaN、正无穷大和负无穷大替换为NaN、posinf和neginf。默认情况下,nan会被替换为零,正无穷大会被替换为输入的dtype所能表示的最大有限值,负无穷大会被替换为输入的dtype所能表示的最小有限值。
- scaler.scale(loss).backward(),梯度放大。
- torch.nn.utils.clip_grad_norm_,梯度裁剪,放置梯度爆炸。
- scaler.step(optimizer) ,首先把梯度值unscale回来,如果梯度值不是inf或NaN,则调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新。
- 更新下一次迭代的scaler。
否则,直接反向传播求梯度。torch.nn.utils.clip_grad_norm_函数执行梯度裁剪,防止梯度爆炸。
7、如果use_ema为True,则执行model_ema的updata函数,更新模型。
8、 torch.cuda.synchronize(),等待上面所有的操作执行完成。
9、接下来,更新loss,ACC1,ACC5的值。
等待一个epoch训练完成后,计算平均loss和平均acc
验证函数
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):global Best_ACCmodel.eval()loss_meter = AverageMeter()acc1_meter = AverageMeter()acc5_meter = AverageMeter()total_num = len(test_loader.dataset)print(total_num, len(test_loader))val_list = []pred_list = []for data, target in test_loader:for t in target:val_list.append(t.data.item())data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)output = model(data)loss = criterion_val(output, target)_, pred = torch.max(output.data, 1)for p in pred:pred_list.append(p.data.item())acc1, acc5 = accuracy(output, target, topk=(1, 5))loss_meter.update(loss.item(), target.size(0))acc1_meter.update(acc1.item(), target.size(0))acc5_meter.update(acc5.item(), target.size(0))acc = acc1_meter.avgprint('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(loss_meter.avg, acc, acc5_meter.avg))if acc > Best_ACC:if isinstance(model, torch.nn.DataParallel):torch.save(model.module, file_dir + '/' + 'best.pth')else:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accif isinstance(model, torch.nn.DataParallel):state = {'epoch': epoch,'state_dict': model.module.state_dict(),'Best_ACC': Best_ACC}if use_ema:state['state_dict_ema'] = model.module.state_dict()torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')else:state = {'epoch': epoch,'state_dict': model.state_dict(),'Best_ACC': Best_ACC}if use_ema:state['state_dict_ema'] = model.state_dict()torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')return val_list, pred_list, loss_meter.avg, acc
验证集和训练集大致相似,主要步骤:
1、在val的函数上面添加@torch.no_grad(),作用:所有计算得出的tensor的requires_grad都自动设置为False。即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。
2、定义参数:
loss_meter: 测试的loss
acc1_meter:top1的ACC。
acc5_meter:top5的ACC。
total_num:总的验证集的数量。
val_list:验证集的label。
pred_list:预测的label。
3、进入循环,迭代test_loader:将label保存到val_list。
将data和target放入device上,non_blocking设置为True。
将data输入到model中,求出预测值,然后输入到loss函数中,求出loss。
调用torch.max函数,将预测值转为对应的label。
将输出的预测值的label存入pred_list。
调用accuracy函数计算ACC1和ACC5
更新loss_meter、acc1_meter、acc5_meter的参数。
4、本次epoch循环完成后,求得本次epoch的acc、loss。
5、接下来是保存模型的逻辑
如果ACC比Best_ACC高,则保存best模型
判断模型是否为DP方式训练的模型。如果是DP方式训练的模型,模型参数放在model.module,则需要保存model.module。
否则直接保存model。
注:保存best模型,我们采用保存整个模型的方式,这样保存的模型包含网络结构,在预测的时候,就不用再重新定义网络了。6、接下来保存每个epoch的模型。
判断模型是否为DP方式训练的模型。如果是DP方式训练的模型,模型参数放在model.module,则需要保存model.module.state_dict()。
新建个字典,放置Best_ACC、epoch和 model.module.state_dict()等参数。然后将这个字典保存。判断是否是使用EMA,如果使用,则还需要保存一份ema的权重。
否则,新建个字典,放置Best_ACC、epoch和 model.state_dict()等参数。然后将这个字典保存。判断是否是使用EMA,如果使用,则还需要保存一份ema的权重。注意:对于每个epoch的模型只保存了state_dict参数,没有保存整个模型文件。
调用训练和验证方法
# 训练与验证is_set_lr = Falselog_dir = {}train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []if resume and os.path.isfile(file_dir+"result.json"):with open(file_dir+'result.json', 'r', encoding='utf-8') as file:logs = json.load(file)train_acc_list = logs['train_acc']train_loss_list = logs['train_loss']val_acc_list = logs['val_acc']val_loss_list = logs['val_loss']epoch_list = logs['epoch_list']for epoch in range(start_epoch, EPOCHS + 1):epoch_list.append(epoch)log_dir['epoch_list'] = epoch_listtrain_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)train_loss_list.append(train_loss)train_acc_list.append(train_acc)log_dir['train_acc'] = train_acc_listlog_dir['train_loss'] = train_loss_listif use_ema:val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)else:val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)val_loss_list.append(val_loss)val_acc_list.append(val_acc)log_dir['val_acc'] = val_acc_listlog_dir['val_loss'] = val_loss_listlog_dir['best_acc'] = Best_ACCwith open(file_dir + '/result.json', 'w', encoding='utf-8') as file:file.write(json.dumps(log_dir))print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))if epoch < 600:cosine_schedule.step()else:if not is_set_lr:for param_group in optimizer.param_groups:param_group["lr"] = 1e-6is_set_lr = Truefig = plt.figure(1)plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')# 显示图例plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')plt.legend(["Train Loss", "Val Loss"], loc="upper right")plt.xlabel(u'epoch')plt.ylabel(u'loss')plt.title('Model Loss ')plt.savefig(file_dir + "/loss.png")plt.close(1)fig2 = plt.figure(2)plt.plot(epoch_list, train_acc_list, 'r-', label=u'Train Acc')plt.plot(epoch_list, val_acc_list, 'b-', label=u'Val Acc')plt.legend(["Train Acc", "Val Acc"], loc="lower right")plt.title("Model Acc")plt.ylabel("acc")plt.xlabel("epoch")plt.savefig(file_dir + "/acc.png")plt.close(2)
调用训练函数和验证函数的主要步骤:
1、定义参数:
- is_set_lr,是否已经设置了学习率,当epoch大于一定的次数后,会将学习率设置到一定的值,并将其置为True。
- log_dir:记录log用的,将有用的信息保存到字典中,然后转为json保存起来。
- train_loss_list:保存每个epoch的训练loss。
- val_loss_list:保存每个epoch的验证loss。
- train_acc_list:保存每个epoch的训练acc。
- val_acc_list:保存么每个epoch的验证acc。
- epoch_list:存放每个epoch的值。
如果是接着上次的断点继续训练则读取log文件,然后把log取出来,赋值到对应的list上。
循环epoch1、调用train函数,得到 train_loss, train_acc,并将分别放入train_loss_list,train_acc_list,然后存入到logdir字典中。
2、调用验证函数,判断是否使用EMA?
如果使用EMA,则传入model_ema.ema,否则,传入model_ft。得到val_list, pred_list, val_loss, val_acc。将val_loss, val_acc分别放入val_loss_list和val_acc_list中,然后存入到logdir字典中。3、保存log。
4、打印本次的测试报告。
5、如果epoch大于600,将学习率设置为固定的1e-6。
6、绘制loss曲线和acc曲线。
运行以及结果查看
完成上面的所有代码就可以开始运行了。点击右键,然后选择“run train.py”即可,运行结果如下:
在每个epoch测试完成之后,打印验证集的acc、recall等指标。
Dilateformer测试结果:
测试
测试,我们采用一种通用的方式。
测试集存放的目录如下图:
Dilateformer_Demo
├─test
│ ├─1.jpg
│ ├─2.jpg
│ ├─3.jpg
│ ├ ......
└─test.py
import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import osclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat', 'Fat Hen', 'Loose Silky-bent','Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=torch.load('checkpoints/Dilateformer/best.pth')
model.eval()
model.to(DEVICE)path = 'test/'
testList = os.listdir(path)
for file in testList:img = Image.open(path + file)img = transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out = model(img)# Predict_, pred = torch.max(out.data, 1)print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))
测试的主要逻辑:
1、定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!
2、定义transforms,transforms和验证集的transforms一样即可,别做数据增强。
3、 torch.load加载model,然后将模型放在DEVICE里,
4、循环 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。循环里面的主要逻辑:
- 使用Image.open读取图片
- 使用transform_test对图片做归一化和标椎化。
- img.unsqueeze_(0) 增加一个维度,由(3,224,224)变为(1,3,224,224)
- Variable(img).to(DEVICE):将数据放入DEVICE中。
- model(img):执行预测。
- _, pred = torch.max(out.data, 1):获取预测值的最大下角标。
运行结果:
完整的代码
完整的代码:
相关文章:
Dilateformer实战:使用Dilateformer实现图像分类任务(二)
文章目录 训练部分导入项目使用的库设置随机因子设置全局参数图像预处理与增强读取数据设置Loss设置模型设置优化器和学习率调整策略设置混合精度,DP多卡,EMA定义训练和验证函数训练函数验证函数调用训练和验证方法 运行以及结果查看测试完整的代码 在上…...
Kubernetes 镜像拉取策略全解析:如何根据需求选择最佳配置?
在Kubernetes集群里,拉取容器镜像是一个非常关键的步骤。这些镜像包含了应用程序及其所有需要的依赖项,Kubernetes通过拉取这些镜像来启动Pod中的容器。为了提升集群的稳定性、速度和安全性,Kubernetes提供了几种不同的镜像拉取策略。这篇文章…...
上海AI中心记录
1、js事件循环 调用栈(Call Stack): JavaScript 是单线程的,所有的代码执行都是在调用栈中进行的。当函数被调用时,进入栈中;执行完毕后,从栈中弹出。 任务队列(Task Queueÿ…...
oscp学习之路,Kioptix Level2靶场通关教程
oscp学习之路,Kioptix Level2靶场通关教程 靶场下载:Kioptrix Level 2.zip 链接: https://pan.baidu.com/s/1gxVRhrzLW1oI_MhcfWPn0w?pwd1111 提取码: 1111 搭建好靶场之后输入ip a看一下攻击机的IP。 确定好本机IP后,使用nmap扫描网段&…...
基于SpringBoot的蜗牛兼职网的设计与实现
一、项目背景 随着社会的快速发展,计算机的影响是全面且深入的。人们生活水平的不断提高,日常生活中人们对蜗牛兼职网方面的要求也在不断提高,需要兼职工作的人数更是不断增加,使得蜗牛兼职网的开发成为必需而且紧迫的事情。蜗牛…...
C++软件设计模式之装饰器模式
装饰器模式(Decorator Pattern)是C软件设计模式中的一种结构型设计模式,主要用于解决在不改变现有对象结构的情况下动态地给对象添加新功能的问题。通过使用装饰器模式,可以在运行时为对象添加新的行为,而不需要修改其…...
Spring Boot项目接收前端参数的11种方式
大家好,我是袁庭新。在前后端项目交互中,前端传递的数据可以通过HTTP请求发送到后端, 后端在Spring Boot中如何接收各种复杂的前端数据呢?这篇文章总结了11种在Spring Boot中接收前端数据的方式。 1 搭建项目 1.通过Spring Init…...
通过GRE协议组建VPN网络
GRE(Generic Routing Encapsulation,通用路由封装协议)协议是一种简单而有效的封装协议,它在网络中的广泛应用,比如在构建VPN网络。 GRE是一种封装协议,它允许网络层协议(如IP)的…...
you-get使用cookies下载B站视频
B站视频更换BV号以后,使用you-get不能下载了。 首先更新你的you-get pip install --upgrade you-get 更新完成后再次使用you-get -u 命令会显示使用cookies才能下载更多清晰度的视频 使用Edge浏览器,添加插件 Cookie-Editor 点击上图的导出按钮&am…...
使用Excel制作通达信自定义“序列数据“
序列数据的视频教程演示 Excel制作通达信自定义序列数据 1.序列数据的制作方法:删掉没有用的数据(行与列)和股代码格式处理,是和外部数据的制作方法是相同,自己上面看历史博文。只需要判断一下,股代码跟随的…...
基于 Nginx 的网站服务器与 LNMP 平台搭建指南
一,Nginx概述 (一)Nginx的作用 Nginx在网络服务器架构中扮演着多面的角色。其初始设定专注于静态网络数据的处理,能高效地为用户提供诸如HTML,CSS,JavaScript等静态资源。当面对动态数据时,借助php - fpm模块,Nginx能够解析php源代码,实现动态页面的生成与展示。在处理…...
OpenCV计算机视觉 03 椒盐噪声的添加与常见的平滑处理方式(均值、方框、高斯、中值)
上一篇文章:OpenCV计算机视觉 02 图片修改 图像运算 边缘填充 阈值处理 添加椒盐噪声 def add_peppersalt_noise(image, n10000):result image.copy()h, w image.shape[:2] # 获取图片的高和宽for i in range(n): # 生成n个椒盐噪声x np.random.randint(…...
WPF自定义窗口 输入验证不生效
WPF自定义窗口 输入验证不生效 WPF ValidationRule 不生效 WPF ValidationRule 不生效 解决方案:在WindowStyle的Template中添加AdornerDecorator标签。 <Style x:Key"WindowStyle1" TargetType"{x:Type Window}"><Setter Property&…...
【MySQL】 SQL优化讲解
一、优化前的思考 在定位到慢查询后,面试官常问如何优化或分析慢查询的SQL语句。若存在聚合查询、多表查询,可尝试优化SQL语句结构,如多表查询可新增临时表;若表数据量过大,可添加索引,但添加索引后仍慢则…...
05.HTTPS的实现原理-HTTPS的握手流程(TLS1.2)
05.HTTPS的实现原理-HTTPS的握手流程(TLS1.2) 简介1. TLS握手过程概述2. TLS握手过程细化3. 主密钥(对称密钥)生成过程4. 密码规范变更 简介 主要讲述了混合加密流程完成后,客户端和服务器如何共同获得相同的对称密钥…...
Java获取自身被调用点
1. 场景 打印日志的时候,需要获取是在哪个地方被调用了,把调用点的信息一并打印出来。 2. 获取自身被调用点的方法 可以通过获取线程的调用栈,遍历后找到调用点。 3. 代码实现 import java.text.SimpleDateFormat; import java.util.Dat…...
有序之美:C++ Set的哲学与诗意
文章目录 前言一.C set 的概念1.1 set 的定义1.2 set 的特点二. set 的构造方法2.1 常见构造函数2.1.1 示例:不同构造方法 2.2 相关文档 三.set 的常用操作3.1 插入操作详解3.1.1 使用 insert() 插入元素3.1.2 使用 emplace() 插入元素3.1.3 插入区间元素 3.2 查找操…...
22. 仿LISP运算
题目描述 LISP语言唯一的语法就是括号要配对 形如(OP P1 P2 ...),括号内元素由单个空格分割。其中第一个元素OP为操作符,后续元素均为其参数,参数个数取决于操作符类型。注意:参数P1,P2也有可能是另外一个嵌套的(OP P1 P2...),当前…...
大模型应用技术系列(三): 深入理解大模型应用中的Cache:GPTCache
前言 无论在什么技术栈中,缓存都是比较重要的一部分。在大模型技术栈中,缓存存在于技术栈中的不同层次。本文将主要聚焦于技术栈中应用层和底层基座之间中间件层的缓存(个人定位),以开源项目GPTCache(LLM的语义缓存)为例,深入讲解这部分缓存的结构和关键实现。 完整技术…...
MATLAB语言的网络编程
标题:MATLAB中的网络编程:深入探索与实践 一、引言 在现代科学和工程领域中,网络编程已经成为了数据处理、信号分析、模型构建等众多任务中不可或缺的一环。MATLAB作为一款强大的数学计算软件,不仅提供了丰富的数值计算功能&…...
边缘计算收益稳定
要使自己的PCDN(Personal Content Delivery Network,个人内容分发网络)收益更稳定,可以从以下几个方面进行努力: 一、选择合适的PCDN平台 平台稳定性:选择技术成熟、稳定性高的PCDN平台,确保内…...
计算机网络 (7)物理层下面的传输媒体
一、定义与位置 物理层是计算机网络体系结构的最低层,它位于传输媒体(传输介质)之上,主要作用是为数据链路层提供一个原始比特流的物理连接。这里的“比特流”是指数据以一个个0或1的二进制代码形式表示。物理层并不是特指某种传输…...
【GoPL】1.2 命令行参数
1.2 命令行参数 24-12-26 大部分程序处理输入,然后产生一些输出,这大概有点像计算的定义 但是程序怎么操作输入的数据?(用参数来操作)输入可能来自文件、网络连接、用户的键盘输入、命令行参数(不同的编程范式) os包提供函数和其他值来处理…...
高精度问题
目录 算法实现基础 高精度加法AB 测试链接 源代码 代码重点 高精度减法A-B 测试链接 源代码 代码重点 高精度乘法A*b和A*B 测试链接 源代码 代码重点 高精度除法A/b和A/B 测试链接 源代码 代码重点 高精度求和差积商余 算法实现基础 本算法调用STL…...
【无线通信】蜂窝系统——干扰与系统容量
干扰是蜂窝无线系统性能的主要限制因素。干扰来源包括同一小区中的其他移动终端、邻近小区正在进行的通话、其他基站在同一频段内的工作信号,或者任何不属于蜂窝系统的设备偶然向蜂窝频段泄漏信号。语音信道中的干扰会导致串音,使得用户在通话时听到背景…...
深入探索仓颉编程语言:函数与结构类型的终极指南
引言 仓颉编程语言是一种现代化、语法精炼的编程语言,其设计目标是提供高度的灵活性与高性能的执行效率。函数与结构类型是仓颉语言的两大基础模块,也是开发者需要掌握的核心。本文将详细讲解仓颉语言中函数和结构类型的特性,辅以代码实例和…...
010-spring-后置处理器(重要)
org.mybatis.spring.mapper.MapperScannerConfigurer...
SQL实现新年倒计时功能
马上就到 2025 年了,给大家分享一个使用 SQL 实现的新年倒计时功能。 以下是 PostgreSQL 语法: DO $$ DECLAREdiff INTERVAL; BEGINRAISE NOTICE 2025新年倒计时开始:;LOOP-- 计算当前时间距离2025年的时间间隔diff age(timestamp 2025-01…...
list模拟实现
目录 节点结构 构造函数 insert erase push_back push_front pop_front pop_back 拷贝构造 析构函数 赋值重载 正向迭代器实现 clear 反向迭代器实现 测试list 附完整代码 参照数据结构篇: 带头双向循环链表 节点结构 namespace dck {template <class T&g…...
JVM【Java虚拟机】基础知识(五)
1. 双亲委派机制 由于Java虚拟机中有多个类加载器,双亲委派机制的核心是解决一个类到底由谁加载的问题。 💡双亲委派机制有什么用? 1.保证类加载的安全性 通过双亲委派机制避免恶意代码替换JDK中的核心类库,比如java.lang.Str…...
阿尔萨斯(JVisualVM)JVM监控工具
文章目录 前言阿尔萨斯(JVisualVM)JVM监控工具1. 阿尔萨斯的功能2. JVisualVM启动3. 使用 前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。 而且听说点赞的人每天的运气都不会太差ÿ…...
Vue BPMN Modeler流程图
1、参考地址 git clone https://github.com/evanyangg/vue-bpmn-modeler.git 2、安装bpmn.js npm install bpmn-js --save 3、使用bpmn.js <template><div class"containers"><div class"canvas" ref"canvas"></div&g…...
python通过正则匹配SQL
pattern r"(?:[^;]|(?:\\.|[^])*);" sql_list [match.group().strip() for match in re.finditer(pattern, execute_sql) if match.group().strip()]for sql in sql_list:print(sql)(?:[^;]|(?:\\.|[^])*); 匹配 连续的非分号内容 或 单引号包裹的字符串&#…...
设置首选网络类型以及调用Android框架层的隐藏API
在Android SDK中提供的framework.jar是阉割版本的,比如有些类标记为hide,这些类不会被打包到这个jar中,而有些只是类中的某个方法或或属性被标记为hide,则这些类或属性会被打包到framework.jar,但是我们无法调用&#…...
观察者模式和发布-订阅模式有什么异同?它们在哪些情况下会被使用?
大家好,我是锋哥。今天分享关于【观察者模式和发布-订阅模式有什么异同?它们在哪些情况下会被使用?】面试题。希望对大家有帮助; 观察者模式和发布-订阅模式有什么异同?它们在哪些情况下会被使用? 1000道 …...
如何保证mysql数据库到ES的数据一致性
1.同步双写方案 在代码中对数据库和ES进行双写操作,确保先更新数据后更新ES。 优点: 数据一致性:双写策略可以保证在MySql和Elasticsearch之间数据的强一致性,因为每次数据库的变更都会在Elasticsearch同步反映。实时性…...
RabbitMQ 的7种工作模式
RabbitMQ 共提供了7种⼯作模式,进⾏消息传递,. 官⽅⽂档:RabbitMQ Tutorials | RabbitMQ 1.Simple(简单模式) P:⽣产者,也就是要发送消息的程序 C:消费者,消息的接收者 Queue:消息队列,图中⻩⾊背景部分.类似⼀个邮箱,可以缓存消息;⽣产者向其中投递消息,消费者从其中取出消息…...
红黑树 Red-Black Tree介绍
1. 红黑树的定义 红黑树是一种具有如下性质的二叉搜索树: 每个节点是红色或黑色。根节点是黑色。所有叶子节点都是黑色的空节点(NIL节点),即哨兵节点。如果一个节点是红色,那么它的子节点一定是黑色。(不存…...
我的创作纪念日—致敬未来的自己
机缘 为什么想去写文章呢? 1、想把自己学的知识和技能做一个总结。 2、想给多年后的自己留下一些财富。 3、希望自己分享的知识和经验也能帮到其他有需要的人 收获 在创作的过程中都有哪些收获? 1、每次对知识的总结,都让我的技能更加的…...
Android Studio IDE环境配置
需要安装哪些东西: Java jdk Java Downloads | OracleAndroid Studio 下载 Android Studio 和应用工具 - Android 开发者 | Android DevelopersAndroid Sdk 现在的Android Studio版本安装时会自动安装,需要注意下安装的路径Android Studio插件…...
matlab中的cell
在MATLAB中,cell 是一种非常重要的数据类型,它能够存储不同类型和大小的数据,这使得它非常灵活,适用于处理复杂的数据结构。 1. 基本介绍 cell 类型的变量可以存储不同类型的数据,如数值、字符、结构体、甚至其他的 …...
Vue项目中env文件的作用和配置
在实际项目的开发中,我们一般会经历项目的开发阶段、测试阶段和最终上线阶段,每一个阶段对于项目代码的要求可能都不尽相同,那么我们如何能够游刃有余的在不同阶段下使我们的项目呈现不同的效果,使用不同的功能呢?这里…...
基于致远OA+慧集通平台的企业主数据管理设计方案(一)
目标 1、实现集团组织主数据的集中统一管理,包括到主数据在致远中的审批新增、编辑、分发等操作; 2、实现集团用户系统权限的集中管理,统一在致远平台中为用户配置各系统中的权限,配置完成后,可以自动或手动的分发到…...
vue前端实现同步发送请求,可设置并发数量【已封装】
新建 TaskManager.js export default class TaskManager {constructor(maxConcurrentTasks 1) {// 最大并发任务数// to do// 并发任务数大于1 接口开始有概率返回空值,推测是后端问题this.maxConcurrentTasks maxConcurrentTasks;this.currentTasks 0;this.tas…...
vue3使用vant日历组件(calendar),自定义日历下标的两种方法
在vue3中使用vant日历组件(calendar)自定义下标的两种方法,推荐使用第二种: 日期下方加小圆点: 一、使用伪元素样式实现(::after伪元素小圆点样式会被覆盖,只能添加一个小圆点) 代码如下(示例…...
Java线程池面试题
为什么要用线程池 降低资源消耗:通过重复利用已创建的线程降低线程创建和销毁造成的消耗提高响应速度:当任务到达时,任务可以不需要等到线程创建就能立即执行方便管理线程:线程是稀缺资源,如果无条件地创建࿰…...
我的 2024 年终总结
2024 年,我离开了待了两年的互联网公司,来到了一家聚焦教育机器人和激光切割机的公司,没错,是一家硬件公司,从未接触过的领域,但这还不是我今年最重要的里程碑事件 5 月份的时候,正式提出了离职…...
Mysql8 数据库安装及主从配置
一、MySQL8 安装 下载 MySQL 8 的安装包并将其上传到服务器。将安装包解压到指定的目录,例如 /opt/mysql8。创建一个名为 mysql 的用户组和一个名为 mysql 的用户,并将用户添加到组中。同时,设置用户密码并更改用户的主目录和默认 shell。配…...
Unity中UGUI的Button动态绑定引用问题
Unity中UGUI的Button动态绑定引用问题 问题代码修改代码如下总结 问题代码 Button动态绑定几个连续的按钮事件时使用for循环的i做按钮的id发现按钮点击对应不上。如下代码 for (int i 0; i < 10; i) {btn[i].onClick.AddListener(() >{Click(i);}); }/// <summary&…...
测试基础之测试分类
软件测试是确保软件产品满足预期功能、性能和用户体验要求的关键环节。它的主要目的是通过系统化的方法发现并修复软件中的缺陷,从而提高软件的质量和可靠性。在软件开发生命周期的不同阶段执行测试,以尽早发现潜在的错误或类型,早期发现缺陷…...