当前位置: 首页 > news >正文

Swin Transformer模型详解(附pytorch实现)

写在前面

Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。

Swin Transformer的最大创新之一是其引入了“平移窗口”机制,克服了传统自注意力方法在大图像处理时计算资源消耗过大的问题。这一机制使得模型能够在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。

在本文中,我们将通过详细的分析,介绍Swin Transformer的模型结构、核心思想及其实现,最后提供一个基于PyTorch的简单实现。

论文地址:https://arxiv.org/pdf/2103.14030

官方代码实现:https://github.com/microsoft/Swin-Transformer

Swin网络结构

如下图所示,Swin Transformer的Encoder采用分层的方式,通过多个阶段(Stage)逐渐减少特征图的分辨率,同时增加特征维度。每个Stage包含若干个Transformer Block。

每个Block通常由以下几个部分组成:

  • Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式减少了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
  • Shifted Window:为了增强不同窗口之间的联系,Swin Transformer在每一层的Block中采用了“窗口位移”策略。每一层中的窗口会偏移一定的步长,使得窗口之间的重叠区域增加,从而促进信息交流。

Patch Partition

Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相当于是VIT模型当中的 Patch Embedding。

from functools import partialimport torch
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPathLayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output,_ ,_ = patch_partition(x)print(f"Output shape: {output.shape}")

Patch Merging

PatchMerging 这一层用于将输入的特征图进行下采样,类似于卷积神经网络中的池化层。

如果图像的高度或宽度是奇数,PatchMerging 会进行填充,使得其变为偶数。这是因为下采样操作需要将图像分割为以2为步长的区域。如果图像的高度或宽度是奇数,直接进行切片会导致不均匀的分割,因此需要填充以保证每个块的大小一致。

这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特征,将合并后的张量重新调整形状,新的空间分辨率是原来的一半(H/2 和 W/2)。

class PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xif __name__=="__main__":batch_size = 1in_channels = 3height, width = 30, 32patch_size = 4embed_dim = 96x = torch.randn(batch_size, in_channels, height, width)patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)output, Wh, Ww = patch_partition(x)patch_merging = PatchMerging(dim=embed_dim)output = patch_merging(output, Wh, Ww)print(output.shape)

在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。

W-MSA

W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。

\Omega (MSA) = 4hwC^{2}+2(hw)^{2}C

\Omega (W$-$MSA) = 4hwC^{2} + 2(M)^{2}hwC

这是原论文当中给出的计算公式,h,w和C分别表示特征的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模型中,自注意力机制需要对整个输入进行计算,这使得计算和内存的消耗随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。

W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地减少了计算和内存开销,同时保持了模型的表示能力。

SW-MSA

SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的基础上,保持对全局信息的建模能力。

左侧就是刚刚说到的W-MSA,经过窗口的偏移变成了右边的SW-MSA,偏移的策略能够让模型在每一层的计算中捕捉到不同窗口之间的依赖关系,避免了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就能够通过偏移和交错的方式进行交流,增强了模型的全局感知能力。

但是,现在的窗口从原来的四个变成了九个,如果对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心思想是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。

意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算统一进行批处理,而不是一个一个地处理。这样可以显著减少计算时间和内存占用。

这个过程我个人感觉比较像是卡诺图,具体的过程可以看我下面画的图:

然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是如果这样做了就会将不相邻的信息混合在一起了。作者这里采用掩蔽机制将自注意力计算限制在每个子窗口内,其实就是创建一个蒙板来屏蔽信息。

Relative Position Bias

关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。

关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先需要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包括0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:

[(2 * Wh - 1) * (2 * Ww - 1), num_heads]

每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的基础上,会加入相对位置偏置,调整相似度:

Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))

其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模型可以更好地捕捉到局部结构的依赖关系。

网络实现

"""
Copyright (c) 2025, Auorui.
All rights reserved.Swin Transformer: Hierarchical Vision Transformer using Shifted Windows<https://arxiv.org/pdf/2103.14030>
use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.pyhttps://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
"""
from functools import partialimport torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
from pyzjr.nn.models.bricks.initer import trunc_normal_LayerNorm = partial(nn.LayerNorm, eps=1e-6)class PatchPartition(nn.Module):def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):super().__init__()self.patch_size = to_2tuple(patch_size)self.embed_dim = embed_dimself.proj = nn.Conv2d(in_channels, self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shapeif H % self.patch_size[0] != 0:pad_h = self.patch_size[0] - H % self.patch_size[0]x = F.pad(x, (0, 0, 0, pad_h))if W % self.patch_size[1] != 0:pad_w = self.patch_size[1] - W % self.patch_size[1]x = F.pad(x, (0, pad_w, 0, 0))x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]Wh, Ww = x.shape[2:]x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]# Linear Embeddingx = self.norm(x)# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)return x, Wh, Wwclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop_ratio)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass PatchMerging(nn.Module):def __init__(self, dim, norm_layer=LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)if H % 2 == 1 or W % 2 == 1:x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xclass WindowAttention(nn.Module):"""Window based multi-head self attention (W-MSA) module with relative position bias.It supports shifted and non-shifted windows."""def __init__(self,dim,window_size,num_heads,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,):super().__init__()self.dim = dimself.window_size = to_2tuple(window_size)win_h, win_w = self.window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))   # [2*Wh-1 * 2*Ww-1, nHeads]   Offset Range: -Wh+1, Wh-1self.register_buffer("relative_position_index",self.get_relative_position_index(win_h, win_w), persistent=False)trunc_normal_(self.relative_position_bias_table, std=.02)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attention_dropout_ratio)self.proj = nn.Linear(dim, dim, bias=proj_bias)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def get_relative_position_index(self, win_h: int, win_w: int):# get pair-wise relative position index for each token inside the windowcoords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij'))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += win_h - 1  # shift to start from 0relative_coords[:, :, 1] += win_w - 1relative_coords[:, :, 0] *= 2 * win_w - 1return relative_coords.sum(-1)  # Wh*Ww, Wh*Wwdef forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[:3]q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的windowArgs:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W: int):"""将一个个window还原成一个feature mapArgs:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window size(M)H (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block."""mlp_ratio = 4def __init__(self,dim,num_heads,window_size=7,shift_size=0,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_ratio=0.,norm_layer=LayerNorm,act_layer=nn.GELU,):super(SwinTransformerBlock, self).__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeassert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim,window_size=self.window_size,num_heads=num_heads,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * self.mlp_ratio)self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)self.H = Noneself.W = Nonedef forward(self, x, mask_matrix):"""Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature.mask_matrix: Attention mask for cyclic shift."""B, L, C = x.shapeH, W = self.H, self.Wassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window sizepad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))attn_mask = mask_matrixelse:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xif pad_r > 0 or pad_b > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass BasicLayer(nn.Module):""" A basic Swin Transformer layer for one stage."""def __init__(self,dim,num_layers,num_heads,drop_path,window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,norm_layer=LayerNorm,act_layer=nn.GELU,downsample=None):super().__init__()self.window_size = window_sizeself.shift_size = window_size // 2self.num_layers = num_layers# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,act_layer=act_layer)for i in range(num_layers)])# patch merging layerif downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef forward(self, x, H, W):""" Forward function.Args:x: Input feature, tensor size (B, H*W, C).H, W: Spatial resolution of the input feature."""# calculate attention mask for SW-MSAHp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_sizeimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))for blk in self.blocks:blk.H, blk.W = H, Wx = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwinTransformer(nn.Module):""" Swin Transformer backbone."""def __init__(self,patch_size=4,in_channels=3,num_classes=1000,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),window_size=7,qkv_bias=True,proj_bias=True,attention_dropout_ratio=0.,proj_drop=0.,drop_path_rate=0.2,norm_layer=LayerNorm,patch_norm=True,):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4输出特征矩阵的channelsself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))# split image into non-overlapping patchesself.patch_embed = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=proj_drop)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]layers = []for i_layer in range(self.num_layers):layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),num_layers=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout_ratio=attention_dropout_ratio,proj_drop=proj_drop,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,)layers.append(layer)self.layers = nn.Sequential(*layers)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# x: [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x)  # [B, L, C]x = self.avgpool(x.transpose(1, 2))x = torch.flatten(x, 1)x = self.head(x)return xdef swin_t(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_s(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 18, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes)return modeldef swin_b(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes)return modeldef swin_l(num_classes) -> SwinTransformer:model = SwinTransformer(in_channels=3,patch_size=4,window_size=7,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),num_classes=num_classes)return modelif __name__=="__main__":import pyzjrdevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = swin_l(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)pyzjr.summary_1(net, input_size=(3, 224, 224))# swin_t Total params: 27,499,108# swin_s Total params: 48,792,676# swin_b Total params: 86,683,780# swin_l Total params: 194,906,308

参考文章

Swin-Transformer网络结构详解_swin transformer-CSDN博客

Swin-transformer详解_swin transformer-CSDN博客 

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客 

推荐的视频:12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili 

相关文章:

Swin Transformer模型详解(附pytorch实现)

写在前面 Swin Transformer&#xff08;Shifted Window Transformer&#xff09;是一种新颖的视觉Transformer模型&#xff0c;在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制&#xff0c;显著改善了Vision Transformer&#xff08;ViT&#xf…...

opencv进行人脸识别环境搭建

1. 构建人脸识别环境 1) 下载安装opencv 下载地址&#xff1a;Releases - OpenCV 参考博文&#xff1a;OpenCV下载安装教程&#xff08;Windows&#xff09;-CSDN博客 下载对应系统的opencv&#xff0c;如windows版&#xff0c;opencv-4.5.5-vc14_vc15.exe 2) 然后解压缩到…...

java小灶课详解:关于char和string的区别和对应的详细操作

char和string的区别与操作详解 在编程语言中&#xff0c;char和string是用于处理字符和字符串的两种重要数据类型。它们在存储、操作和应用场景上存在显著差异。本文将从以下几个方面详细解析两者的区别及常见操作。 1. 基本定义与存储差异 char&#xff1a; 定义&#xff1a;…...

计算机网络之---RIP协议

RIP协议的作用 RIP (Routing Information Protocol) 协议是一个基于距离矢量的路由协议&#xff0c;它在网络中用来动态地交换路由信息。RIP 是最早的路由协议之一&#xff0c;通常用于小型和中型网络中。它的工作原理简单&#xff0c;易于实现&#xff0c;但在一些大型网络中效…...

F#语言的文件操作

F#语言的文件操作 F#是一种功能性编程语言&#xff0c;运行在.NET平台上&#xff0c;特别适合处理并发和复杂的数据处理任务。在这篇文章中&#xff0c;我们将介绍F#语言中的文件操作&#xff0c;包括读取、写入和管理文件的基本方法。通过实例来帮助理解&#xff0c;适合初学…...

微信小程序开发设置支持scss文件

在微信小程序开发中&#xff0c;默认是不支持scss文件的&#xff0c;创建文件的时候&#xff0c;css文件默认创建的是wxss后缀结尾的&#xff0c;但是用习惯了scss的怎么办呢&#xff1f; 首先找到project.config.json文件&#xff0c;打开文件在setting下设置useCompilerPlug…...

【Excel笔记_3】execl的单元格是#DIV/0!,判断如果是这个,则该单元格等于空

在 Excel 中&#xff0c;可以使用 IF 函数来判断单元格是否是 #DIV/0! 错误&#xff0c;并将其替换为空值&#xff08;即空字符串 ""&#xff09;。具体公式如下&#xff1a; IF(ISERROR(A1), "", A1)或者&#xff0c;如果只想判断 #DIV/0! 错误&#xff…...

51单片机入门基础

目录 一、基础知识储备 &#xff08;一&#xff09;了解51单片机的基本概念 &#xff08;二&#xff09;掌握数字电路基础 &#xff08;三&#xff09;学习C语言编程基础 二、开发环境搭建 &#xff08;一&#xff09;硬件准备 &#xff08;二&#xff09;软件准备 三、…...

设计模式 行为型 访问者模式(Visitor Pattern)与 常见技术框架应用 解析

访问者模式&#xff08;Visitor Pattern&#xff09;是一种行为设计模式&#xff0c;它允许你在不改变元素类的前提下定义作用于这些元素的新操作。这种模式将算法与对象结构分离&#xff0c;使得可以独立地变化那些保存在复杂对象结构中的元素的操作。 假设我们有一个复杂的对…...

stable diffusion 量化学习笔记

文章目录 一、一些tensorRT背景及使用介绍1&#xff09;深度学习介绍2&#xff09;TensorRT优化策略介绍3&#xff09;TensorRT基础使用流程4&#xff09;dynamic shape 模式5&#xff09;TensorRT模型转换 二、TensorRT转onnx模型1&#xff09;onnx介绍2&#xff09;背景知识&…...

金融项目实战 04|JMeter实现自动化脚本接口测试及持续集成

目录 一、⾃动化测试理论 二、自动化脚本 1、添加断言 1️⃣注册、登录 2️⃣认证、充值、开户、投资 2、可重复执行&#xff1a;清除测试数据脚本按指定顺序执行 1️⃣如何可以做到可重复执⾏&#xff1f; 2️⃣清除测试数据&#xff1a;连接数据库setup线程组 ①明确…...

无需昂贵GPU:本地部署开源AI项目LocalAI在消费级硬件上运行大模型

无需昂贵GPU&#xff1a;本地部署开源AI项目LocalAI在消费级硬件上运行大模型 随着人工智能技术的快速发展&#xff0c;越来越多的AI模型被广泛应用于各个领域。然而&#xff0c;运行这些模型通常需要高性能的硬件支持&#xff0c;特别是GPU&#xff08;图形处理器&#xff09…...

selenium学习笔记

一.搭建环境 1.安装chrome #下载chrome wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb#安装chrome apt --fix-broken install ./google-chrome-stable_current_amd64.deb2.安装chromedriver 首先先查看版本&#xff1a;google-chrome --…...

SOME/IP协议详解 基础解读 涵盖SOME/IP协议解析 SOME/IP通讯机制 协议特点 错误处理机制

车载以太网协议栈总共可划分为五层&#xff0c;分别为物理层&#xff0c;数据链路层&#xff0c;网络层&#xff0c;传输层&#xff0c;应用层&#xff0c;其中今天所要介绍的内容SOME/IP就是一种应用层协议。 SOME/IP协议内容按照AUTOSAR中的描述&#xff0c;我们可以更进一步…...

nginx 实现 正向代理、反向代理 、SSL(证书配置)、负载均衡 、虚拟域名 ,使用其他中间件监控

我们可以详细地配置 Nginx 来实现正向代理、反向代理、SSL、负载均衡和虚拟域名。同时&#xff0c;我会介绍如何使用一些中间件来监控 Nginx 的状态和性能。 1. 安装 Nginx 如果你还没有安装 Nginx&#xff0c;可以通过以下命令进行安装&#xff08;以 Ubuntu 为例&#xff0…...

基于单片机的智能花卉浇水系统的设计与实现

摘要&#xff1a; 随着人们生活水平的不断提高&#xff0c;生活节奏也越来越快。人们经常忽视办公室或者家居的花卉&#xff0c;忘记浇水。本文设计了一种基于单片机的智能浇水系统。目的是解决养殖花卉的人忘记浇水的问题。本系统以单片机AT89S52为控制芯片&#xff0c;能够按…...

《使用 YOLOV8 和 KerasCV 进行高效目标检测》

《使用 YOLOV8 和 KerasCV 进行高效目标检测》 作者&#xff1a;Gitesh Chawda创建日期&#xff1a;2023/06/26最后修改时间&#xff1a;2023/06/26描述&#xff1a;使用 KerasCV 训练自定义 YOLOV8 对象检测模型。 &#xff08;i&#xff09; 此示例使用 Keras 2 在 Colab 中…...

【Domain Generalization(3)】领域泛化与文生图之 -- QUOTA 任意领域中的生成物体的数量可控

系列文章目录 【Domain Generalization(1)】增量学习/在线学习/持续学习/迁移学习/多任务学习/元学习/领域适应/领域泛化概念理解第一篇了解了 DG 的概念&#xff0c;那么接下来将介绍 DG 近年在文生图中的相关应用/代表性工作。【Domain Generalization(2)】领域泛化在文生图…...

qml XmlListModel详解

1、概述 XmlListModel是QtQuick用于从XML数据创建只读模型的组件。它可以作为各种view元素的数据源&#xff0c;比如ListView、GridView、PathView等&#xff1b;也可以作为其他和model交互的元素的数据源。通过XmlRole定义角色&#xff0c;如name、age和height&#xff0c;并…...

CAPL如何设置TCP/IP传输层动态端口范围

在TCP/IP协议中,应用程序通过传输层协议TCP/UDP传输数据,接收方传输层收到数据后,根据传输层端口号把接收的数据上交给正确的应用程序。我们可以简单地认为传输层端口号是应用程序的标识,这就是为什么我们说应用程序在使用TCP/IP协议通信时要打开传输层端口号或者绑定端口号…...

Pandas常用数据类型

扩展库pandas常用的数据结构如下&#xff1a; &#xff08;1&#xff09;Series&#xff1a;带标签的一维数组 &#xff08;2&#xff09;DatetimeIndes&#xff1a;时间序列 &#xff08;3&#xff09;DateFrame&#xff1a;带标签且大小可变的二维表格结构 &#xff08;4…...

【AI大模型】BERT GPT ELMo模型的对比

目录 &#x1f354; BERT, GPT, ELMo之间的不同点 &#x1f354; BERT, GPT, ELMo各自的优点和缺点 &#x1f354; 小结 学习目标 理解BERT, GPT, ELMo相互间的不同点理解BERT, GPT, ELMo相互比较下的各自优点和缺点 &#x1f354; BERT, GPT, ELMo之间的不同点 关于特征提取…...

探索AGI:智能助手与自我赋能的新时代

目录 1 AGI1.1 DeepMind Levels&#xff08;2023年11月)1.2 OpenAI Levels&#xff08;2024年7月&#xff09;1.3 对比与总结1.4 AGI可能诞生哪里 2 基于AI的智能自动化助手2.1 通用型大模型2.2 专业的Agent和模型工具开发框架2.3 编程与代码生成助手2.4 视频和多模态生成2.5 商…...

Oracle Dataguard(主库为双节点集群)配置详解(5):将主库复制到备库并启动同步

Oracle Dataguard&#xff08;主库为双节点集群&#xff09;配置详解&#xff08;5&#xff09;&#xff1a;将主库复制到备库并启动同步 目录 Oracle Dataguard&#xff08;主库为双节点集群&#xff09;配置详解&#xff08;5&#xff09;&#xff1a;将主库复制到备库并启动…...

webrtc自适应分辨率的设置

DegradationPreference 是一个枚举类&#xff0c;用于在视频编码或实时通信&#xff08;如 WebRTC&#xff09;中指定系统资源不足时如何处理质量下降的策略。以下是该枚举类的中文解释&#xff1a; enum class DegradationPreference {// 禁用&#xff1a;不根据资源过载信号…...

提供的 IP 地址 10.0.0.5 和子网掩码位 /26 来计算相关的网络信息

网络和IP地址计算器 https://www.sojson.com/convert/subnetmask.html提供的 IP 地址 10.0.0.5 和子网掩码位 /26 来计算相关的网络信息。 子网掩码转换 子网掩码 /26 的含义二进制表示:/26 表示前 26 位是网络部分&#xff0c;剩下的 6 位是主机部分。对应的子网掩码为 255…...

WPF系列八:图形控件Path

简介 Path控件支持一种称为路径迷你语言&#xff08;Path Mini-Language&#xff09;的紧凑字符串格式&#xff0c;用于描述复杂的几何图形。这种语言通过一系列命令字母和坐标来定义路径上的点和线段&#xff0c;最终绘制出想要的图形。 绘制任意形状&#xff1a;可以用来绘…...

如何移除git中被跟踪的commit文件

忽略已被跟踪的文件 问题描述 如果某个文件已经被 Git 跟踪&#xff08;即已被提交到仓库&#xff09;&#xff0c;即使后来将其添加到 .gitignore 文件中&#xff0c;Git 仍会继续跟踪它。 解决方案 更新 .gitignore 文件 将需要忽略的文件加入 .gitignore&#xff1a; .env…...

15. C语言 函数指针与回调函数

本章目录: 前言什么是函数指针&#xff1f;定义声明方式 函数指针的基本用法示例&#xff1a;最大值函数输出示例&#xff1a; 回调函数与函数指针什么是回调函数&#xff1f;通俗解释 示例&#xff1a;回调函数实现动态数组填充输出示例&#xff1a; 进一步探索&#xff1a;带…...

tomcat12启动流程源码分析

信息: Server.服务器版本: Apache Tomcat/12.0.x-dev 信息: Java虚拟机版本: 21下载源码https://github.com/apache/tomcat&#xff0c;并用idea打开&#xff0c;配置ant编译插件&#xff0c;或者使用我的代码 启动脚本是/bin/startup.bat&#xff0c;内部又执行了bin\cata…...

Pycharm 使用教程

一、基本配置 1. 切换Python解释器 pycharm切换解释器版本 2. pycharm虚拟环境配置 虚拟环境的目的&#xff1a;创建适用于该项目的环境&#xff0c;与系统环境隔离&#xff0c;防止污染系统环境&#xff08;包括需要的库&#xff09;虚拟环境配置存放在项目根目录下的 ven…...

数据仓库: 9- 数据仓库数据治理

目录 9- 数据治理9.1 数据标准化9.1.1 数据标准化的定义9.1.2 数据标准化的重要性9.1.3 数据标准化的主要内容9.1.4 数据标准化的实施步骤9.1.5 数据标准化常用工具9.1.6 数据标准化的挑战与应对策略9.1.7 案例分析9.1.8 总结 9.2 主数据管理(MDM)9.2.1 主数据管理的核心目标9.…...

Kutools for Excel 简体中文版 - 官方正版授权

Kutools for Excel 是一款超棒的 Excel 插件&#xff0c;就像给你的 Excel 加了个超能助手。它有 300 多种实用功能&#xff0c;现在还有 AI 帮忙&#xff0c;能把复杂的任务变简单&#xff0c;重复的事儿也能自动搞定&#xff0c;不管是新手还是老手都能用得顺手。有了它&…...

回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测

回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测 目录 回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测。 程序设计 完整代码&#xff1a;回…...

lerna使用指南

lerna版本 以下所有配置命令都是基于v8.1.9&#xff0c;lerna v5 v7版本差别较大&#xff0c;在使用时&#xff0c;注意自身的lerna版本。 lerna开启缓存及缓存配置 nx缓存是v5版本以后才有的&#xff0c;小于该版本的无法使用该功能。 初始化配置 缓存配置文件nx.json&am…...

LightGCN:为推荐系统简化图卷积网络的创新之作

LightGCN: Simplifying and Powering Graph Convolution Network for RecommendationSIGIR2020Collaborative Filtering, Recommendation, Embedding Propagation, Graph Neural Network &#x1f31f; 研究背景 在信息爆炸的互联网时代&#xff0c;个性化推荐系统成为缓解信…...

【图像去噪】论文精读:High-Quality Self-Supervised Deep Image Denoising(HQ-SSL)

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中) 文章目录 前言Abstract1 Introduction2 Convoluti…...

Elasticsarch:使用全文搜索在 ES|QL 中进行过滤 - 8.17

8.17 在 ES|QL 中引入了 match 和 qstr 函数&#xff0c;可用于执行全文过滤。本文介绍了它们的作用、使用方法、与现有文本过滤方法的区别、当前的限制以及未来的改进。 ES|QL 现在包含全文函数&#xff0c;可用于使用文本查询过滤数据。我们将回顾可用的文本过滤方法&#xf…...

17.C语言输入输出函数详解:从缓存原理到常用函数用法

目录 1.前言2.缓存和字节流3.printf4.scanf5.sscanf6.getchar与putchar7.puts与gets 1.前言 本篇原文为&#xff1a;C语言输入输出函数详解&#xff1a;从缓存原理到常用函数用法。 更多C进阶、rust、python、逆向等等教程&#xff0c;可点击此链接查看&#xff1a;酷程网 C…...

高等数学学习笔记 ☞ 不定积分与积分公式

1. 不定积分的定义 1. 原函数与导函数的定义&#xff1a; 若函数可导&#xff0c;且&#xff0c;则称函数是函数的一个原函数&#xff0c;函数是函数的导函数。 备注&#xff1a; ①&#xff1a;若函数是连续的&#xff0c;则函数一定存在原函数&#xff0c;反之不对。 ②&…...

Debye-Einstein-模型拟合比热容Python脚本

固体比热模型中的德拜模型和爱因斯坦模型是固体物理学中用于估算固体热容的两种重要原子振动模型。 爱因斯坦模型基于三种假设&#xff1a;1.晶格中的每一个原子都是三维量子谐振子&#xff1b;2.原子不互相作用&#xff1b;3.所有的原子都以相同的频率振动&#xff08;与德拜…...

Ubuntu24.04安装AppImage报错AppImages require FUSE to run.

报错如下&#xff1a; 解决&#xff1a; sudo apt install libfuse2t64如果不行&#xff1a; sudo add-apt-repository universe sudo apt install libfuse2t64安装时又报错&#xff1a; [10354:0109/100149.571068:FATAL:setuid_sandbox_host.cc(158)] The SUID sandbox hel…...

3_CSS3 渐变 --[CSS3 进阶之路]

CSS3 引入了渐变&#xff08;gradients&#xff09;&#xff0c;它允许在两个或多个指定的颜色之间显示平滑的过渡。CSS3 支持两种类型的渐变&#xff1a; 线性渐变&#xff08;Linear Gradients&#xff09;&#xff1a;颜色沿着一条线性路径变化&#xff0c;可以是水平、垂直…...

uniapp 左右滑动切换Tab

各种开发会遇到很多奇葩的需求&#xff0c;今天这个是在页面 左右滑动&#xff0c;然后自动去切换Tab <viewtouchstart"touchStart"touchcancel"touchCancel"touchend"touchEnd"><components is"xxx"/></view>//---…...

STM32 FreeRTOS 任务创建和删除实验(动态方法)

动态创建,堆栈是在FreeRTOS管理的堆内存里,注意任务不要重复创建。 xxxxx_STACK_SIZE 128 uxTaskGetStackHighWaterMark()获取指定任务的任务栈的历史剩余最小值,根据这个结果适当调整启动任务的大小。 实验目标 学会 xTaskCreate( ) 和 vTaskDelete( ) 的使用: start_…...

宝塔面板 申请证书后 仍然提示不安全

证书显示有效&#xff0c;但是网站显示不安全 导致的原因是引入静态文件使用的是HTTP&#xff0c;查看方法为F12打开console控制台 可以看到静态文件全部都是HTTP 网站采用wordpress搭建&#xff0c;基于问题解决&#xff0c;其他方式搭建也是一样&#xff0c;处理掉所有的H…...

透明部署、旁路逻辑串联的区别

背景 需讨论防火墙到底是串联&#xff0c;还是旁挂。 通常串联指的就是“透明部署”&#xff0c;旁挂指的就是“逻辑串联”。 透明部署&#xff08;串联&#xff09; 也称为透明模式或桥接模式&#xff0c;是一种安全设备的部署方式。在这种模式下&#xff0c;安全设备被串联…...

C++实现设计模式---原型模式 (Prototype)

原型模式 (Prototype) 原型模式 是一种创建型设计模式&#xff0c;它通过复制现有对象来创建新对象&#xff0c;而不是通过实例化。 意图 使用原型实例指定要创建的对象类型&#xff0c;并通过复制该原型来生成新对象。提供一种高效创建对象的方式&#xff0c;尤其是当对象的…...

Canvas简历编辑器-选中绘制与拖拽多选交互方案

Canvas简历编辑器-选中绘制与拖拽多选交互方案 在之前我们聊了聊如何基于Canvas与基本事件组合实现了轻量级DOM&#xff0c;并且在此基础上实现了如何进行管理事件以及多层级渲染的能力设计。那么此时我们就依然在轻量级DOM的基础上&#xff0c;关注于实现选中绘制与拖拽多选交…...

kotlin的dagger hilt依赖注入

依赖注入&#xff08;dependency injection, di&#xff09;是设计模式的一种&#xff0c;它的实际作用是给对象赋予实例变量。 基础认识 class MainActivity : ComponentActivity() {override fun onCreate(savedInstanceState: Bundle?) {super.onCreate(savedInstanceSta…...