当前位置: 首页 > news >正文

YK人工智能(六)——万字长文学会基于Torch模型网络可视化

1. 可视化网络结构

随着深度神经网络做的的发展,网络的结构越来越复杂,我们也很难确定每一层的输入结构,输出结构以及参数等信息,这样导致我们很难在短时间内完成debug。因此掌握一个可以用来可视化网络结构的工具是十分有必要的。类似的功能在另一个深度学习库Keras中可以调用一个叫做model.summary()的API来很方便地实现,调用后就会显示我们的模型参数,输入大小,输出大小,模型的整体参数等,但是在PyTorch中没有这样一种便利的工具帮助我们可视化我们的模型结构。

为了解决这个问题,人们开发了torchinfo工具包 ( torchinfo是由torchsummary和torchsummaryX重构出的库) 。本节我们将介绍如何使用torchinfo来可视化网络结构。

7.1.1 使用print函数打印模型基础信息

我们使用ResNet18的结构进行展示。

import torchvision.models as models
model = models.resnet18()

通过上面的两步,我们就得到resnet18的模型结构。在学习torchinfo之前,让我们先看下直接print(model)的结果。

print(model)
ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=512, out_features=1000, bias=True)
)

我们可以发现单纯的print(model),只能得出基础构件的信息,既不能显示出每一层的shape,也不能显示对应参数量的大小,为了解决这些问题,我们就需要介绍出我们今天的主人公torchinfo

1.2 使用torchinfo可视化网络结构

  • torchinfo的安装
安装方法一
pip install torchinfo 
安装方法二
conda install -c conda-forge torchinfo
  • torchinfo的使用

trochinfo的使用也是十分简单,我们只需要使用torchinfo.summary()就行了,必需的参数分别是model,input_size[batch_size,channel,h,w],更多参数可以参考documentation,下面让我们一起通过一个实例进行学习。

import torchvision.models as models
from torchinfo import summary
resnet18 = models.resnet18() # 实例化模型
summary(resnet18, (1, 3, 224, 224)) # 1:batch_size 3:图片的通道数 224: 图片的高宽
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-13                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-15                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-16                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-17            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-18             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-19                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-20                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-21            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-22                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-23                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-24            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-26                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-27            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-28                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-29                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-30            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-31             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-32                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-33                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-34            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-35                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-36                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-37            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-38                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-39                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-40            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-41                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-42                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-43            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-44             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-45                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-46                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-47            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-48                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-51                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11
==========================================================================================

我们可以看到torchinfo提供了更加详细的信息,包括模块信息(每一层的类型、输出shape和参数量)、模型整体的参数量、模型大小、一次前向或者反向传播需要的内存大小等

注意

但你使用的是colab或者jupyter notebook时,想要实现该方法,summary()一定是该单元(即notebook中的cell)的返回值,否则我们就需要使用print(summary(...))来可视化。

2 CNN可视化

卷积神经网络(CNN)是深度学习中非常重要的模型结构,它广泛地用于图像处理,极大地提升了模型表现,推动了计算机视觉的发展和进步。但CNN是一个“黑盒模型”,人们并不知道CNN是如何获得较好表现的,由此带来了深度学习的可解释性问题。如果能理解CNN工作的方式,人们不仅能够解释所获得的结果,提升模型的鲁棒性,而且还能有针对性地改进CNN的结构以获得进一步的效果提升。

理解CNN的重要一步是可视化,包括可视化特征是如何提取的、提取到的特征的形式以及模型在输入数据上的关注点等。本节我们就从上述三个方面出发,介绍如何在PyTorch的框架下完成CNN模型的可视化。

经过本节的学习,你将收获:

  • 可视化CNN卷积核的方法
  • 可视化CNN特征图的方法
  • 可视化CNN显著图(class activation map)的方法

2.1 CNN卷积核可视化

卷积核在CNN中负责提取特征,可视化卷积核能够帮助人们理解CNN各个层在提取什么样的特征,进而理解模型的工作原理。例如在Zeiler和Fergus 2013年的paper中就研究了CNN各个层的卷积核的不同,他们发现靠近输入的层提取的特征是相对简单的结构,而靠近输出的层提取的特征就和图中的实体形状相近了,如下图所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在PyTorch中可视化卷积核也非常方便,核心在于特定层的卷积核即特定层的模型权重,可视化卷积核就等价于可视化对应的权重矩阵。下面给出在PyTorch中可视化卷积核的实现方案,以torchvision自带的VGG11模型为例。

首先加载模型,并确定模型的层信息:

import torch
from torchvision.models import vgg11model = vgg11(pretrained=True)
print(dict(model.features.named_children()))
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG11_Weights.IMAGENET1K_V1`. You can also use `weights=VGG11_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg){'0': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '1': ReLU(inplace=True), '2': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), '3': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '4': ReLU(inplace=True), '5': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), '6': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '7': ReLU(inplace=True), '8': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '9': ReLU(inplace=True), '10': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), '11': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '12': ReLU(inplace=True), '13': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '14': ReLU(inplace=True), '15': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), '16': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '17': ReLU(inplace=True), '18': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), '19': ReLU(inplace=True), '20': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)}
{'0': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'1': ReLU(inplace=True),'2': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),'3': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'4': ReLU(inplace=True),'5': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),'6': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'7': ReLU(inplace=True),'8': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'9': ReLU(inplace=True),'10': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),'11': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'12': ReLU(inplace=True),'13': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'14': ReLU(inplace=True),'15': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),'16': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'17': ReLU(inplace=True),'18': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),'19': ReLU(inplace=True),'20': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)}

卷积核对应的应为卷积层(Conv2d),这里以第“3”层为例,可视化对应的参数:

import matplotlib.pyplot as pltconv1 = dict(model.features.named_children())['3']
kernel_set = conv1.weight.detach()
num = len(conv1.weight.detach())
print(kernel_set.shape)
for i in range(0,num):i_kernel = kernel_set[i]plt.figure(figsize=(20, 17))if (len(i_kernel)) > 1:for idx, filer in enumerate(i_kernel):plt.subplot(9, 9, idx+1) plt.axis('off')plt.imshow(filer[ :, :].detach(),cmap='bwr')
torch.Size([128, 64, 3, 3])/var/folders/z7/ll9p_xgn76l2f7pqtx3c44jr0000gn/T/ipykernel_51575/1203810060.py:9: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.plt.figure(figsize=(20, 17))

torch.Size([128, 64, 3, 3])

由于第“3”层的特征图由64维变为128维,因此共有128*64个卷积核,其中部分卷积核可视化效果如下图所示:

在这里插入图片描述

2.2 CNN特征图可视化方法

与卷积核相对应,输入的原始图像经过每次卷积层得到的数据称为特征图,可视化卷积核是为了看模型提取哪些特征,可视化特征图则是为了看模型提取到的特征是什么样子的。

获取特征图的方法有很多种,可以从输入开始,逐层做前向传播,直到想要的特征图处将其返回。尽管这种方法可行,但是有些麻烦了。在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图,这个接口的名称非常形象,叫做hook。可以想象这样的场景,数据通过网络向前传播,网络某一层我们预先设置了一个钩子,数据传播过后钩子上会留下数据在这一层的样子,读取钩子的信息就是这一层的特征图。具体实现如下:

class Hook(object):def __init__(self):self.module_name = []self.features_in_hook = []self.features_out_hook = []def __call__(self,module, fea_in, fea_out):print("hooker working", self)self.module_name.append(module.__class__)self.features_in_hook.append(fea_in)self.features_out_hook.append(fea_out)return Nonedef plot_feature(model, idx, inputs):hh = Hook()model.features[idx].register_forward_hook(hh)# forward_model(model,False)model.eval()_ = model(inputs)print(hh.module_name)print((hh.features_in_hook[0][0].shape))print((hh.features_out_hook[0].shape))out1 = hh.features_out_hook[0]total_ft  = out1.shape[1]first_item = out1[0].cpu().clone()    plt.figure(figsize=(20, 17))for ftidx in range(total_ft):if ftidx > 99:breakft = first_item[ftidx]plt.subplot(10, 10, ftidx+1) plt.axis('off')#plt.imshow(ft[ :, :].detach(),cmap='gray')plt.imshow(ft[ :, :].detach())

这里我们首先实现了一个hook类,之后在plot_feature函数中,将该hook类的对象注册到要进行可视化的网络的某层中。model在进行前向传播的时候会调用hook的__call__函数,我们也就是在那里存储了当前层的输入和输出。这里的features_out_hook 是一个list,每次前向传播一次,都是调用一次,也就是features_out_hook 长度会增加1。

2.3 CNN class activation map可视化方法

class activation map (CAM)的作用是判断哪些变量对模型来说是重要的,在CNN可视化的场景下,即判断图像中哪些像素点对预测结果是重要的。除了确定重要的像素点,人们也会对重要区域的梯度感兴趣,因此在CAM的基础上也进一步改进得到了Grad-CAM(以及诸多变种)。CAM和Grad-CAM的示例如下图所示:

在这里插入图片描述

相比可视化卷积核与可视化特征图,CAM系列可视化更为直观,能够一目了然地确定重要区域,进而进行可解释性分析或模型优化改进。CAM系列操作的实现可以通过开源工具包pytorch-grad-cam来实现。

  • 安装
pip install grad-cam
  • 一个简单的例子
import torch
from torchvision.models import vgg11,resnet18,resnet101,resnext101_32x8d
import matplotlib.pyplot as plt
from PIL import Image
import numpy as npmodel = vgg11(pretrained=True)
img_path = './dog.png'
# resize操作是为了和传入神经网络训练图片大小一致
img = Image.open(img_path).resize((224,224))
# 需要将原始图片转为np.float32格式并且在0-1之间 
rgb_img = np.float32(img)/255
plt.imshow(img)

在这里插入图片描述

from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image# 将图片转为tensor
img_tensor = torch.from_numpy(rgb_img).permute(2,0,1).unsqueeze(0)target_layers = [model.features[-1]]
# 选取合适的类激活图,但是ScoreCAM和AblationCAM需要batch_size
cam = GradCAM(model=model,target_layers=target_layers)
targets = [ClassifierOutputTarget(preds)]   
# 上方preds需要设定,比如ImageNet有1000类,这里可以设为200
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
print(type(cam_img))
Image.fromarray(cam_img)

在这里插入图片描述

2.4 使用FlashTorch快速实现CNN可视化

聪明的你可能要问了,已经2024年了,难道还要我们手把手去写各种CNN可视化的代码吗?答案当然是否定的。随着PyTorch社区的努力,目前已经有不少开源工具能够帮助我们快速实现CNN可视化。这里我们介绍其中的一个——FlashTorch。

(注:使用中发现该package对环境有要求,如果下方代码运行报错,请参考作者给出的配置或者Colab运行环境:https://github.com/MisaOgura/flashtorch/issues/39)

  • 安装
pip install flashtorch
import torchvision.models as models
from flashtorch.activmax import GradientAscentmodel = models.vgg16(pretrained=True)
g_ascent = GradientAscent(model.features)# specify layer and filter info
conv5_1 = model.features[24]
conv5_1_filters = [45, 271, 363, 489]g_ascent.visualize(conv5_1, conv5_1_filters, title="VGG16: conv5_1")

在这里插入图片描述

3. 使用TensorBoard可视化训练过程

训练过程的可视化在深度学习模型训练中扮演着重要的角色。学习的过程是一个优化的过程,我们需要找到最优的点作为训练过程的输出产物。一般来说,我们会结合训练集的损失函数和验证集的损失函数,绘制两条损失函数的曲线来确定训练的终点,找到对应的模型用于测试。那么除了记录训练中每个epoch的loss值,能否实时观察损失函数曲线的变化,及时捕捉模型的变化呢?

此外,我们也希望可视化其他内容,如输入数据(尤其是图片)、模型结构、参数分布等,这些对于我们在debug中查找问题来源非常重要(比如输入数据和我们想象的是否一致)。

TensorBoard作为一款可视化工具能够满足上面提到的各种需求。TensorBoard由TensorFlow团队开发,最早和TensorFlow配合使用,后来广泛应用于各种深度学习框架的可视化中来。本节我们探索TensorBoard的强大功能,希望帮助读者“从入门到精通”。

经过本节的学习,你将收获:

  • 安装TensorBoard工具
  • 了解TensorBoard可视化的基本逻辑
  • 掌握利用TensorBoard实现训练过程可视化
  • 掌握利用TensorBoard完成其他内容的可视化

3.1 TensorBoard安装

在已安装PyTorch的环境下使用pip安装即可:

pip install tensorboardXpip install torch-tb-profiler

也可以使用PyTorch自带的tensorboard工具,此时不需要额外安装tensorboard。

3.2 TensorBoard可视化的基本逻辑

我们可以将TensorBoard看做一个记录员,它可以记录我们指定的数据,包括模型每一层的feature map,权重,以及训练loss等等。TensorBoard将记录下来的内容保存在一个用户指定的文件夹里,程序不断运行中TensorBoard会不断记录。记录下的内容可以通过网页的形式加以可视化。

3.3 TensorBoard的配置与启动

在使用TensorBoard前,我们需要先指定一个文件夹供TensorBoard保存记录下来的数据。然后调用tensorboard中的SummaryWriter作为上述“记录员”

from tensorboardX import SummaryWriterwriter = SummaryWriter('./runs')

上面的操作实例化SummaryWritter为变量writer,并指定writer的输出目录为当前目录下的"runs"目录。也就是说,之后tensorboard记录下来的内容都会保存在runs。

如果使用PyTorch自带的tensorboard,则采用如下方式import:

from torch.utils.tensorboard import SummaryWriter

这里聪明的你可能发现了,是否可以手动往runs文件夹里添加数据用于可视化,或者把runs文件夹里的数据放到其他机器上可视化呢?答案是可以的。只要数据被记录,你可以将这个数据分享给其他人,其他人在安装了tensorboard的情况下就会看到你分享的数据。

启动tensorboard也很简单,在命令行中输入

tensorboard --logdir=/path/to/logs/ --port=xxxx

其中“path/to/logs/"是指定的保存tensorboard记录结果的文件路径(等价于上面的“./runs",port是外部访问TensorBoard的端口号,可以通过访问ip:port访问tensorboard,这一操作和jupyter notebook的使用类似。如果不是在服务器远程使用的话则不需要配置port。

有时,为了tensorboard能够不断地在后台运行,也可以使用nohup命令或者tmux工具来运行tensorboard。大家可以自行搜索,这里不展开讨论了。

下面,我们将模拟深度学习模型训练过程,来介绍如何利用TensorBoard可视化其中的各个部分。

3.4 TensorBoard模型结构可视化

首先定义模型:

import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))self.flatten = nn.Flatten()self.linear1 = nn.Linear(64,32)self.relu = nn.ReLU()self.linear2 = nn.Linear(32,1)self.sigmoid = nn.Sigmoid()def forward(self,x):x = self.conv1(x)x = self.pool(x)x = self.conv2(x)x = self.pool(x)x = self.adaptive_pool(x)x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)y = self.sigmoid(x)return ymodel = Net()
print(model)
Net((conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))(adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))(flatten): Flatten(start_dim=1, end_dim=-1)(linear1): Linear(in_features=64, out_features=32, bias=True)(relu): ReLU()(linear2): Linear(in_features=32, out_features=1, bias=True)(sigmoid): Sigmoid()
)

可视化模型的思路和7.1中介绍的方法一样,都是给定一个输入数据,前向传播后得到模型的结构,再通过TensorBoard进行可视化,使用add_graph:

import torch
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))
writer.close()

展示结果如下(其中框内部分初始会显示为“Net",需要双击后才会展开):

在这里插入图片描述

3.5 TensorBoard图像可视化

当我们做图像相关的任务时,可以方便地将所处理的图片在tensorboard中进行可视化展示。

  • 对于单张图片的显示使用add_image
  • 对于多张图片的显示使用add_images
  • 有时需要使用torchvision.utils.make_grid将多张图片拼成一张图片后,用writer.add_image显示

这里我们使用torchvision的CIFAR10数据集为例:

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform_train = transforms.Compose([transforms.ToTensor()])
transform_test = transforms.Compose([transforms.ToTensor()])train_data = datasets.CIFAR10(".", train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(".", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)images, labels = next(iter(train_loader))# 仅查看一张图片
writer = SummaryWriter('./pytorch_tb')
writer.add_image('images[0]', images[0])
writer.close()# 将多张图片拼接成一张图片,中间用黑色网格分割
# create grid of images
writer = SummaryWriter('./pytorch_tb')
img_grid = torchvision.utils.make_grid(images)
writer.add_image('image_grid', img_grid)
writer.close()# 将多张图片直接写入
writer = SummaryWriter('./pytorch_tb')
writer.add_images("images",images,global_step = 0)
writer.close()
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz100%|██████████| 170498071/170498071 [00:42<00:00, 4037492.39it/s]Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified

依次运行上面三组可视化(注意不要同时在notebook的一个单元格内运行),得到的可视化结果如下(最后运行的结果在最上面):

在这里插入图片描述

在这里插入图片描述

另外注意上方menu部分,刚刚只有“GRAPHS"栏对应模型的可视化,现在则多出了”IMAGES“栏对应图像的可视化。左侧的滑动按钮可以调整图像的亮度和对比度。

此外,除了可视化原始图像,TensorBoard提供的可视化方案自然也适用于我们在Python中用matplotlib等工具绘制的其他图像,用于展示分析结果等内容。

3.6 TensorBoard连续变量可视化

TensorBoard可以用来可视化连续变量(或时序变量)的变化过程,通过add_scalar实现:

writer = SummaryWriter('./pytorch_tb')
for i in range(500):x = iy = x**2writer.add_scalar("x", x, i) #日志中记录x在第step i 的值writer.add_scalar("y", y, i) #日志中记录y在第step i 的值
writer.close()

可视化结果如下:

在这里插入图片描述

如果想在同一张图中显示多个曲线,则需要分别建立存放子路径(使用SummaryWriter指定路径即可自动创建,但需要在tensorboard运行目录下),同时在add_scalar中修改曲线的标签使其一致即可:

writer1 = SummaryWriter('./pytorch_tb/x')
writer2 = SummaryWriter('./pytorch_tb/y')
for i in range(500):x = iy = x*2writer1.add_scalar("same", x, i) #日志中记录x在第step i 的值writer2.add_scalar("same", y, i) #日志中记录y在第step i 的值
writer1.close()
writer2.close()

在这里插入图片描述

这里也可以用一个writer,但for循环中不断创建SummaryWriter不是一个好选项。此时左下角的Runs部分出现了勾选项,我们可以选择我们想要可视化的曲线。曲线名称对应存放子路径的名称(这里是x和y)。

这部分功能非常适合损失函数的可视化,可以帮助我们更加直观地了解模型的训练情况,从而确定最佳的checkpoint。左侧的Smoothing滑动按钮可以调整曲线的平滑度,当损失函数震荡较大时,将Smoothing调大有助于观察loss的整体变化趋势。

3.7 TensorBoard参数分布可视化

当我们需要对参数(或向量)的变化,或者对其分布进行研究时,可以方便地用TensorBoard来进行可视化,通过add_histogram实现。下面给出一个例子:

import torch
import numpy as np# 创建正态分布的张量模拟参数矩阵
def norm(mean, std):t = std * torch.randn((100, 20)) + meanreturn twriter = SummaryWriter('./pytorch_tb/')
for step, mean in enumerate(range(-10, 10, 1)):w = norm(mean, 1)writer.add_histogram("w", w, step)writer.flush()
writer.close()

结果如下:

在这里插入图片描述

3.8 服务器端使用TensorBoard

一般情况下,我们会连接远程的服务器来对模型进行训练,但由于服务器端是没有浏览器的(纯命令模式),因此,我们需要进行相应的配置,才可以在本地浏览器,使用tensorboard查看服务器运行的训练过程。
本文提供以下几种方式进行,其中前两种方法都是建立SSH隧道,实现远程端口到本机端口的转发,最后一种方法适用于没有下载Xshell等SSH连接工具的用户

  • MobaXterm

    1. 在MobaXterm点击Tunneling
    2. 选择New SSH tunnel,我们会出现以下界面。

    在这里插入图片描述

    1. 对新建的SSH通道做以下设置,第一栏我们选择Local port forwarding< Remote Server>我们填写localhost< Remote port>填写6006,tensorboard默认会在6006端口进行显示,我们也可以根据 tensorboard --logdir=/path/to/logs/ --port=xxxx的命令中的port进行修改,< SSH server> 填写我们连接服务器的ip地址,<SSH login>填写我们连接的服务器的用户名,<SSH port>填写端口号(通常为22),< forwarded port>填写的是本地的一个端口号,以便我们后面可以对其进行访问。
    2. 设定好之后,点击Save,然后Start。在启动tensorboard,这样我们就可以在本地的浏览器输入http://localhost:6006/对其进行访问了
  • Xshell

    1. Xshell的连接方法与MobaXterm的连接方式本质上是一样的,具体操作如下:

    2. 连接上服务器后,打开当前会话属性,会出现下图,我们选择隧道,点击添加
      在这里插入图片描述

    3. 按照下方图进行选择,其中目标主机代表的是服务器,源主机代表的是本地,端口的选择根据实际情况而定。
      在这里插入图片描述

    4. 启动tensorboard,在本地127.0.0.1:6006 或者 localhost:6006进行访问。

  • SSH

    1. 该方法是将服务器的6006端口重定向到自己机器上来,我们可以在本地的终端里输入以下代码:其中16006代表映射到本地的端口,6006代表的是服务器上的端口。
      ssh -L 16006:127.0.0.1:6006 username@remote_server_ip
    
    1. 在服务上使用默认的6006端口正常启动tensorboard
    tensorboard --logdir=xxx --port=6006
    
    1. 在本地的浏览器输入地址
    127.0.0.1:16006 或者 localhost:16006
    

3.9 总结

对于TensorBoard来说,它的功能是很强大的,可以记录的东西不只限于本节所介绍的范围。

主要的实现方案是构建一个SummaryWriter,然后通过add_XXX()函数来实现。

其实TensorBoard的逻辑还是很简单的,它的基本逻辑就是文件的读写逻辑,写入想要可视化的数据,然后TensorBoard自己会读出来。

4 使用wandb可视化训练过程

我们使用了Tensorboard可视化训练过程,但是Tensorboard对数据的保存仅限于本地,也很难分析超参数不同对实验的影响。wandb的出现很好的解决了这些问题,因此在本章节中,我们将对wandb进行简要介绍。
wandb是Weights & Biases的缩写,它能够自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。目前它能够和Jupyter、TensorFlow、Pytorch、Keras、Scikit、fast.ai、LightGBM、XGBoost一起结合使用。

经过本节的学习,你将收获:

  • wandb的安装
  • wandb的使用
  • demo演示

4.1 wandb的安装

wandb的安装非常简单,我们只需要使用pip安装即可。

pip install wandb

安装完成后,我们需要在官网注册一个账号并复制下自己的API keys,然后在本地使用下面的命令登录。

!wandb login
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!

这时,我们会看到下面的界面,只需要粘贴你的API keys即可。
在这里插入图片描述

4.2 wandb的使用

wandb的使用也非常简单,只需要在代码中添加几行代码即可。

这里的project和entity是你在wandb上创建的项目名称和用户名,如果你还没有创建项目,可以参考官方文档。

4.3 demo演示

下面我们使用一个CIFAR10的图像分类demo来演示wandb的使用。


import random  # to set the python random seed
import numpy  # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings('ignore')

使用wandb的第一步是初始化wandb,这里我们使用wandb.init()函数来初始化wandb,其中project是你在wandb上创建的项目名称,name是你的实验名称。

# 初始化wandb
import wandb
wandb.init(project="aa",name="liuyangokay",settings=wandb.Settings(init_timeout=180))

Finishing last run (ID:bj4wnti0) before initializing another…

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced---------------------------------------------------------------------------KeyboardInterrupt                         Traceback (most recent call last)Cell In[17], line 31 # 初始化wandb2 import wandb
----> 3 wandb.init(project="aa",4            name="liuyangokay",settings=wandb.Settings(init_timeout=180))File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:1256, in init(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, fork_from, resume_from, settings)1254     wi = _WandbInit()1255     wi.setup(kwargs)
-> 1256     return wi.init()1258 except KeyboardInterrupt as e:1259     if logger is not None:File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:654, in _WandbInit.init(self)649 if jupyter and not self.settings.silent:650     ipython.display_html(651         f"Finishing last run (ID:{latest_run._run_id}) before initializing another..."652     )
--> 654 latest_run.finish()656 if jupyter and not self.settings.silent:657     ipython.display_html(658         f"Successfully finished last run (ID:{latest_run._run_id}). Initializing new run:<br/>"659     )File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:451, in _run_decorator._noop.<locals>.wrapper(self, *args, **kwargs)448         wandb.termwarn(message, repeat=False)449         return cls.Dummy()
--> 451 return func(self, *args, **kwargs)File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:393, in _run_decorator._attach.<locals>.wrapper(self, *args, **kwargs)391         raise e392     cls._is_attaching = ""
--> 393 return func(self, *args, **kwargs)File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:2139, in Run.finish(self, exit_code, quiet)2127 @_run_decorator._noop2128 @_run_decorator._attach2129 def finish(self, exit_code: int | None = None, quiet: bool | None = None) -> None:2130     """Mark a run as finished, and finish uploading all data.2131 2132     This is used when creating multiple runs in the same process. We automatically(...)2137         quiet: Set to true to minimize log output2138     """
-> 2139     return self._finish(exit_code, quiet)File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:2173, in Run._finish(self, exit_code, quiet)2170 self._is_finished = True2172 try:
-> 2173     self._atexit_cleanup(exit_code=exit_code)2175     # Run hooks that should happen after the last messages to the2176     # internal service, like detaching the logger.2177     for hook in self._teardown_hooks:File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:2426, in Run._atexit_cleanup(self, exit_code)2423         os.remove(self._settings.resume_fname)2425 try:
-> 2426     self._on_finish()2428 except KeyboardInterrupt:2429     if not wandb.wandb_agent._is_running():  # type: ignoreFile ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/wandb_run.py:2676, in Run._on_finish(self)2669 exit_handle.add_probe(on_probe=self._on_probe_exit)2671 with progress.progress_printer(2672     self._printer,2673     self._settings,2674 ) as progress_printer:2675     # Wait for the run to complete.
-> 2676     _ = exit_handle.wait(2677         timeout=-1,2678         on_progress=functools.partial(2679             self._on_progress_exit,2680             progress_printer,2681         ),2682     )2684 # Print some final statistics.2685 poll_exit_handle = self._backend.interface.deliver_poll_exit()File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/lib/mailbox.py:279, in MailboxHandle.wait(self, timeout, on_probe, on_progress, release, cancel)276     if self._interface._transport_keepalive_failed():277         raise MailboxError("transport failed")
--> 279 found, abandoned = self._slot._get_and_clear(timeout=wait_timeout)280 if found:281     # Always update progress to 100% when done282     if on_progress and progress_handle and progress_sent:File ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/lib/mailbox.py:126, in _MailboxSlot._get_and_clear(self, timeout)124 def _get_and_clear(self, timeout: float) -> Tuple[Optional[pb.Result], bool]:125     found = None
--> 126     if self._wait(timeout=timeout):127         with self._lock:128             found = self._resultFile ~/miniconda3/envs/llm/lib/python3.10/site-packages/wandb/sdk/lib/mailbox.py:122, in _MailboxSlot._wait(self, timeout)121 def _wait(self, timeout: float) -> bool:
--> 122     return self._event.wait(timeout=timeout)File ~/miniconda3/envs/llm/lib/python3.10/threading.py:607, in Event.wait(self, timeout)605 signaled = self._flag606 if not signaled:
--> 607     signaled = self._cond.wait(timeout)608 return signaledFile ~/miniconda3/envs/llm/lib/python3.10/threading.py:324, in Condition.wait(self, timeout)322 else:323     if timeout > 0:
--> 324         gotit = waiter.acquire(True, timeout)325     else:326         gotit = waiter.acquire(False)KeyboardInterrupt: 

使用wandb的第二步是设置超参数,这里我们使用wandb.config来设置超参数,这样我们就可以在wandb的界面上看到超参数的变化。wandb.config的使用方法和字典类似,我们可以使用config.key的方式来设置超参数。

# 超参数设置
config = wandb.config  # config的初始化
config.batch_size = 64  
config.test_batch_size = 10 
config.epochs = 5  
config.lr = 0.01 
config.momentum = 0.1  
config.use_cuda = False  
config.seed = 2043  
config.log_interval = 10 # 设置随机数
def set_seed(seed):random.seed(config.seed)      torch.manual_seed(config.seed) numpy.random.seed(config.seed) 

第三步是构建训练和测试的pipeline,这里我们使用pytorch的CIFAR10数据集和resnet18来构建训练和测试的pipeline。

def train(model, device, train_loader, optimizer):model.train()for batch_id, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)criterion = nn.CrossEntropyLoss()loss = criterion(output, target)loss.backward()optimizer.step()# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):model.eval()test_loss = 0correct = 0example_images = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)criterion = nn.CrossEntropyLoss()test_loss += criterion(output, target).item()pred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()example_images.append(wandb.Image(data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))# 使用wandb.log 记录你想记录的指标wandb.log({"Examples": example_images,"Test Accuracy": 100. * correct / len(test_loader.dataset),"Test Loss": test_loss})wandb.watch_called = False def main():use_cuda = config.use_cuda and torch.cuda.is_available()device = torch.device("cuda:0" if use_cuda else "cpu")kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}# 设置随机数set_seed(config.seed)torch.backends.cudnn.deterministic = True# 数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据train_loader = DataLoader(datasets.CIFAR10(root='dataset',train=True,download=True,transform=transform), batch_size=config.batch_size, shuffle=True, **kwargs)test_loader = DataLoader(datasets.CIFAR10(root='dataset',train=False,download=True,transform=transform), batch_size=config.batch_size, shuffle=False, **kwargs)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')model = resnet18(pretrained=True).to(device)optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)wandb.watch(model, log="all")for epoch in range(1, config.epochs + 1):train(model, device, train_loader, optimizer)test(model, device, test_loader, classes)# 本地和云端模型保存torch.save(model.state_dict(), 'model.pth')wandb.save('model.pth')if __name__ == '__main__':main()
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to dataset/cifar-10-python.tar.gz100%|██████████| 170498071/170498071 [00:27<00:00, 6180102.72it/s]Extracting dataset/cifar-10-python.tar.gz to dataset
Files already downloaded and verifiedWARNING:root:Only 108 Image will be uploaded.
WARNING:root:Only 108 Image will be uploaded.
WARNING:root:Only 108 Image will be uploaded.
WARNING:root:Only 108 Image will be uploaded.
WARNING:root:Only 108 Image will be uploaded.

当我们运行完上面的代码后,我们就可以在wandb的界面上看到我们的训练结果了和系统的性能指标。同时,我们还可以在setting里面设置训练完给我们发送邮件,这样我们就可以在训练完之后及时的查看训练结果了。
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

我们可以发现,使用wandb可以很方便的记录我们的训练结果,除此之外,wandb还为我们提供了很多的功能,比如:模型的超参数搜索,模型的版本控制,模型的部署等等。这些功能都可以帮助我们更好的管理我们的模型,更好的进行模型的迭代和优化。这些功能我们在后面的更新中会进行介绍。

参考资料

  1. https://blog.csdn.net/Python_Ai_Road/article/details/107704530
  2. https://andrewhuman.github.io/cnn-hidden-layout_search
  3. https://cloud.tencent.com/developer/article/1747222
  4. https://github.com/jacobgil/pytorch-grad-cam
  5. https://github.com/MisaOgura/flashtorch

相关文章:

YK人工智能(六)——万字长文学会基于Torch模型网络可视化

1. 可视化网络结构 随着深度神经网络做的的发展&#xff0c;网络的结构越来越复杂&#xff0c;我们也很难确定每一层的输入结构&#xff0c;输出结构以及参数等信息&#xff0c;这样导致我们很难在短时间内完成debug。因此掌握一个可以用来可视化网络结构的工具是十分有必要的…...

对象的实例化、内存布局与访问定位

一、创建对象的方式 二、创建对象的步骤: 一、判断对象对应的类是否加载、链接、初始化: 虚拟机遇到一条new指令&#xff0c;首先去检查这个指令的参数能否在Metaspace的常量池中定位到一个类的符号引用&#xff0c;并且检查这个符号引用代表的类是否已经被加载、解析和初始化…...

Java面试题2025-并发编程进阶(线程池和并发容器类)

线程池 一、什么是线程池 为什么要使用线程池 在开发中&#xff0c;为了提升效率的操作&#xff0c;我们需要将一些业务采用多线程的方式去执行。 比如有一个比较大的任务&#xff0c;可以将任务分成几块&#xff0c;分别交给几个线程去执行&#xff0c;最终做一个汇总就可…...

Vue Router 客户端路由解决方案:axios 响应拦截(跳转到登录页面)

文章目录 引言客户端路由 vs. 服务端路由简单的路由案例术语I Vue Router 提供的组件RouterLinkRouterViewII 创建路由器实例调用 createRouter() 函数创建路由选项III 注册路由器插件通过调用 use() 来完成注册路由器插件的职责对于组合式 API,Vue Router 给我们提供了一些组…...

Redis的通用命令

⭐️前言⭐️ 本文主要介绍Redis的通用命令 &#x1f349;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f349;博主将持续更新学习记录收获&#xff0c;友友们有任何问题可以在评论区留言 &#x1f349;博客中涉及源码及博主日常练习代码均已上传GitHub &#x1f4cd;内容导…...

[Python人工智能] 四十九.PyTorch入门 (4)利用基础模块构建神经网络并实现分类预测

从本专栏开始,作者正式研究Python深度学习、神经网络及人工智能相关知识。前文讲解PyTorch构建回归神经网络。这篇文章将介绍如何利用PyTorch构建神经网络实现分类预测,其是使用基础模块构建。前面我们的Python人工智能主要以TensorFlow和Keras为主,而现在最主流的深度学习框…...

SpringBoot使用 easy-captcha 实现验证码登录功能

文章目录 一、 环境准备1. 解决思路2. 接口文档3. redis下载 二、后端实现1. 引入依赖2. 添加配置3. 后端代码实现4. 前端代码实现 在前后端分离的项目中&#xff0c;登录功能是必不可少的。为了提高安全性&#xff0c;通常会加入验证码验证。 easy-captcha 是一个简单易用的验…...

RabbitMQ 从入门到精通:从工作模式到集群部署实战(一)

#作者&#xff1a;闫乾苓 文章目录 RabbitMQ简介RabbitMQ与VMware的关系架构工作流程RabbitMQ 队列工作模式及适用场景简单队列模式&#xff08;Simple Queue&#xff09;工作队列模式&#xff08;Work Queue&#xff09;发布/订阅模式&#xff08;Publish/Subscribe&#xff…...

Unity中的虚拟相机(Cinemachine)

Unity Cinemachine详解 什么是Cinemachine Cinemachine是Unity官方推出的智能相机系统&#xff0c;它提供了一套完整的工具来创建复杂的相机运动和行为&#xff0c;而无需编写大量代码。它能够大大简化相机管理&#xff0c;提高游戏开发效率。 Cinemachine的主要组件 1. Vi…...

响应式编程_04Spring 5 中的响应式编程技术栈_WebFlux 和 Spring Data Reactive

文章目录 概述响应式Web框架Spring WebFlux响应式数据访问Spring Data Reactive 概述 https://spring.io/reactive 2017 年&#xff0c;Spring 发布了新版本 Spring 5&#xff0c; Spring 5 引入了很多核心功能&#xff0c;这其中重要的就是全面拥抱了响应式编程的设计思想和实…...

网络设备的安全加固

设备的安全始终是信息网络安全的一个重要方面&#xff0c;攻击者往往通过控制网络中设备来破坏系统和信息&#xff0c;或扩大已有的破坏。网络设备包括主机&#xff08;服务器、工作站、PC&#xff09;和网络设施&#xff08;交换机、路由器等&#xff09;。 一般说来&#xff…...

OpenCV:特征检测总结

目录 一、什么是特征检测&#xff1f; 二、OpenCV 中的常见特征检测方法 1. Harris 角点检测 2. Shi-Tomasi 角点检测 3. Canny 边缘检测 4. SIFT&#xff08;尺度不变特征变换&#xff09; 5. ORB 三、特征检测的应用场景 1. 图像匹配 2. 运动检测 3. 自动驾驶 4.…...

Java高频面试之SE-17

hello啊&#xff0c;各位观众姥爷们&#xff01;&#xff01;&#xff01;本牛马baby今天又来了&#xff01;哈哈哈哈哈嗝&#x1f436; Java缓冲区溢出&#xff0c;如何解决&#xff1f; 在 Java 中&#xff0c;缓冲区溢出 (Buffer Overflow) 虽然不是像 C/C 中那样直接可见…...

移动机器人规划控制入门与实践:基于navigation2 学习笔记(一)

课程实践: (1)手写A*代码并且调试,总结优缺点 (2)基于Gazebo仿真,完成给定机器人在给定地图中的导航调试 (3)使用Groot设计自己的导航行为树 掌握一门技术 规划控制概述 常见移动机器人...

每日一题洛谷P5721 【深基4.例6】数字直角三角形c++

#include<iostream> using namespace std; int main() {int n;cin >> n;int t 1;for (int i 0; i < n; i) {for (int j 0; j < n - i; j) {printf("%02d",t);t;}cout << endl;}return 0; }...

RTMP 和 WebRTC

WebRTC(Web Real-Time Communication)和 RTMP(Real-Time Messaging Protocol)是两种完全不同的流媒体协议,设计目标、协议栈、交互流程和应用场景均有显著差异。以下是两者的详细对比,涵盖协议字段、交互流程及核心设计思想。 一、协议栈与设计目标对比 特性RTMPWebRTC传…...

数据库技术基础

1 数据库系统概述 1.1 数据库的4个概念 &#xff08;1&#xff09;数据&#xff08;信息&#xff09; 数据&#xff1a;指已记录或可获取的事实&#xff0c;是数据库存储的最小单元。除文本、数字外&#xff0c;还有图形、图像、声音等。 数据由于能为用户利用才被记录和保…...

如何获取sql数据中时间的月份、年份(类型为date)

可用自带的函数month来实现 如&#xff1a; 创建表及插入数据&#xff1a; create table test (id int,begindate datetime) insert into test values (1,2015-01-01) insert into test values (2,2015-02-01) 执行sql语句,获取月份&#xff1a; select MONTH(begindate)…...

每日Attention学习18——Grouped Attention Gate

模块出处 [ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation 模块名称 Grouped Attention Gate (GAG) 模块作用 轻量特征融合 模块结构 模块特点 特征融合前使用Group…...

分析用户请求K8S里ingress-nginx提供的ingress流量路径

前言 本文是个人的小小见解&#xff0c;欢迎大佬指出我文章的问题&#xff0c;一起讨论进步~ 我个人的疑问点 进入的流量是如何自动判断进入iptables的四表&#xff1f;k8s nodeport模式的原理&#xff1f; 一 本机环境介绍 节点名节点IPK8S版本CNI插件Master192.168.44.1…...

TensorFlow是个啥玩意?

TensorFlow是一个开源的机器学习框架&#xff0c;由Google开发。它可以帮助开发者构建和训练各种机器学习模型&#xff0c;包括神经网络和深度学习模型。TensorFlow的设计理念是使用数据流图来表示计算过程&#xff0c;其中节点表示数学运算&#xff0c;边表示数据流动。 Tens…...

初识C语言、C语言的学习方向总述与入门

目录 1. 什么是C语言&#xff1f; 2. 第一个C语言程序 3. 数据类型 4. 变量、常量 4.1 定义变量的方法 4.2 变量的命名 4.3 变量的分类 4.4 变量的作用域和生命周期 4.5 常量 5. 字符串转义字符注释 5.1 字符串 5.2 转义字符 6. 注释 7. 选择语句 8. 循环语句 …...

零基础学习书生.浦语大模型-入门岛

第一关&#xff1a;Linux基础知识 任务一&#xff1a;Cursor连接SSH运行代码 使用Remote - SSH插件即可 运行指令 python hello_world.py端口映射 ssh -p 46561 rootssh.intern-ai.org.cn -CNg -L 7860:127.0.0.1:7860 -o StrictHostKeyCheckingno 注&#xff1a;46561&a…...

【R语言】获取数据

R语言自带2种数据存储格式&#xff1a;*.RData和*.rds。 这两者的区别是&#xff1a;前者既可以存储数据&#xff0c;也可以存储当前工作空间中的所有变量&#xff0c;属于非标准化存储&#xff1b;后者仅用于存储单个R对象&#xff0c;且存储时可以创建标准化档案&#xff0c…...

MongoDB学习笔记-解析jsonCommand内容

如果需要屏蔽其他项目对MongoDB的直接访问操作&#xff0c;统一由一个入口访问操作MongoDB&#xff0c;可以考虑直接传入jsonCommand语句解析执行。 相关依赖包 <!-- SpringBootDataMongodb依赖包 --> <dependency><groupId>org.springframework.boot</…...

国产编辑器EverEdit - 工具栏说明

1 工具栏 1.1 应用场景 当用户想显示/隐藏界面的标签栏、工具栏、状态栏、主菜单等界面元素时&#xff0c;可以通过EverEdit的菜单选项进行设置。 1.2 使用方法 选择菜单查看 -> 工具栏&#xff0c;在工具栏的子菜单中选择勾选或去掉勾选对应的选项。 标签栏&#xff1…...

linux 函数 sem_init () 信号量、sem_destroy()

&#xff08;1&#xff09; &#xff08;2&#xff09; 代码举例&#xff1a; #include <stdio.h> #include <stdlib.h> #include <pthread.h> #include <semaphore.h> #include <unistd.h>sem_t semaphore;void* thread_function(void* arg) …...

制造业设备状态监控与生产优化实战:基于SQL的序列分析与状态机建模

目录 1. 背景与挑战 2. 数据建模与采集 2.1 数据表设计 设备状态表(记录设备实时状态变更)...

密码学的数学基础1-素数和RSA加密

数学公式推导是密码学的基础, 故开一个新的课题 – 密码学的数学基础系列 素数 / 质数 质数又称素数。 一个大于1的自然数&#xff0c;除了1和它自身外&#xff0c;不能被其他自然数整除的数叫做质数&#xff1b;否则称为合数&#xff08;规定1既不是质数也不是合数&#xff0…...

C++ Primer 算术运算符

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…...

计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

前一篇文章&#xff0c;Tensor 基本操作5 device 管理&#xff0c;使用 GPU 设备 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started PyTorch 计算图和 Autograd 微积分之于机器学习Computational Graphs 计算图Autograd…...

Vue 图片引用方式详解:静态资源与动态路径访问

目录 前言1. 引用 public/ 目录2. assets/ 目录3. 远程服务器4. Vue Router 动态访问5. 总结6. 扩展&#xff08;图片不显示&#xff09; 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF 在 Vue 开发中&#x…...

熟练掌握Http协议

目录 基本概念请求数据Get请求方式和Post请求方式 响应数据响应状态码 基本概念 Http协议全称超文本传输协议(HyperText Transfer Protocol)&#xff0c;是网络通信中应用层的协议&#xff0c;规定了浏览器和web服务器数据传输的格式和规则 Http应用层协议具有以下特点&#…...

爬虫学习笔记之Robots协议相关整理

定义 Robots协议也称作爬虫协议、机器人协议&#xff0c;全名为网络爬虫排除标准&#xff0c;用来告诉爬虫和搜索引擎哪些页面可以爬取、哪些不可以。它通常是一个叫做robots.txt的文本文件&#xff0c;一般放在网站的根目录下。 robots.txt文件的样例 对有所爬虫均生效&#…...

血压计OCR文字检测数据集VOC+YOLO格式2147张11类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;2147 标注数量(xml文件个数)&#xff1a;2147 标注数量(txt文件个数)&#xff1a;2147 …...

正则表达式详细介绍

目录 正则表达式详细介绍什么是正则表达式&#xff1f;元字符转义字符字符类限定字符字符分枝字符分组懒惰匹配和贪婪匹配零宽断言 正则表达式详细介绍 什么是正则表达式&#xff1f; 正则表达式是一组由字母和符号组成的特殊文本&#xff0c;它可以用来从文本中找出满足你想…...

初识ArkTS语言

文章目录 ArkTS是HarmonyOS优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&#xff0c;保持了TS的基本风格&#xff0c;同时通过规范定义强化开发期静态检查和分析&#xff0c;提升程序执行稳定性和性能。 从…...

Go语言并发之美:构建高性能键值存储系统

摘要 本文介绍了基于Go语言实现的高性能并发键值存储系统。通过深入探讨Go语言在并发编程中的优势&#xff0c;文章详细阐述了系统的锁机制、分片优化、内存管理和持久化设计等关键环节。这些设计展示了如何在系统开发中进行有效的权衡&#xff0c;以确保最优性能。该系统不仅充…...

6. k8s二进制集群之各节点部署

获取kubernetes源码安装主节点&#xff08;分别执行以下各节点命令&#xff09;安装工作节点&#xff08;同步kebelet和kube-proxy到各工作节点&#xff09;总结 继续上一篇文章《k8s二进制集群之ETCD集群部署》下面介绍一下各节点的部署与配置。 获取kubernetes源码 https:/…...

从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架

目录 前言 环境介绍 代码与动机 架构设计&#xff0c;优缺点 博客系列指引 前言 笔者前段时间花费了一周&#xff0c;整理了一下自从TM1637开始打算的&#xff0c;使用OLED来搭建一个通用的显示库的一个工程。笔者的OLED库已经开源到Github上了&#xff0c;地址在&#xf…...

spring基础总结

先修知识&#xff1a;依赖注入&#xff0c;反转控制&#xff0c;生命周期 IDEA快捷键 Ctrl Altm:提取方法&#xff0c;设置trycatch 通用快捷键&#xff1a; Ctrl F&#xff1a;在当前文件中查找文本。Ctrl R&#xff1a;在当前文件中替换文本。Ctrl Z&#xff1a;撤销…...

基础相对薄弱怎么考研

复习总体规划 明确目标 选择专业和院校&#xff1a;根据你的兴趣、职业规划和自身实力&#xff0c;选择适合自己的专业和院校。可以参考往年的分数线、报录比、复试难度等。了解考试科目&#xff1a;不同专业考试科目不同&#xff0c;一般包括&#xff1a; 公共课&#xff1a…...

代码随想录36 动态规划

leetcode 343.整数拆分 给定一个正整数 n &#xff0c;将其拆分为 k 个 正整数 的和&#xff08; k > 2 &#xff09;&#xff0c;并使这些整数的乘积最大化。 返回 你可以获得的最大乘积 。 示例 1: 输入: n 2 输出: 1 解释: 2 1 1, 1 1 1。 示例 2: 输入: n 1…...

p5r预告信生成器API

p5r预告信生成器API 本人将js生成的p5r预告信使用go语言进行了重写和部署&#xff0c;并开放了其api&#xff0c;可以直接通过get方法获取预告信的png。 快速开始 http://api.viogami.tech/p5cc/:text eg: http://api.viogami.tech/p5cc/persona5 感谢p5r风格字体的制作者和…...

React图标库: 使用React Icons实现定制化图标效果

React图标库: 使用React Icons实现定制化图标效果 图标库介绍 是一个专门为React应用设计的图标库&#xff0c;它包含了丰富的图标集合&#xff0c;覆盖了常用的图标类型&#xff0c;如FontAwesome、Material Design等。React Icons可以让开发者在React应用中轻松地添加、定制各…...

说说Redis的内存淘汰策略?

大家好&#xff0c;我是锋哥。今天分享关于【说说Redis的内存淘汰策略?】面试题。希望对大家有帮助&#xff1b; 说说Redis的内存淘汰策略? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Redis 提供了多种内存淘汰策略&#xff0c;用于在内存达到限制时决定如何…...

【C语言】自定义类型讲解

文章目录 一、前言二、结构体2.1 概念2.2 定义2.2.1 通常情况下的定义2.2.2 匿名结构体 2.3 结构体的自引用和嵌套2.4 结构体变量的定义与初始化2.5 结构体的内存对齐2.6 结构体传参2.7 结构体实现位段 三、枚举3.1 概念3.2 定义3.3 枚举的优点3.3.1 提高代码的可读性3.3.2 防止…...

机器学习8-卷积和卷积核1

机器学习8-卷积和卷积核1 卷积与图像去噪卷积的定义与性质定义性质卷积的原理卷积步骤卷积的示例与应用卷积的优缺点优点缺点 总结 高斯卷积核卷积核尺寸的设置依据任务类型考虑数据特性实验与调优 高斯函数标准差的设置依据平滑需求结合卷积核尺寸实际应用场景 总结 图像噪声与…...

3、C#基于.net framework的应用开发实战编程 - 实现(三、三) - 编程手把手系列文章...

三、 实现&#xff1b; 三&#xff0e;三、编写应用程序&#xff1b; 此文主要是实现应用的主要编码工作。 1、 分层&#xff1b; 此例子主要分为UI、Helper、DAL等层。UI负责便签的界面显示&#xff1b;Helper主要是链接UI和数据库操作的中间层&#xff1b;DAL为对数据库的操…...

PHP 中 `foreach` 循环结合引用使用时可能出现的问题

问题背景 假设你有如下 PHP 代码&#xff1a; <?php $arr array(1, 2, 3, 4);// 使用引用遍历并修改数组元素 foreach ($arr as &$value) {$value $value * 2; } // 此时 $arr 变为 array(2, 4, 6, 8)// 再使用非引用方式遍历数组 foreach ($arr as $key > $val…...