【深度学习项目】语义分割-FCN网络(原理、网络架构、基于Pytorch实现FCN网络)
文章目录
- 介绍
- 深度学习语义分割的关键特点
- 主要架构和技术
- 数据集和评价指标
- 总结
- FCN网络
- FCN 的特点
- FCN 的工作原理
- FCN 的变体和发展
- FCN 的网络结构
- FCN 的实现(基于Pytorch)
- 1. 环境配置
- 2. 文件结构
- 3. 预训练权重下载地址
- 4. 数据集,本例程使用的是PASCAL VOC2012数据集
- 5. 训练方法
- 6. 注意事项
- 7. Pytorch官方实现的FCN网络框架图
- 8. 完整代码
- 8.1 src文件目录代码
- 8.2 train_utils文件目录代码
- 8.3 根目录代码
个人主页:道友老李
欢迎加入社区:道友老李的学习社区
介绍
深度学习语义分割(Semantic Segmentation)是一种计算机视觉任务,它旨在将图像中的每个像素分类为预定义类别之一。与物体检测不同,后者通常只识别和定位图像中的目标对象边界框,语义分割要求对图像的每一个像素进行分类,以实现更精细的理解。这项技术在自动驾驶、医学影像分析、机器人视觉等领域有着广泛的应用。
深度学习语义分割的关键特点
- 像素级分类:对于输入图像的每一个像素点,模型都需要预测其属于哪个类别。
- 全局上下文理解:为了正确地分割复杂场景,模型需要考虑整个图像的内容及其上下文信息。
- 多尺度处理:由于目标可能出现在不同的尺度上,有效的语义分割方法通常会处理多种分辨率下的特征。
主要架构和技术
-
全卷积网络 (FCN):
- FCN是最早的端到端训练的语义分割模型之一,它移除了传统CNN中的全连接层,并用卷积层替代,从而能够接受任意大小的输入并输出相同空间维度的概率图。
-
跳跃连接 (Skip Connections):
- 为了更好地保留原始图像的空间细节,一些模型引入了跳跃连接,即从编码器部分直接传递特征到解码器部分,这有助于恢复细粒度的结构信息。
-
U-Net:
- U-Net是一个专为生物医学图像分割设计的网络架构,它使用了对称的收缩路径(下采样)和扩展路径(上采样),以及丰富的跳跃连接来捕捉局部和全局信息。
-
DeepLab系列:
- DeepLab采用了空洞/膨胀卷积(Atrous Convolution)来增加感受野而不减少特征图分辨率,并通过多尺度推理和ASPP模块(Atrous Spatial Pyramid Pooling)增强了对不同尺度物体的捕捉能力。
-
PSPNet (Pyramid Scene Parsing Network):
- PSPNet利用金字塔池化机制收集不同尺度的上下文信息,然后将其融合用于最终的预测。
-
RefineNet:
- RefineNet强调了高分辨率特征的重要性,并通过一系列细化单元逐步恢复细节,确保输出高质量的分割结果。
-
HRNet (High-Resolution Network):
- HRNet在整个网络中保持了高分辨率的表示,同时通过多尺度融合策略有效地整合了低分辨率但富含语义的信息。
数据集和评价指标
常用的语义分割数据集包括PASCAL VOC、COCO、Cityscapes等。这些数据集提供了标注好的图像,用于训练和评估模型性能。
评价语义分割模型的标准通常包括:
- 像素准确率 (Pixel Accuracy):所有正确分类的像素占总像素的比例。
- 平均交并比 (Mean Intersection over Union, mIoU):这是最常用的评价指标之一,计算每个类别的IoU(交集除以并集),然后取平均值。
- 频率加权交并比 (Frequency Weighted IoU):考虑每个类别的出现频率,对mIoU进行加权。
总结
随着硬件性能的提升和算法的进步,深度学习语义分割已经取得了显著的进展。现代模型不仅能在速度上满足实时应用的需求,还能提供非常精确的分割结果。未来的研究可能会集中在提高模型效率、增强跨域泛化能力以及探索无监督或弱监督的学习方法等方面。
FCN网络
FCN(Fully Convolutional Networks,全卷积网络)是一种用于计算机视觉任务的神经网络架构,尤其擅长处理像素级别的分类问题,例如语义分割。FCN 的核心思想是将传统的 CNN(卷积神经网络)中的全连接层替换为卷积层,这样可以接受任意大小的输入图像,并输出同样大小的概率图,其中每个像素点对应于该位置属于某个类别的概率。
FCN 的特点
- 任意尺寸输入:由于没有全连接层的限制,FCN 可以处理任意尺寸的输入图像。
- 端到端训练:FCN 支持从原始像素到最终预测结果的端到端训练,不需要预先提取特征或者分阶段训练。
- 多尺度上下文信息:通过在不同层级使用跳跃结构(skip architecture),FCN 能够结合低层的精细空间信息和高层的语义信息,从而提升分割精度。
FCN 的工作原理
-
编码器部分:通常基于一个预训练好的分类网络(如 VGG、ResNet 等),移除掉最后的全连接层。这个部分负责提取图像特征,随着网络深度增加,感受野也逐渐扩大,能够捕捉到更大范围内的上下文信息。
-
解码器部分:这部分用来逐步恢复特征图的空间分辨率,直到与输入图像相同大小。这通常是通过上采样操作完成的,比如转置卷积(Deconvolution 或 Fractionally-strided convolution)。在这个过程中,可以加入来自编码器早期层的特征(即跳跃连接),来帮助保持细节。
-
跳跃连接(Skip Connections):为了保留更多的位置信息,FCN 会将编码器中较早层的特征图与解码器中对应的层进行融合。这种做法有助于改善边界区域的分割效果。
FCN 的变体和发展
自 FCN 提出以来,出现了许多改进版本,包括但不限于:
- U-Net:一种具有对称的编码器-解码器结构的网络,广泛应用于医学图像分割领域。
- DeepLab 系列:通过引入空洞卷积(Atrous Convolution)等技术来增强模型捕捉多尺度信息的能力。
- PSPNet (Pyramid Scene Parsing Network):利用金字塔池化模块获取全局上下文信息。
- RefineNet:专注于细节恢复,采用多路径优化策略来传递所有分辨率的信息。
这些模型在不同的应用场景中都有所应用,并且根据特定的任务需求不断进化和发展。
首个端到端的针对像素级预测的全卷积网络
FCN 的网络结构
Conv的参数量:77512*4096=102760448
FCN 的实现(基于Pytorch)
该项目主要是来自pytorch官方torchvision模块中的源码
https://github.com/pytorch/vision/tree/main/torchvision/models/segmentation
1. 环境配置
- Python3.6/3.7/3.8
- Pytorch1.10
- Ubuntu或Centos(Windows暂不支持多GPU训练)
- 最好使用GPU训练
- 详细环境配置见
requirements.txt
2. 文件结构
├── src: 模型的backbone以及FCN的搭建├── train_utils: 训练、验证以及多GPU训练相关模块├── my_dataset.py: 自定义dataset用于读取VOC数据集├── train.py: 以fcn_resnet50(这里使用了Dilated/Atrous Convolution)进行训练├── train_multi_GPU.py: 针对使用多GPU的用户使用├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试├── validation.py: 利用训练好的权重验证/测试数据的mIoU等指标,并生成record_mAP.txt文件└── pascal_voc_classes.json: pascal_voc标签文件
3. 预训练权重下载地址
- 注意:官方提供的预训练权重是在COCO上预训练得到的,训练时只针对和PASCAL VOC相同的类别进行了训练,所以类别数是21(包括背景)
- fcn_resnet50: https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth
- fcn_resnet101: https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth
- 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是
fcn_resnet50_coco.pth
文件,
不是fcn_resnet50_coco-1167a1af.pth
4. 数据集,本例程使用的是PASCAL VOC2012数据集
- Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
5. 训练方法
- 确保提前准备好数据集
- 确保提前下载好对应预训练模型权重
- 若要使用单GPU或者CPU训练,直接使用train.py训练脚本
- 若要使用多GPU训练,使用
torchrun --nproc_per_node=8 train_multi_GPU.py
指令,nproc_per_node
参数为使用GPU数量 - 如果想指定使用哪些GPU设备可在指令前加上
CUDA_VISIBLE_DEVICES=0,3
(例如我只要使用设备中的第1块和第4块GPU设备) CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py
6. 注意事项
- 在使用训练脚本时,注意要将’–data-path’(VOC_root)设置为自己存放’VOCdevkit’文件夹所在的根目录
- 在使用预测脚本时,要将’weights_path’设置为你自己生成的权重路径。
- 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改’–num-classes’、‘–aux’、‘–data-path’和’–weights’即可,其他代码尽量不要改动
7. Pytorch官方实现的FCN网络框架图
8. 完整代码
8.1 src文件目录代码
- init.py
from .fcn_model import fcn_resnet50, fcn_resnet101
- backbone.py
import torch
import torch.nn as nndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x):# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x):return self._forward_impl(x)def _resnet(block, layers, **kwargs):model = ResNet(block, layers, **kwargs)return modeldef resnet50(**kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)def resnet101(**kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
- fcn_model.py
from collections import OrderedDictfrom typing import Dictimport torch
from torch import nn, Tensor
from torch.nn import functional as F
from .backbone import resnet50, resnet101class IntermediateLayerGetter(nn.ModuleDict):"""Module wrapper that returns intermediate layers from a modelIt has a strong assumption that the modules have been registeredinto the model in the same order as they are used.This means that one should **not** reuse the same nn.Moduletwice in the forward if you want this to work.Additionally, it is only able to query submodules that are directlyassigned to the model. So if `model` is passed, `model.feature1` canbe returned, but not `model.feature1.layer2`.Args:model (nn.Module): model on which we will extract the featuresreturn_layers (Dict[name, new_name]): a dict containing the namesof the modules for which the activations will be returned asthe key of the dict, and the value of the dict is the nameof the returned activation (which the user can specify)."""_version = 2__annotations__ = {"return_layers": Dict[str, str],}def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:if not set(return_layers).issubset([name for name, _ in model.named_children()]):raise ValueError("return_layers are not present in model")orig_return_layers = return_layersreturn_layers = {str(k): str(v) for k, v in return_layers.items()}# 重新构建backbone,将没有使用到的模块全部删掉layers = OrderedDict()for name, module in model.named_children():layers[name] = moduleif name in return_layers:del return_layers[name]if not return_layers:breaksuper(IntermediateLayerGetter, self).__init__(layers)self.return_layers = orig_return_layersdef forward(self, x: Tensor) -> Dict[str, Tensor]:out = OrderedDict()for name, module in self.items():x = module(x)if name in self.return_layers:out_name = self.return_layers[name]out[out_name] = xreturn outclass FCN(nn.Module):"""Implements a Fully-Convolutional Network for semantic segmentation.Args:backbone (nn.Module): the network used to compute the features for the model.The backbone should return an OrderedDict[Tensor], with the key being"out" for the last feature map used, and "aux" if an auxiliary classifieris used.classifier (nn.Module): module that takes the "out" element returned fromthe backbone and returns a dense prediction.aux_classifier (nn.Module, optional): auxiliary classifier used during training"""__constants__ = ['aux_classifier']def __init__(self, backbone, classifier, aux_classifier=None):super(FCN, self).__init__()self.backbone = backboneself.classifier = classifierself.aux_classifier = aux_classifierdef forward(self, x: Tensor) -> Dict[str, Tensor]:input_shape = x.shape[-2:]# contract: features is a dict of tensorsfeatures = self.backbone(x)result = OrderedDict()x = features["out"]x = self.classifier(x)# 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["out"] = xif self.aux_classifier is not None:x = features["aux"]x = self.aux_classifier(x)# 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)result["aux"] = xreturn resultclass FCNHead(nn.Sequential):def __init__(self, in_channels, channels):inter_channels = in_channels // 4layers = [nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(0.1),nn.Conv2d(inter_channels, channels, 1)]super(FCNHead, self).__init__(*layers)def fcn_resnet50(aux, num_classes=21, pretrain_backbone=False):# 'resnet50_imagenet': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'# 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth'backbone = resnet50(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 载入resnet50 backbone预训练权重backbone.load_state_dict(torch.load("resnet50.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return modeldef fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):# 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'# 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'backbone = resnet101(replace_stride_with_dilation=[False, True, True])if pretrain_backbone:# 载入resnet101 backbone预训练权重backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))out_inplanes = 2048aux_inplanes = 1024return_layers = {'layer4': 'out'}if aux:return_layers['layer3'] = 'aux'backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)aux_classifier = None# why using aux: https://github.com/pytorch/vision/issues/4292if aux:aux_classifier = FCNHead(aux_inplanes, num_classes)classifier = FCNHead(out_inplanes, num_classes)model = FCN(backbone, classifier, aux_classifier)return model
8.2 train_utils文件目录代码
- init.py
from .train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
from .distributed_utils import init_distributed_mode, save_on_master, mkdir
- distributed_utils.py
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as distimport errno
import osclass SmoothedValue(object):"""Track a series of values and provide access to smoothed values over awindow or the global series average."""def __init__(self, window_size=20, fmt=None):if fmt is None:fmt = "{value:.4f} ({global_avg:.4f})"self.deque = deque(maxlen=window_size)self.total = 0.0self.count = 0self.fmt = fmtdef update(self, value, n=1):self.deque.append(value)self.count += nself.total += value * ndef synchronize_between_processes(self):"""Warning: does not synchronize the deque!"""if not is_dist_avail_and_initialized():returnt = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')dist.barrier()dist.all_reduce(t)t = t.tolist()self.count = int(t[0])self.total = t[1]@propertydef median(self):d = torch.tensor(list(self.deque))return d.median().item()@propertydef avg(self):d = torch.tensor(list(self.deque), dtype=torch.float32)return d.mean().item()@propertydef global_avg(self):return self.total / self.count@propertydef max(self):return max(self.deque)@propertydef value(self):return self.deque[-1]def __str__(self):return self.fmt.format(median=self.median,avg=self.avg,global_avg=self.global_avg,max=self.max,value=self.value)class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, a, b):n = self.num_classesif self.mat is None:# 创建混淆矩阵self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)with torch.no_grad():# 寻找GT中为目标的像素索引k = (a >= 0) & (a < n)# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)inds = n * a[k].to(torch.int64) + b[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h = self.mat.float()# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)acc_global = torch.diag(h).sum() / h.sum()# 计算每个类别的准确率acc = torch.diag(h) / h.sum(1)# 计算每个类别预测与真实目标的iouiu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)class MetricLogger(object):def __init__(self, delimiter="\t"):self.meters = defaultdict(SmoothedValue)self.delimiter = delimiterdef update(self, **kwargs):for k, v in kwargs.items():if isinstance(v, torch.Tensor):v = v.item()assert isinstance(v, (float, int))self.meters[k].update(v)def __getattr__(self, attr):if attr in self.meters:return self.meters[attr]if attr in self.__dict__:return self.__dict__[attr]raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))def __str__(self):loss_str = []for name, meter in self.meters.items():loss_str.append("{}: {}".format(name, str(meter)))return self.delimiter.join(loss_str)def synchronize_between_processes(self):for meter in self.meters.values():meter.synchronize_between_processes()def add_meter(self, name, meter):self.meters[name] = meterdef log_every(self, iterable, print_freq, header=None):i = 0if not header:header = ''start_time = time.time()end = time.time()iter_time = SmoothedValue(fmt='{avg:.4f}')data_time = SmoothedValue(fmt='{avg:.4f}')space_fmt = ':' + str(len(str(len(iterable)))) + 'd'if torch.cuda.is_available():log_msg = self.delimiter.join([header,'[{0' + space_fmt + '}/{1}]','eta: {eta}','{meters}','time: {time}','data: {data}','max mem: {memory:.0f}'])else:log_msg = self.delimiter.join([header,'[{0' + space_fmt + '}/{1}]','eta: {eta}','{meters}','time: {time}','data: {data}'])MB = 1024.0 * 1024.0for obj in iterable:data_time.update(time.time() - end)yield objiter_time.update(time.time() - end)if i % print_freq == 0:eta_seconds = iter_time.global_avg * (len(iterable) - i)eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))if torch.cuda.is_available():print(log_msg.format(i, len(iterable), eta=eta_string,meters=str(self),time=str(iter_time), data=str(data_time),memory=torch.cuda.max_memory_allocated() / MB))else:print(log_msg.format(i, len(iterable), eta=eta_string,meters=str(self),time=str(iter_time), data=str(data_time)))i += 1end = time.time()total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print('{} Total time: {}'.format(header, total_time_str))def mkdir(path):try:os.makedirs(path)except OSError as e:if e.errno != errno.EEXIST:raisedef setup_for_distributed(is_master):"""This function disables printing when not in master process"""import builtins as __builtin__builtin_print = __builtin__.printdef print(*args, **kwargs):force = kwargs.pop('force', False)if is_master or force:builtin_print(*args, **kwargs)__builtin__.print = printdef is_dist_avail_and_initialized():if not dist.is_available():return Falseif not dist.is_initialized():return Falsereturn Truedef get_world_size():if not is_dist_avail_and_initialized():return 1return dist.get_world_size()def get_rank():if not is_dist_avail_and_initialized():return 0return dist.get_rank()def is_main_process():return get_rank() == 0def save_on_master(*args, **kwargs):if is_main_process():torch.save(*args, **kwargs)def init_distributed_mode(args):if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:args.rank = int(os.environ["RANK"])args.world_size = int(os.environ['WORLD_SIZE'])args.gpu = int(os.environ['LOCAL_RANK'])elif 'SLURM_PROCID' in os.environ:args.rank = int(os.environ['SLURM_PROCID'])args.gpu = args.rank % torch.cuda.device_count()elif hasattr(args, "rank"):passelse:print('Not using distributed mode')args.distributed = Falsereturnargs.distributed = Truetorch.cuda.set_device(args.gpu)args.dist_backend = 'nccl'print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,world_size=args.world_size, rank=args.rank)setup_for_distributed(args.rank == 0)
- train_and_eval.py
import torch
from torch import nn
import train_utils.distributed_utils as utilsdef criterion(inputs, target):losses = {}for name, x in inputs.items():# 忽略target中值为255的像素,255的像素是目标边缘或者padding填充losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)if len(losses) == 1:return losses['out']return losses['out'] + 0.5 * losses['aux']def evaluate(model, data_loader, device, num_classes):model.eval()confmat = utils.ConfusionMatrix(num_classes)metric_logger = utils.MetricLogger(delimiter=" ")header = 'Test:'with torch.no_grad():for image, target in metric_logger.log_every(data_loader, 100, header):image, target = image.to(device), target.to(device)output = model(image)output = output['out']confmat.update(target.flatten(), output.argmax(1).flatten())confmat.reduce_from_all_processes()return confmatdef train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):model.train()metric_logger = utils.MetricLogger(delimiter=" ")metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))header = 'Epoch: [{}]'.format(epoch)for image, target in metric_logger.log_every(data_loader, print_freq, header):image, target = image.to(device), target.to(device)with torch.cuda.amp.autocast(enabled=scaler is not None):output = model(image)loss = criterion(output, target)optimizer.zero_grad()if scaler is not None:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()lr_scheduler.step()lr = optimizer.param_groups[0]["lr"]metric_logger.update(loss=loss.item(), lr=lr)return metric_logger.meters["loss"].global_avg, lrdef create_lr_scheduler(optimizer,num_step: int,epochs: int,warmup=True,warmup_epochs=1,warmup_factor=1e-3):assert num_step > 0 and epochs > 0if warmup is False:warmup_epochs = 0def f(x):"""根据step数返回一个学习率倍率因子,注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法"""if warmup is True and x <= (warmup_epochs * num_step):alpha = float(x) / (warmup_epochs * num_step)# warmup过程中lr倍率因子从warmup_factor -> 1return warmup_factor * (1 - alpha) + alphaelse:# warmup后lr倍率因子从1 -> 0# 参考deeplab_v2: Learning rate policyreturn (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
8.3 根目录代码
- pascal_voc_classes.json
{"aeroplane": 1,"bicycle": 2,"bird": 3,"boat": 4,"bottle": 5,"bus": 6,"car": 7,"cat": 8,"chair": 9,"cow": 10,"diningtable": 11,"dog": 12,"horse": 13,"motorbike": 14,"person": 15,"pottedplant": 16,"sheep": 17,"sofa": 18,"train": 19,"tvmonitor": 20
}
- my_dataset.py
import osimport torch.utils.data as data
from PIL import Imageclass VOCSegmentation(data.Dataset):def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):super(VOCSegmentation, self).__init__()assert year in ["2007", "2012"], "year must be in ['2007', '2012']"root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")assert os.path.exists(root), "path '{}' does not exist.".format(root)image_dir = os.path.join(root, 'JPEGImages')mask_dir = os.path.join(root, 'SegmentationClass')txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)with open(os.path.join(txt_path), "r") as f:file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]assert (len(self.images) == len(self.masks))self.transforms = transformsdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is the image segmentation."""img = Image.open(self.images[index]).convert('RGB')target = Image.open(self.masks[index])if self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.images)@staticmethoddef collate_fn(batch):images, targets = list(zip(*batch))batched_imgs = cat_list(images, fill_value=0)batched_targets = cat_list(targets, fill_value=255)return batched_imgs, batched_targetsdef cat_list(images, fill_value=0):# 计算该batch数据中,channel, h, w的最大值max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))batch_shape = (len(images),) + max_sizebatched_imgs = images[0].new(*batch_shape).fill_(fill_value)for img, pad_img in zip(images, batched_imgs):pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)return batched_imgs# dataset = VOCSegmentation(voc_root="/data/", transforms=get_transform(train=True))
# d1 = dataset[0]
# print(d1)
- validation.py
import os
import torchfrom src import fcn_resnet50
from train_utils import evaluate
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")assert os.path.exists(args.weights), f"weights {args.weights} not found."# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=SegmentationPresetEval(520),txt_name="val.txt")num_workers = 8val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)model = fcn_resnet50(aux=args.aux, num_classes=num_classes)model.load_state_dict(torch.load(args.weights, map_location=device)['model'])model.to(device)confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)print(confmat)def parse_args():import argparseparser = argparse.ArgumentParser(description="pytorch fcn training")parser.add_argument("--data-path", default="/data/", help="VOCdevkit root")parser.add_argument("--weights", default="./save_weights/model_29.pth")parser.add_argument("--num-classes", default=20, type=int)parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")parser.add_argument("--device", default="cuda", help="training device")parser.add_argument('--print-freq', default=10, type=int, help='print frequency')args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()if not os.path.exists("./save_weights"):os.mkdir("./save_weights")main(args)
- transforms.py
import numpy as np
import randomimport torch
from torchvision import transforms as T
from torchvision.transforms import functional as Fdef pad_if_smaller(img, size, fill=0):# 如果图像最小边长小于给定size,则用数值fill进行paddingmin_size = min(img.size)if min_size < size:ow, oh = img.sizepadh = size - oh if oh < size else 0padw = size - ow if ow < size else 0img = F.pad(img, (0, 0, padw, padh), fill=fill)return imgclass Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target):for t in self.transforms:image, target = t(image, target)return image, targetclass RandomResize(object):def __init__(self, min_size, max_size=None):self.min_size = min_sizeif max_size is None:max_size = min_sizeself.max_size = max_sizedef __call__(self, image, target):size = random.randint(self.min_size, self.max_size)# 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小image = F.resize(image, size)# 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST# 如果是之前的版本需要使用PIL.Image.NEARESTtarget = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)return image, targetclass RandomHorizontalFlip(object):def __init__(self, flip_prob):self.flip_prob = flip_probdef __call__(self, image, target):if random.random() < self.flip_prob:image = F.hflip(image)target = F.hflip(target)return image, targetclass RandomCrop(object):def __init__(self, size):self.size = sizedef __call__(self, image, target):image = pad_if_smaller(image, self.size)target = pad_if_smaller(target, self.size, fill=255)crop_params = T.RandomCrop.get_params(image, (self.size, self.size))image = F.crop(image, *crop_params)target = F.crop(target, *crop_params)return image, targetclass CenterCrop(object):def __init__(self, size):self.size = sizedef __call__(self, image, target):image = F.center_crop(image, self.size)target = F.center_crop(target, self.size)return image, targetclass ToTensor(object):def __call__(self, image, target):image = F.to_tensor(image)target = torch.as_tensor(np.array(target), dtype=torch.int64)return image, targetclass Normalize(object):def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, image, target):image = F.normalize(image, mean=self.mean, std=self.std)return image, target
- train.py
import os
import time
import datetimeimport torchfrom src import fcn_resnet50
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)def create_model(aux, num_classes, pretrain=True):model = fcn_resnet50(aux=aux, num_classes=num_classes)if pretrain:weights_dict = torch.load("./src/fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的预训练权重是21类(包括背景)# 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return modeldef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")batch_size = args.batch_size# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# 用来保存训练以及验证过程中信息results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=True),txt_name="train.txt")# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=False),txt_name="val.txt")num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,pin_memory=True,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,num_workers=num_workers,pin_memory=True,collate_fn=val_dataset.collate_fn)model = create_model(aux=args.aux, num_classes=num_classes)model.to(device)params_to_optimize = [{"params": [p for p in model.backbone.parameters() if p.requires_grad]},{"params": [p for p in model.classifier.parameters() if p.requires_grad]}]if args.aux:params = [p for p in model.aux_classifier.parameters() if p.requires_grad]params_to_optimize.append({"params": params, "lr": args.lr * 10})optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)scaler = torch.cuda.amp.GradScaler() if args.amp else None# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu')model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])args.start_epoch = checkpoint['epoch'] + 1if args.amp:scaler.load_state_dict(checkpoint["scaler"])start_time = time.time()for epoch in range(args.start_epoch, args.epochs):mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# write into txtwith open(results_file, "a") as f:# 记录每个epoch对应的train_loss、lr以及验证集各指标train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")save_file = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"lr_scheduler": lr_scheduler.state_dict(),"epoch": epoch,"args": args}if args.amp:save_file["scaler"] = scaler.state_dict()torch.save(save_file, "save_weights/model_{}.pth".format(epoch))total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print("training time {}".format(total_time_str))def parse_args():import argparseparser = argparse.ArgumentParser(description="pytorch fcn training")parser.add_argument("--data-path", default="/data/", help="VOCdevkit root")parser.add_argument("--num-classes", default=20, type=int)parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")parser.add_argument("--device", default="cuda", help="training device")parser.add_argument("-b", "--batch-size", default=4, type=int)parser.add_argument("--epochs", default=30, type=int, metavar="N",help="number of total epochs to train")parser.add_argument('--lr', default=0.0001, type=float, help='initial learning rate')parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')parser.add_argument('--print-freq', default=10, type=int, help='print frequency')parser.add_argument('--resume', default='', help='resume from checkpoint')parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='start epoch')# Mixed precision training parametersparser.add_argument("--amp", default=False, type=bool,help="Use torch.cuda.amp for mixed precision training")args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()if not os.path.exists("./save_weights"):os.mkdir("./save_weights")main(args)
- train_multi_GPU.py
import time
import os
import datetimeimport torchfrom src import fcn_resnet50
from train_utils import train_one_epoch, evaluate, create_lr_scheduler, init_distributed_mode, save_on_master, mkdir
from my_dataset import VOCSegmentation
import transforms as Tclass SegmentationPresetTrain:def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):min_size = int(0.5 * base_size)max_size = int(2.0 * base_size)trans = [T.RandomResize(min_size, max_size)]if hflip_prob > 0:trans.append(T.RandomHorizontalFlip(hflip_prob))trans.extend([T.RandomCrop(crop_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])self.transforms = T.Compose(trans)def __call__(self, img, target):return self.transforms(img, target)class SegmentationPresetEval:def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):self.transforms = T.Compose([T.RandomResize(base_size, base_size),T.ToTensor(),T.Normalize(mean=mean, std=std),])def __call__(self, img, target):return self.transforms(img, target)def get_transform(train):base_size = 520crop_size = 480return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)def create_model(aux, num_classes):model = fcn_resnet50(aux=aux, num_classes=num_classes)weights_dict = torch.load("./fcn_resnet50_coco.pth", map_location='cpu')if num_classes != 21:# 官方提供的预训练权重是21类(包括背景)# 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错for k in list(weights_dict.keys()):if "classifier.4" in k:del weights_dict[k]missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)if len(missing_keys) != 0 or len(unexpected_keys) != 0:print("missing_keys: ", missing_keys)print("unexpected_keys: ", unexpected_keys)return modeldef main(args):init_distributed_mode(args)print(args)device = torch.device(args.device)# segmentation nun_classes + backgroundnum_classes = args.num_classes + 1# 用来保存coco_info的文件results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))VOC_root = args.data_path# check voc rootif os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))# load train data set# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txttrain_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=True),txt_name="train.txt")# load validation data set# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txtval_dataset = VOCSegmentation(args.data_path,year="2012",transforms=get_transform(train=False),txt_name="val.txt")print("Creating data loaders")if args.distributed:train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)else:train_sampler = torch.utils.data.RandomSampler(train_dataset)test_sampler = torch.utils.data.SequentialSampler(val_dataset)train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,sampler=train_sampler, num_workers=args.workers,collate_fn=train_dataset.collate_fn, drop_last=True)val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,sampler=test_sampler, num_workers=args.workers,collate_fn=train_dataset.collate_fn)print("Creating model")# create model num_classes equal background + 20 classesmodel = create_model(aux=args.aux, num_classes=num_classes)model.to(device)if args.sync_bn:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)model_without_ddp = modelif args.distributed:model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])model_without_ddp = model.moduleparams_to_optimize = [{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},]if args.aux:params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]params_to_optimize.append({"params": params, "lr": args.lr * 10})optimizer = torch.optim.SGD(params_to_optimize,lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)scaler = torch.cuda.amp.GradScaler() if args.amp else None# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)lr_scheduler = create_lr_scheduler(optimizer, len(train_data_loader), args.epochs, warmup=True)# 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练if args.resume:# If map_location is missing, torch.load will first load the module to CPU# and then copy each parameter to where it was saved,# which would result in all processes on the same machine using the same set of devices.checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)model_without_ddp.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])args.start_epoch = checkpoint['epoch'] + 1if args.amp:scaler.load_state_dict(checkpoint["scaler"])if args.test_only:confmat = evaluate(model, val_data_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)returnprint("Start training")start_time = time.time()for epoch in range(args.start_epoch, args.epochs):if args.distributed:train_sampler.set_epoch(epoch)mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader, device, epoch,lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)confmat = evaluate(model, val_data_loader, device=device, num_classes=num_classes)val_info = str(confmat)print(val_info)# 只在主进程上进行写操作if args.rank in [-1, 0]:# write into txtwith open(results_file, "a") as f:# 记录每个epoch对应的train_loss、lr以及验证集各指标train_info = f"[epoch: {epoch}]\n" \f"train_loss: {mean_loss:.4f}\n" \f"lr: {lr:.6f}\n"f.write(train_info + val_info + "\n\n")if args.output_dir:# 只在主节点上执行保存权重操作save_file = {'model': model_without_ddp.state_dict(),'optimizer': optimizer.state_dict(),'lr_scheduler': lr_scheduler.state_dict(),'args': args,'epoch': epoch}if args.amp:save_file["scaler"] = scaler.state_dict()save_on_master(save_file,os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print('Training time {}'.format(total_time_str))if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description=__doc__)# 训练文件的根目录(VOCdevkit)parser.add_argument('--data-path', default='/data/', help='dataset')# 训练设备类型parser.add_argument('--device', default='cuda', help='device')# 检测目标类别数(不包含背景)parser.add_argument('--num-classes', default=20, type=int, help='num_classes')# 每块GPU上的batch_sizeparser.add_argument('-b', '--batch-size', default=4, type=int,help='images per gpu, the total batch size is $NGPU x batch_size')parser.add_argument("--aux", default=True, type=bool, help="auxilier loss")# 指定接着从哪个epoch数开始训练parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')# 训练的总epoch数parser.add_argument('--epochs', default=20, type=int, metavar='N',help='number of total epochs to run')# 是否使用同步BN(在多个GPU之间同步),默认不开启,开启后训练速度会变慢parser.add_argument('--sync_bn', type=bool, default=False, help='whether using SyncBatchNorm')# 数据加载以及预处理的线程数parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')# 训练学习率,这里默认设置成0.0001,如果效果不好可以尝试加大学习率parser.add_argument('--lr', default=0.0001, type=float,help='initial learning rate')# SGD的momentum参数parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')# SGD的weight_decay参数parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')# 训练过程打印信息的频率parser.add_argument('--print-freq', default=20, type=int, help='print frequency')# 文件保存地址parser.add_argument('--output-dir', default='./multi_train', help='path where to save')# 基于上次的训练结果接着训练parser.add_argument('--resume', default='', help='resume from checkpoint')# 不训练,仅测试parser.add_argument("--test-only",dest="test_only",help="Only test the model",action="store_true",)# 分布式进程数parser.add_argument('--world-size', default=1, type=int,help='number of distributed processes')parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')# Mixed precision training parametersparser.add_argument("--amp", default=False, type=bool,help="Use torch.cuda.amp for mixed precision training")args = parser.parse_args()# 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建if args.output_dir:mkdir(args.output_dir)main(args)
- predict.py
import os
import time
import jsonimport torch
from torchvision import transforms
import numpy as np
from PIL import Imagefrom src import fcn_resnet50def time_synchronized():torch.cuda.synchronize() if torch.cuda.is_available() else Nonereturn time.time()def main():aux = False # inference time not need aux_classifierclasses = 20weights_path = "./save_weights/model_2.pth"img_path = "./test.jpeg"palette_path = "./palette.json"assert os.path.exists(weights_path), f"weights {weights_path} not found."assert os.path.exists(img_path), f"image {img_path} not found."assert os.path.exists(palette_path), f"palette {palette_path} not found."with open(palette_path, "rb") as f:pallette_dict = json.load(f)pallette = []for v in pallette_dict.values():pallette += v# get devicesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# create modelmodel = fcn_resnet50(aux=aux, num_classes=classes+1)# delete weights about aux_classifierweights_dict = torch.load(weights_path, map_location='cpu')['model']for k in list(weights_dict.keys()):if "aux" in k:del weights_dict[k]# load weightsmodel.load_state_dict(weights_dict)model.to(device)# load imageoriginal_img = Image.open(img_path)# from pil image to tensor and normalizedata_transform = transforms.Compose([transforms.Resize(520),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225))])img = data_transform(original_img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)model.eval() # 进入验证模式with torch.no_grad():# init modelimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()output = model(img.to(device))t_end = time_synchronized()print("inference+NMS time: {}".format(t_end - t_start))prediction = output['out'].argmax(1).squeeze(0)prediction = prediction.to("cpu").numpy().astype(np.uint8)mask = Image.fromarray(prediction)mask.putpalette(pallette)mask.save("test_result.png")if __name__ == '__main__':main()
测试图片:
预测结果:
相关文章:
【深度学习项目】语义分割-FCN网络(原理、网络架构、基于Pytorch实现FCN网络)
文章目录 介绍深度学习语义分割的关键特点主要架构和技术数据集和评价指标总结 FCN网络FCN 的特点FCN 的工作原理FCN 的变体和发展FCN 的网络结构FCN 的实现(基于Pytorch)1. 环境配置2. 文件结构3. 预训练权重下载地址4. 数据集,本例程使用的…...
集群、分布式及微服务间的区别与联系
目录 单体架构介绍集群和分布式架构集群和分布式集群和分布式区别和联系 微服务架构的引入微服务带来的挑战 总结 单体架构介绍 早期很多创业公司或者传统企业会把业务的所有功能实现都打包在一个项目中,这种方式就称为单体架构 以我们都很熟悉的电商系统为例&…...
ConvBERT:通过基于跨度的动态卷积改进BERT
摘要 像BERT及其变体这样的预训练语言模型最近在各种自然语言理解任务中取得了令人印象深刻的性能。然而,BERT严重依赖于全局自注意力机制,因此存在较大的内存占用和计算成本。尽管所有的注意力头都从全局角度查询整个输入序列以生成注意力图࿰…...
C# 网络协议第三方库Protobuf的使用
为什么要使用二进制数据 通常我们写一个简单的网络通讯软件可能使用的最多的是字符串类型,比较简单,例如发送格式为(head)19|Msg:Heart|100,x,y,z…,在接收端会解析收到的socket数据。 这样通常是完全可行的,但是随着数据量变大&…...
「2024 博客之星」自研Java框架 Sunrays-Framework 使用教程
文章目录 0.序言我的成长历程遇到挫折,陷入低谷重拾信心,迎接未来开源与分享我为何如此看重这次评选最后的心声 1.概述1.主要功能2.相关链接 2.系统要求构建工具框架和语言数据库与缓存消息队列与对象存储 3.快速入门0.配置Maven中央仓库1.打开settings.…...
【Elasticsearch】Springboot编写Elasticsearch的RestAPI
RestAPI 初始化RestClient创建索引库Mapping映射 判断索引库是否存在删除索引库总结 ES官方提供了各种不同语言的客户端,用来操作ES。这些客户端的本质就是组装DSL语句,通过http请求发送给ES。 官方文档地址 由于ES目前最新版本是8.8,提供了全…...
Swift语言的学习路线
Swift语言的学习路线 引言 在现代程序开发中,Swift语言逐渐成为了移动应用程序开发的重要工具,尤其是在iOS和macOS平台上。自2014年发布以来,Swift以其易读性和强大的功能,受到越来越多开发者的青睐。对于初学者而言,…...
63,【3】buuctf web Upload-Labs-Linux 1
进入靶场 点击pass1 查看提示 既然是上传文件,先构造一句话木马,便于用蚁剑连接 <?php eval($_POST[123])?> 将这两处的检查函数删掉 再上传木马 文件后缀写为.php 右键复制图片地址 打开蚁剑连接 先点击测试连接,显示成功后&…...
Leetcode:2239
1,题目 2,思路 循环遍历满足条件就记录,最后返回结果值 3,代码 public class Leetcode2239 {public static void main(String[] args) {System.out.println(new Solution2239().findClosestNumber(new int[]{-4, -2, 1, 4, 8})…...
自然语言处理与NLTK环境配置
自然语言处理(Natural Language Processing, NLP)是人工智能的重要分支,专注于计算机如何理解、分析和生成自然语言。自然语言是人类用于交流的语言,如中文、英文等,这使得自然语言处理成为沟通人与计算机的桥梁。近年来,NLP在诸多领域得到广泛应用,包括文本分析、语言翻…...
分布式搜索引擎02
1. DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1. DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询。常见的查询类型包括: 查询所有:查询出所有数据,…...
使用 Logback 的最佳实践:`logback.xml` 与 `logback-spring.xml` 的区别与用法
在开发 Spring Boot 项目时,日志是调试和监控的重要工具。Spring Boot 默认支持 Logback 作为日志系统,并提供了 logback.xml 和 logback-spring.xml 两种配置方式。这篇文章将详细介绍这两者的区别、各自的优缺点以及最佳实践。 目录 一、什么是 Logbac…...
【爬虫开发】爬虫开发从0到1全知识教程第12篇:scrapy爬虫框架,介绍【附代码文档】
本教程的知识点为:爬虫概要 爬虫基础 爬虫概述 知识点: 1. 爬虫的概念 requests模块 requests模块 知识点: 1. requests模块介绍 1.1 requests模块的作用: 数据提取概要 数据提取概述 知识点 1. 响应内容的分类 知识点:…...
鸿蒙UI(ArkUI-方舟UI框架)-开发布局
文章目录 开发布局1、布局概述1)布局结构2)布局元素组成3)如何选择布局4)布局位置5)对子元素的约束 2、构建布局1)线性布局 (Row/Column)概述布局子元素在排列方向上的间距布局子元素在交叉轴上的对齐方式(…...
代码随想录_字符串
字符串 344.反转字符串 344. 反转字符串 编写一个函数,其作用是将输入的字符串反转过来。输入字符串以字符数组 s 的形式给出。 不要给另外的数组分配额外的空间,你必须**原地修改输入数组**、使用 O(1) 的额外空间解决这一问题。 思路: 双指针 代…...
2025年1月17日(点亮三色LED)
系统信息: Raspberry Pi Zero 2W 系统版本: 2024-10-22-raspios-bullseye-armhf Python 版本:Python 3.9.2 已安装 pip3 支持拍摄 1080p 30 (1092*1080), 720p 60 (1280*720), 60/90 (640*480) 已安装 vim 已安装 git 学习目标:…...
Spring Boot自动配置原理:如何实现零配置启动
引言 在现代软件开发中,Spring 框架已经成为 Java 开发领域不可或缺的一部分。而 Spring Boot 的出现,更是为 Spring 应用的开发带来了革命性的变化。Spring Boot 的核心优势之一就是它的“自动配置”能力,它极大地简化了 Spring 应用的配置…...
React技术栈搭配(全栈)(MERN栈、PERN栈)
文章目录 1. MERN 栈2. PERN 栈3. React Next.js Node.js4. JAMstack (JavaScript, APIs, Markup)5. React GraphQL Node.js6. React Native Node.js结论 React作为前端框架已经成为了现代web开发的重要组成部分。在全栈开发中,React通常…...
Linux - 线程池
线程池 什么是池? 池化技术的核心就是"提前准备并重复利用资源". 减少资源创建和销毁的成本. 那么线程池就是提前准备好一些线程, 当有任务来临时, 就可以直接交给这些线程运行, 当线程完成这些任务后, 并不会被销毁, 而是继续等待任务. 那么这些线程在程序运行过程…...
以Python构建ONE FACE管理界面:从基础至进阶的实战探索
一、引言 1.1 研究背景与意义 在人工智能技术蓬勃发展的当下,面部识别技术凭借其独特优势,于安防、金融、智能终端等众多领域广泛应用。在安防领域,可助力监控系统精准识别潜在威胁人员,提升公共安全保障水平;金融行业中,实现刷脸支付、远程开户等便捷服务,优化用户体…...
使用Sum计算Loss和解决梯度累积(Gradient Accumulation)的Bug
使用Sum计算Loss和解决梯度累积的Bug 学习 https://unsloth.ai/blog/gradient:Bugs in LLM Training - Gradient Accumulation Fix 这篇文章的记录。 在深度学习训练过程中,尤其是在大批量(large batch)训练中,如何高…...
mfc操作json示例
首先下载cJSON,加入项目; 构建工程,如果出现, fatal error C1010: unexpected end of file while looking for precompiled head 在cJSON.c文件的头部加入#include "stdafx.h"; 看情况,可能是加到.h或者是.cpp文件的头部,它如果有包含头文件, #include &…...
C语言练习(18)
一个班10个学生的成绩,存放在一个一维数组中,要求找出其中成绩最高的学生成绩和该生的序号。 #include <stdio.h>#define STUDENT_NUM 10 // 定义学生数量int main() {int scores[STUDENT_NUM]; // 定义存储学生成绩的一维数组int i;// 输入10个…...
LeetCode 热题 100_全排列(55_46_中等_C++)(递归(回溯))
LeetCode 热题 100_两数之和(55_46) 题目描述:输入输出样例:题解:解题思路:思路一(递归(回溯)): 代码实现代码实现(思路一(…...
编译chromium笔记
编译环境: windows10 powershell7.2.24 git 2.47.1 https://storage.googleapis.com/chrome-infra/depot_tools.zip 配置git git config --global user.name "John Doe" git config --global user.email "jdoegmail.com" git config --global …...
PHP语言的数据库编程
PHP语言的数据库编程 引言 随着互联网的发展,动态网站已成为主流,而动态网站的核心就是与数据库进行交互。PHP(超文本预处理器)是一种流行的开源服务器端脚本语言,被广泛用于Web开发。它以其简单易学和功能强大而受到…...
【PGCCC】PostgreSQL 中表级锁的剖析
本博客解释了 PostgreSQL 中的锁定机制,重点关注数据定义语言 (DDL) 操作所需的表级锁定。 锁定还是解锁的艺术? 人们通常将数据库锁与物理锁进行比较,这甚至可能导致您订购有关锁的历史、波斯锁和撬锁技术的书籍。我们大多数人可能都是通过…...
1.10 自洽性(Self-Consistency):多路径推理的核心力量
自洽性(Self-Consistency):多路径推理的核心力量 随着人工智能尤其是大规模语言模型的不断进化,如何提升其推理能力和决策准确性成为了研究的重点。在这一背景下,**自洽性(Self-Consistency)**作为一种新的推理方法,逐渐展现出其强大的潜力。自洽性方法通过多路径推理…...
【24】Word:小郑-准考证❗
目录 题目 准考证.docx 邮件合并-指定考生生成准考证 Word.docx 表格内容居中表格整体相较于页面居中 考试时一定要做一问保存一问❗ 题目 准考证.docx 插入→表格→将文本转换成表格→✔制表符→确定选中第一列→单击右键→在第一列的右侧插入列→布局→合并单元格&#…...
Linux 信号(Signal)详解
信号(Signal)是 Linux 系统中用于进程间通信的一种机制。它是一种异步通知,用于通知进程发生了某个事件。信号可以来自内核、其他进程或进程自身。 信号的基本概念 信号的作用: 通知进程发生了某个事件(如用户按下 Ct…...
【数据分享】1929-2024年全球站点的逐年最低气温数据(Shp\Excel\免费获取)
气象数据是在各项研究中都经常使用的数据,气象指标包括气温、风速、降水、湿度等指标!说到气象数据,最详细的气象数据是具体到气象监测站点的数据! 有关气象指标的监测站点数据,之前我们分享过1929-2024年全球气象站点…...
app版本控制java后端接口版本管理
java api version 版本控制 java接口版本管理 1 自定义 AppVersionHandleMapping 自定义AppVersionHandleMapping实现RequestMappingHandlerMapping里面的方法 public class AppVersionHandleMapping extends RequestMappingHandlerMapping {Overrideprotected RequestCondit…...
2024年度总结-CSDN
2024年CSDN年度总结 Author:OnceDay Date:2025年1月21日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 文章目录 2024年CSDN年度总结1. 整体回顾2…...
基于python的博客系统设计与实现
摘要:目前,对于信息的获取是十分的重要,我们要做到的不是裹足不前,而是应该主动获取和共享给所有人。博客系统就能够实现信息获取与分享的功能,博主在发表文章后,互联网上的其他用户便可以看到,…...
服务器日志自动上传到阿里云OSS备份
背景 公司服务器磁盘空间有限,只能存近15天日志,但是有时需要查看几个月前的日志,需要将服务器日志定时备份到某个地方,需要查询的时候有地方可查。 针对这个问题,想到3个解决方法: 1、买一个配置比较低…...
优化使用 Flask 构建视频转 GIF 工具
优化使用 Flask 构建视频转 GIF 工具 优化后的项目概述 在优化后的版本中,我们将实现以下功能: 可设置每个 GIF 的帧率和大小:用户可以选择 GIF 的帧率和输出大小。改进的用户界面:使用更现代的设计使界面更美观、整洁。自定义…...
leetcode:511. 游戏玩法分析 I
难度:简单 SQL Schema > Pandas Schema > 活动表 Activity: ----------------------- | Column Name | Type | ----------------------- | player_id | int | | device_id | int | | event_date | date | | games_playe…...
windows git bash 使用zsh 并集成 oh my zsh
参考了 这篇文章 进行配置,记录了自己的踩坑过程,并增加了 zsh-autosuggestions 插件的集成。 主要步骤: 1. git bash 这个就不说了,自己去网上下,windows 使用git时候 命令行基本都有它。 主要也是用它不方便&…...
【Python运维】Python与网络监控:如何编写网络探测与流量分析工具
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 随着互联网技术的快速发展,网络性能的监控与分析成为保障信息系统稳定运行的关键环节。本文深入探讨了如何利用Python语言构建高效的网络探…...
OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数,用于处理未校准的立体图像对。该函…...
字玩FontPlayer开发笔记12 Vue3撤销重做功能
字玩FontPlayer开发笔记12 Vue3撤销重做功能 字玩FontPlayer是笔者开源的一款字体设计工具,使用Vue3 ElementUI开发,源代码:github | gitee 笔记 撤销重做功能是设计工具必不可少的模块,以前尝试使用成熟的库实现撤销重做功能…...
无人机图传模块:深入理解其工作原理与实际效用
无人机图传模块作为无人机系统的关键组成部分,承担着将无人机拍摄的图像和视频实时传输至地面控制站或接收设备的重任。本文将深入探讨无人机图传模块的工作原理及其在实际应用中的效用,帮助读者更好地理解这一技术的奥秘。 一、无人机图传模块的工作原…...
PDF文件提取开源工具调研总结
概述 PDF是一种日常工作中广泛使用的跨平台文档格式,常常包含丰富的内容:包括文本、图表、表格、公式、图像。在现代信息处理工作流中发挥了重要的作用,尤其是RAG项目中,通过将非结构化数据转化为结构化和可访问的信息࿰…...
Linux(Centos 7.6)命令详解:dos2unix
1.命令作用 将Windows格式文件件转换为Unix、Linux格式的文件(也可以转换成其他格式的) 2.命令语法 Usage: dos2unix [options] [file ...] [-n infile outfile ...] 3.参数详解 options: -c, --convmode,转换方式,支持ascii, 7bit, iso, mac,默认…...
梯度提升决策树树(GBDT)公式推导
### 逻辑回归的损失函数 逻辑回归模型用于分类问题,其输出是一个概率值。对于二分类问题,逻辑回归模型的输出可以表示为: \[ P(y 1 | x) \frac{1}{1 e^{-F(x)}} \] 其中 \( F(x) \) 是一个线性组合函数,通常表示为ÿ…...
跨域问题分析及解决方案
1、跨域 指的是浏览器不能执行其他网站的脚本。它是由浏览器的同源策略造成的,是浏览器对javascript施加的安全限制。 2、同源策略:是指协议,域名,端口都要相同,其中有一个不同都会产生跨域; 3、跨域流程…...
【三国游戏——贪心、排序】
题目 代码 #include <bits/stdc.h> using namespace std; using ll long long; const int N 1e510; int a[N], b[N], c[N]; int w[4][N]; int main() {int n;cin >> n;for(int i 1; i < n; i)cin >> a[i];for(int i 1; i < n; i)cin >> b[i…...
深入理解 Java 的数据类型与运算符
Java学习资料 Java学习资料 Java学习资料 在 Java 编程中,数据类型与运算符是构建程序的基础元素。它们决定了数据在程序中的存储方式以及如何对数据进行各种操作。 一、数据类型 (一)基本数据类型 整型: 用于存储整数数值&…...
WOA-CNN-GRU-Attention、CNN-GRU-Attention、WOA-CNN-GRU、CNN-GRU四模型对比多变量时序预测
WOA-CNN-GRU-Attention、CNN-GRU-Attention、WOA-CNN-GRU、CNN-GRU四模型对比多变量时序预测 目录 WOA-CNN-GRU-Attention、CNN-GRU-Attention、WOA-CNN-GRU、CNN-GRU四模型对比多变量时序预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 基于WOA-CNN-GRU-Attention、…...
(二叉树)
我们今天就开始引进一个新的数据结构了:我们所熟知的:二叉树; 但是我们在引进二叉树之前我们先了解一下树; 树 树的概念和结构: 树是⼀种⾮线性的数据结构,它是由 n ( n>0 ) …...