当前位置: 首页 > news >正文

CrossFormer实战:使用CrossFormer实现图像分类任务(一)

摘要

CrossFormer是一种新型的视觉Transformer架构,旨在通过引入跨尺度注意力机制来提升计算机视觉任务的性能。该模型特别关注不同尺度特征之间的交互,解决了现有视觉Transformer在处理多尺度特征时的不足。

在这里插入图片描述

研究背景

在计算机视觉中,特征的多尺度性对于理解和处理图像至关重要。然而,许多现有的视觉Transformer模型未能有效利用这些跨尺度特征,主要原因包括:

  • 输入嵌入在每一层都是相同尺度的,缺乏跨尺度特征。
  • 一些模型为了降低计算成本,牺牲了小尺度特征。

核心创新

CrossFormer提出了以下关键组件,以解决上述问题:

  • Cross-scale Embedding Layer (CEL)

    • CEL通过将每个嵌入与多个不同尺度的图像块混合,提供了跨尺度特征。这使得自注意力模块能够接收到多尺度的信息,从而增强模型的表达能力。
  • Long Short Distance Attention (LSDA)

    • LSDA将自注意力模块分为短距离和长距离两个部分。这种设计不仅降低了计算负担,还保留了小尺度和大尺度特征,使得模型在处理复杂视觉任务时更加高效。
  • Dynamic Position Bias (DPB)

    • DPB模块使得相对位置偏差能够适应可变大小的图像,增强了模型的灵活性。

本文使用CrossFormer模型实现图像分类任务,模型选择tiny,在植物幼苗分类任务ACC达到了96%+。

请添加图片描述
请添加图片描述

通过深入阅读本文,您将能够掌握以下关键技能与知识:

  1. 数据增强的多种策略:包括利用PyTorch的transforms库进行基本增强,以及进阶技巧如CutOut、MixUp、CutMix等,这些方法能显著提升模型泛化能力。

  2. CrossFormer模型的训练实现:了解如何从头开始构建并训练CrossFormer,涵盖模型定义、数据加载、训练循环等关键环节。

  3. 混合精度训练:学习如何利用PyTorch自带的混合精度训练功能,加速训练过程同时减少内存消耗。

  4. 梯度裁剪技术:掌握梯度裁剪的应用,有效防止梯度爆炸问题,确保训练过程的稳定性。

  5. 分布式数据并行(DP)训练:了解如何在多GPU环境下使用PyTorch的分布式数据并行功能,加速大规模模型训练。

  6. 可视化训练过程:学习如何绘制训练过程中的loss和accuracy曲线,直观监控模型学习状况。

  7. 评估与生成报告:掌握在验证集上评估模型性能的方法,并生成详细的评估报告,包括ACC等指标。

  8. 测试脚本编写:学会编写测试脚本,对测试集进行预测,评估模型在实际应用中的表现。

  9. 学习率调整策略:理解并应用余弦退火策略动态调整学习率,优化训练效果。

  10. 自定义统计工具:使用AverageMeter类或其他工具统计和记录训练过程中的ACC、loss等关键指标,便于后续分析。

  11. 深入理解ACC1与ACC5:掌握图像分类任务中ACC1(Top-1准确率)和ACC5(Top-5准确率)的含义及其计算方法。

  12. 指数移动平均(EMA):学习如何在模型训练中应用EMA技术,进一步提升模型在测试集上的表现。

若您在以上任一领域基础尚浅,感到理解困难,推荐您参考我的专栏“经典主干网络精讲与实战”,该专栏从零开始,循序渐进地讲解上述所有知识点,助您轻松掌握深度学习中的这些核心技能。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装einops,执行命令:

pip install einops

数据增强Cutout和Mixup

为了提高模型的泛化能力和性能,我在数据预处理阶段加入了Cutout和Mixup这两种数据增强技术。Cutout通过随机遮挡图像的一部分来强制模型学习更鲁棒的特征,而Mixup则通过混合两张图像及其标签来生成新的训练样本,从而增加数据的多样性。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)在深度学习中是一种用于模型参数优化的技术,它通过计算参数的指数移动平均值来平滑模型的学习过程。这种方法有助于提高模型的稳定性和泛化能力,特别是在训练后期。以下是关于EMA的总结,表达进行了优化:

EMA概述

EMA是一种加权移动平均技术,其中每个新的平均值都是前一个平均值和当前值的加权和。在深度学习中,EMA被用于模型参数的更新,以减缓参数在训练过程中的快速波动,从而得到更加平滑和稳定的模型表现。

工作原理

在训练过程中,除了维护当前模型的参数外,还额外保存一份EMA参数。每个训练步骤或每隔一定步骤,根据当前模型参数和EMA参数,按照指数衰减的方式更新EMA参数。具体来说,EMA参数的更新公式通常如下:

EMA new = decay × EMA old + ( 1 − decay ) × model_parameters \text{EMA}_{\text{new}} = \text{decay} \times \text{EMA}_{\text{old}} + (1 - \text{decay}) \times \text{model\_parameters} EMAnew=decay×EMAold+(1decay)×model_parameters
其中,decay是一个介于0和1之间的超参数,控制着旧EMA值和新模型参数值之间的权重分配。较大的decay值意味着EMA更新时更多地依赖于旧值,即平滑效果更强。

应用优势

  1. 稳定性:EMA通过平滑参数更新过程,减少了模型在训练过程中的波动,使得模型更加稳定。
  2. 泛化能力:由于EMA参数是历史参数的平滑版本,它往往能捕捉到模型训练过程中的全局趋势,因此在测试或评估时,使用EMA参数往往能获得更好的泛化性能。
  3. 快速收敛:虽然EMA本身不直接加速训练过程,但通过稳定模型参数,它可能间接地帮助模型更快地收敛到更优的解。

使用场景

EMA在深度学习中的使用场景广泛,特别是在需要高度稳定性和良好泛化能力的任务中,如图像分类、目标检测等。在训练大型模型时,EMA尤其有用,因为它可以帮助减少过拟合的风险,并提高模型在未见数据上的表现。

具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

CrossFormer_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─crossformer.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
train.py:训练models文件下DilateFormer的模型
crossformer:来源官方代码,在官方的模型基础上做了一些修改。

计算mean和std

在深度学习中,特别是在处理图像数据时,计算数据的均值(mean)和标准差(standard deviation, std)并进行归一化(Normalization)是加速模型收敛、提高模型性能的关键步骤之一。这里我将详细解释这两个概念,并讨论它们如何帮助模型学习。

均值(Mean)

均值是所有数值加和后除以数值的个数得到的平均值。在图像处理中,我们通常对每个颜色通道(如RGB图像的三个通道)分别计算均值。这意味着,如果我们的数据集包含多张图像,我们会计算所有图像在R通道上的像素值的均值,同样地,我们也会计算G通道和B通道的均值。

标准差(Standard Deviation, Std)

标准差是衡量数据分布离散程度的统计量。它反映了数据点与均值的偏离程度。在计算图像数据的标准差时,我们也是针对每个颜色通道分别进行的。标准差较大的颜色通道意味着该通道上的像素值变化较大,而标准差较小的通道则相对较为稳定。

归一化(Normalization)

归一化是将数据按比例缩放,使之落入一个小的特定区间,通常是[0, 1]或[-1, 1]。在图像处理中,我们通常会使用计算得到的均值和标准差来进行归一化,公式如下:

Normalized Value = Original Value − Mean Std \text{Normalized Value} = \frac{\text{Original Value} - \text{Mean}}{\text{Std}} Normalized Value=StdOriginal ValueMean

注意,在某些情况下,为了简化计算并确保数据非负,我们可能会选择将数据缩放到[0, 1]区间,这时使用的是最大最小值归一化,而不是基于均值和标准差的归一化。但在这里,我们主要讨论基于均值和标准差的归一化,因为它能保留数据的分布特性。

为什么需要归一化?

  1. 加速收敛:归一化后的数据具有相似的尺度,这有助于梯度下降算法更快地找到最优解,因为不同特征的梯度更新将在同一数量级上,从而避免了某些特征因尺度过大或过小而导致的训练缓慢或梯度消失/爆炸问题。

  2. 提高精度:归一化可以改善模型的泛化能力,因为它使得模型更容易学习到特征之间的相对关系,而不是被特征的绝对大小所影响。

  3. 稳定性:归一化后的数据更加稳定,减少了训练过程中的波动,有助于模型更加稳定地收敛。

如何计算和使用mean和std

  1. 计算全局mean和std:在整个数据集上计算mean和std。这通常是在训练开始前进行的,并使用这些值来归一化训练集、验证集和测试集。

  2. 使用库函数:许多深度学习框架(如PyTorch、TensorFlow等)提供了计算mean和std的便捷函数,并可以直接用于数据集的归一化。

  3. 动态调整:在某些情况下,特别是当数据集非常大或持续更新时,可能需要动态地计算mean和std。这通常涉及到在训练过程中使用移动平均(如EMA)来更新这些统计量。

计算并使用数据的mean和std进行归一化是深度学习中的一项基本且重要的预处理步骤,它对于加速模型收敛、提高模型性能和稳定性具有重要意义。新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

CrossFormer代码

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_NEG_INF = -1000000class Mlp(nn.Module):r"""2-layer MLP"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DynamicPosBias(nn.Module):r"""DPB moduleUse a MLP to predict position bias used in attention."""def __init__(self, dim, num_heads, residual):super().__init__()self.residual = residualself.num_heads = num_headsself.pos_dim = dim // 4self.pos_proj = nn.Linear(2, self.pos_dim)self.pos1 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim),)self.pos2 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim))self.pos3 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.num_heads))def forward(self, biases):if self.residual:pos = self.pos_proj(biases)  # 2Wh-1 * 2Ww-1, headspos = pos + self.pos1(pos)pos = pos + self.pos2(pos)pos = self.pos3(pos)else:pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))return posdef flops(self, N):flops = N * 2 * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.num_headsreturn flopsclass Attention(nn.Module):r""" Multi-head self attention module with dynamic position bias.Args:dim (int): Number of input channels.group_size (tuple[int]): The height and width of the group.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,position_bias=True):super().__init__()self.dim = dimself.group_size = group_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.position_bias = position_biasif position_bias:self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)# generate mother-setposition_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2Ww-1biases = biases.flatten(1).transpose(0, 1).float()self.register_buffer("biases", biases, persistent=False)# get pair-wise relative position index for each token inside the groupcoords_h = torch.arange(self.group_size[0])coords_w = torch.arange(self.group_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.group_size[1] - 1relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index, persistent=False)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):"""Args:x: input features with shape of (num_groups*B, N, C)mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scale# @ stands for matrix multiplicationattn = (q @ k.transpose(-2, -1))if self.position_bias:pos = self.pos(self.biases)  # 2Wh-1 * 2Ww-1, heads# select position biasrelative_position_bias = pos[self.relative_position_index.view(-1)].view(self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef extra_repr(self) -> str:return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'def flops(self, N):# calculate flops for 1 group with token length of Nflops = 0# qkv = self.qkv(x)flops += N * self.dim * 3 * self.dim# attn = (q @ k.transpose(-2, -1))flops += self.num_heads * N * (self.dim // self.num_heads) * N#  x = (attn @ v)flops += self.num_heads * N * N * (self.dim // self.num_heads)# x = self.proj(x)flops += N * self.dim * self.dimif self.position_bias:flops += self.pos.flops(N)return flopsclass CrossFormerBlock(nn.Module):r""" CrossFormer Block.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resulotion.num_heads (int): Number of attention heads.group_size (int): Group size.interval (int): Interval for LDA.lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNormnum_patch_sizeimpl_type (str): use_extra_conv (bool): Extra convolution layer. Default: True"""def __init__(self, dim, input_resolution, num_heads, group_size=7, interval=8, lsda_flag=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1,pad_type=0, use_extra_conv=True, use_cpe=False, no_mask=False, adaptive_interval=False):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.group_size = group_sizeself.interval = intervalself.lsda_flag = lsda_flagself.mlp_ratio = mlp_ratioself.num_patch_size = num_patch_sizeself.pad_type = pad_typeself.use_extra_conv = use_extra_convself.use_cpe = use_cpeif min(self.input_resolution) <= self.group_size:# if group size is larger than input resolution, we don't partition groupsself.lsda_flag = 0self.group_size = min(self.input_resolution)self.norm1 = norm_layer(dim)self.attn = Attention(dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,position_bias=(not use_cpe))if self.use_cpe:self.cpe = nn.Conv2d(in_channels=input_resolution[0], out_channels=input_resolution[0], kernel_size=3,padding=1, groups=input_resolution[0])self.norm_cpe = norm_layer(dim)if adaptive_interval:self.interval = int(np.ceil(self.input_resolution[0] / self.group_size))if self.use_extra_conv:self.ex_kernel = [3, 3]padding = (self.ex_kernel[0] - 1) // 2self.ex_conv = nn.Conv2d(dim, dim, self.ex_kernel, padding=padding, groups=dim)self.ex_ln = norm_layer(dim)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim, elementwise_affine=True)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)# compute attention maskattn_mask = Noneif not no_mask:H, W = self.input_resolutionsize_div = self.interval * self.group_size if self.lsda_flag == 1 else self.group_sizepad_w = (size_div - W % size_div) % size_divpad_h = (size_div - H % size_div) % size_divif self.pad_type == 0:pad_l = pad_t = 0else:pad_l = pad_w // 2pad_t = pad_h // 2pad_r = pad_w - pad_lpad_b = pad_h - pad_tHp = H + pad_hWp = W + pad_wmask = torch.zeros((1, Hp, Wp, 1))if pad_h > 0:mask[:, -pad_b:, :, :] = -1mask[:, : pad_t, :, :] = -1if pad_w > 0:mask[:, :, -pad_r:, :] = -1mask[:, :, : pad_l, :] = -1if self.lsda_flag == 0:  # 0 for SDAG = Gh = Gw = self.group_sizenG = Hp * Wp // G ** 2# attn_maskif pad_w > 0 or pad_h > 0:mask = mask.reshape(1, Hp // G, G, Wp // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous()mask = mask.reshape(nG, 1, G * G)attn_mask = torch.zeros((nG, G * G, G * G))attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)else:attn_mask = Noneelse:  # 1 for LDAI = self.intervalG = Gh = Gw = self.group_sizeRh, Rw = Hp // (Gh * I), Wp // (Gw * I)nG = I ** 2 * Rh * Rw# attn_maskif pad_w > 0 or pad_h > 0:mask = mask.reshape(1, Rh, Gh, I, Rw, Gw, I, 1).permute(0, 1, 4, 3, 6, 2, 5, 7).contiguous()mask = mask.reshape(nG, 1, Gh * Gw)attn_mask = torch.zeros((nG, Gh * Gw, Gh * Gw))attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF)else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask, persistent=False)def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)if self.use_cpe:x = x + self.norm_cpe(self.cpe(x))# paddingsize_div = self.interval * self.group_size if self.lsda_flag == 1 else self.group_sizepad_w = (size_div - W % size_div) % size_divpad_h = (size_div - H % size_div) % size_divif self.pad_type == 0:pad_l = pad_t = 0else:pad_l = pad_w // 2pad_t = pad_h // 2pad_r = pad_w - pad_lpad_b = pad_h - pad_tx = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# group embeddingsif self.lsda_flag == 0:  # 0 for SDAG = Gh = Gw = self.group_sizex = x.reshape(B, Hp // G, G, Wp // G, G, C).permute(0, 1, 3, 2, 4, 5)x = x.reshape(B * Hp * Wp // G ** 2, G ** 2, C)else:  # 1 for LDAI = self.intervalG = Gh = Gw = self.group_sizeRh, Rw = Hp // (Gh * I), Wp // (Gw * I)x = x.reshape(B, Rh, Gh, I, Rw, Gw, I, C).permute(0, 1, 4, 3, 6, 2, 5, 7).contiguous()x = x.reshape(B * Rh * Rw * I * I, Gh * Gw, C)# multi-head self-attentionx = self.attn(x, mask=self.attn_mask)  # nW*B, G*G, C# ungroup embeddingsif self.lsda_flag == 0:x = x.reshape(B, Hp // G, Wp // G, G, G, C).permute(0, 1, 3, 2, 4,5).contiguous()  # B, Hp//G, G, Wp//G, G, Celse:x = x.reshape(B, Rh, Rw, I, I, Gh, Gw, C).permute(0, 1, 5, 3, 2, 6, 4,7).contiguous()  # B, Rh, Gh, I, Rw, Gw, I, Cx = x.view(B, Hp, Wp, C)# remove paddingif pad_w > 0 or pad_h > 0:x = x[:, pad_t:H + pad_t, pad_l:W + pad_l, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))if self.use_extra_conv:x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()x = self.ex_conv(x)x = x.permute(0, 2, 3, 1).view(B, H * W, C).contiguous()x = self.ex_ln(x)return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}, " \f"interval={self.interval}"def flops(self):flops = 0H, W = self.input_resolution# norm1flops += self.dim * H * W# LSDAnW = H * W / self.group_size / self.group_sizeflops += nW * self.attn.flops(self.group_size * self.group_size)# mlpflops += 2 * H * W * self.dim * self.dim * self.mlp_ratio# norm2flops += self.dim * H * Wreturn flopsclass PatchMerging(nn.Module):r""" Patch Merging Layer.Args:input_resolution (tuple[int]): Resolution of input feature.dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reductions = nn.ModuleList()self.patch_size = patch_sizeself.norm = norm_layer(dim)for i, ps in enumerate(patch_size):if i == len(patch_size) - 1:out_dim = 2 * dim // 2 ** ielse:out_dim = 2 * dim // 2 ** (i + 1)stride = 2padding = (ps - stride) // 2self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps,stride=stride, padding=padding))def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = self.norm(x)x = x.view(B, H, W, C).permute(0, 3, 1, 2)xs = []for i in range(len(self.reductions)):tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)xs.append(tmp_x)x = torch.cat(xs, dim=2)return xdef extra_repr(self) -> str:return f"input_resolution={self.input_resolution}, dim={self.dim}"def flops(self):H, W = self.input_resolutionflops = H * W * self.dimfor i, ps in enumerate(self.patch_size):if i == len(self.patch_size) - 1:out_dim = 2 * self.dim // 2 ** ielse:out_dim = 2 * self.dim // 2 ** (i + 1)flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dimreturn flopsclass Stage(nn.Module):""" CrossFormer blocks for one stage.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resolution.depth (int): Number of blocks.num_heads (int): Number of attention heads.group_size (int): variable G in the paper, one group has GxG embeddingsmlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, input_resolution, depth, num_heads, group_size, interval,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,patch_size_end=[4], num_patch_size=None, use_cpe=False, pad_type=0,no_mask=False, adaptive_interval=False, use_acl=False):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.depth = depthself.use_checkpoint = use_checkpoint# build blocksself.blocks = nn.ModuleList()for i in range(depth):lsda_flag = 0 if (i % 2 == 0) else 1# use extra convolution block every 3 blocksuse_extra_conv = ((i + 1) % 3 == 0) and (i < depth - 1) and use_aclself.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,num_heads=num_heads, group_size=group_size[i], interval=interval,lsda_flag=lsda_flag,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop, attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer,num_patch_size=num_patch_size,use_extra_conv=use_extra_conv,use_cpe=use_cpe,pad_type=pad_type,no_mask=no_mask,adaptive_interval=adaptive_interval))# patch merging layerif downsample is not None:self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer,patch_size=patch_size_end, num_input_patch_size=num_patch_size)else:self.downsample = Nonedef forward(self, x):for blk in self.blocks:if self.use_checkpoint:x = checkpoint.checkpoint(blk, x)else:x = blk(x)if self.downsample is not None:x = self.downsample(x)return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"def flops(self):flops = 0for blk in self.blocks:flops += blk.flops()if self.downsample is not None:flops += self.downsample.flops()return flopsclass PatchEmbed(nn.Module):r""" Image to Patch EmbeddingArgs:img_size (int): Image size.  Default: 224.patch_size (int): Patch token size. Default: [4].in_chans (int): Number of input image channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size)# patch_size = to_2tuple(patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.projs = nn.ModuleList()for i, ps in enumerate(patch_size):if i == len(patch_size) - 1:dim = embed_dim // 2 ** ielse:dim = embed_dim // 2 ** (i + 1)stride = patch_size[0]padding = (ps - patch_size[0]) // 2self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):B, C, H, W = x.shape# FIXME look at relaxing size constraintsassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."xs = []for i in range(len(self.projs)):tx = self.projs[i](x).flatten(2).transpose(1, 2)xs.append(tx)  # B Ph*Pw Cx = torch.cat(xs, dim=2)if self.norm is not None:x = self.norm(x)return xdef flops(self):Ho, Wo = self.patches_resolutionflops = 0for i, ps in enumerate(self.patch_size):if i == len(self.patch_size) - 1:dim = self.embed_dim // 2 ** ielse:dim = self.embed_dim // 2 ** (i + 1)flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])if self.norm is not None:flops += Ho * Wo * self.embed_dimreturn flopsclass CrossFormer(nn.Module):r""" CrossFormerA PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention`  -Args:img_size (int | tuple(int)): Input image size. Default 224patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each stage.num_heads (tuple(int)): Number of attention heads in different layers.group_size (int): Group size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: Nonedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.ape (bool): If True, add absolute position embedding to the patch embedding. Default: Falsepatch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: Falseuse_cpe (bool): Whether to use conditional positional encoding. Default: Falsegroup_type (str): Strategy to change the group size in different stages. Default: constantpad_type (bool): 0 to pad in one direction, otherwise 1. Default: 0"""def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],group_size=[7, 7, 7, 7], crs_interval=[8, 4, 2, 1], mlp_ratio=4.,qkv_bias=True, qk_scale=None,drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,norm_layer=nn.LayerNorm, ape=False, patch_norm=True,use_checkpoint=False, merge_size=[[2], [2], [2]], use_cpe=False,group_type='constant', pad_type=0, no_mask=False,adaptive_interval=False, use_acl=False, **kwargs):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.ape = apeself.patch_norm = patch_normself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratio# split image into non-overlapping patchesself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)num_patches = self.patch_embed.num_patchespatches_resolution = self.patch_embed.patches_resolutionself.patches_resolution = patches_resolution# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))trunc_normal_(self.absolute_pos_embed, std=.02)self.pos_drop = nn.Dropout(p=drop_rate)# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule# compute group size for each layergroup_size = self.compute_group_size(group_size, depths, patches_resolution, group_type)# build layersself.layers = nn.ModuleList()num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]for i_layer in range(self.num_layers):patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else Nonenum_patch_size = num_patch_sizes[i_layer]layer = Stage(dim=int(embed_dim * 2 ** i_layer),input_resolution=(patches_resolution[0] // (2 ** i_layer),patches_resolution[1] // (2 ** i_layer)),depth=depths[i_layer],num_heads=num_heads[i_layer],group_size=group_size[i_layer],interval=crs_interval[i_layer],mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint,patch_size_end=patch_size_end,num_patch_size=num_patch_size,use_cpe=use_cpe,pad_type=pad_type,no_mask=no_mask,adaptive_interval=adaptive_interval,use_acl=use_acl)self.layers.append(layer)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.apply(self._init_weights)def compute_group_size(self, group_size=[7, 7, 7, 7], depths=[2, 2, 6, 2], resolution=[56, 56],group_type='constant'):r"""genenrate group size for crossformeroutput:- rst_group_size: should be in the shape [[4], [4, 4], [14, 14, 14], [7, 7]] if the depths = [1, 2, 3, 2]"""rst_group_size = []# compute linear fraction patch sizemin_size = 4total_depth = sum(depths)step_size = (1 - min_size / resolution[0]) / total_depthgroup_fraction = np.arange(min_size / resolution[0], 1.0, step_size)cnt = 0for i_stage in range(len(depths)):rst_group_size.append([])cur_resolution = resolution[0] // 2 ** i_stagefor i_block in range(depths[i_stage]):if group_type == 'constant':# constant group size for each stagerst_group_size[i_stage].append(group_size[i_stage])elif group_type == 'linear':# the fraction of group size relative to input resolution grow in lineargz = cur_resolution * group_fraction[cnt]rst_group_size[i_stage].append(max(4, int(np.ceil(gz))))elif group_type == 'linear_div':# if fraction > 1/2, let fraction = 1/2 if fraction < 3/4 else 1gz = cur_resolution * group_fraction[cnt]if gz > cur_resolution // 2:gz = cur_resolution if gz > cur_resolution * 3 / 4 or i_stage != 2 else cur_resolution // 2rst_group_size[i_stage].append(max(4, int(np.ceil(gz))))elif group_type == 'alter':# if fraction > 1/2, let fraction alter between 1/2 and 1gz = cur_resolution * group_fraction[cnt]if gz > cur_resolution // 2:gz = cur_resolution if cnt % 2 != 0 or i_stage != 2 else cur_resolution // 2rst_group_size[i_stage].append(max(4, int(np.ceil(gz))))elif group_type == '7_14':rst_group_size[i_stage].append(group_size[i_stage] if i_stage != 2 or i_block >= 4 else group_size[i_stage] // 2)cnt += 1print("Group Size:")print(rst_group_size)return rst_group_sizedef _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)@torch.jit.ignoredef no_weight_decay(self):return {'absolute_pos_embed'}@torch.jit.ignoredef no_weight_decay_keywords(self):return {'relative_position_bias_table'}def forward_features(self, x):x = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)for layer in self.layers:x = layer(x)x = self.norm(x)  # B L Cx = self.avgpool(x.transpose(1, 2))  # B C 1x = torch.flatten(x, 1)return xdef forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef flops(self):flops = 0flops += self.patch_embed.flops()for i, layer in enumerate(self.layers):flops += layer.flops()flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)flops += self.num_features * self.num_classesreturn flopsdef cros_tiny_patch4_group7_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=64, depths=[1, 1, 8, 6], num_heads=[2, 4, 8, 16],group_size=[7, 7, 7, 7], patch_size=[4, 8, 16, 32],drop_path_rate=0.5,merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",use_acl=False,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_small_patch4_group7_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],group_size=[7, 7, 7, 7], patch_size=[4, 8, 16, 32],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",use_acl=False,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_base_patch4_group7_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],group_size=[7, 7, 7, 7], patch_size=[4, 8, 16, 32],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",use_acl=False,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_large_patch4_group7_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],group_size=[7, 7, 7, 7], patch_size=[4, 8, 16, 32],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",use_acl=False,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_pp_small_patch4_group_const_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=64, depths=[2, 2, 18, 2], num_heads=[2, 4, 8, 16],group_size=[4, 4, 14, 7], patch_size=[4, 8, 16, 32],interval=[4,2,1,1],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",drop_path_rate=0.2,use_acl=True,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_pp_base_patch4_group_const_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24],group_size=[4, 4, 14, 7], patch_size=[4, 8, 16, 32],interval=[4,2,1,1],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",drop_path_rate=0.3,use_acl=True,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_pp_large_patch4_group_const_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],group_size=[4, 4, 14, 7], patch_size=[4, 8, 16, 32],interval=[4,2,1,1],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",drop_path_rate=0.5,use_acl=True,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return modeldef cros_pp_huge_patch4_group_const_224(pretrained=None, **kwargs):model = CrossFormer(embed_dim=128, depths=[6, 6, 18, 2], num_heads=[4, 8, 16, 32],group_size=[4, 4, 14, 7], patch_size=[4, 8, 16, 32],interval=[4,2,1,1],merge_size=[[2, 4], [2, 4], [2, 4]], group_type="constant",use_acl=True,drop_path_rate=0.5,**kwargs)if pretrained:checkpoint = torch.load(pretrained,map_location="cpu")model.load_state_dict(checkpoint["model"],strict=False)return model

相关文章:

CrossFormer实战:使用CrossFormer实现图像分类任务(一)

摘要 CrossFormer是一种新型的视觉Transformer架构&#xff0c;旨在通过引入跨尺度注意力机制来提升计算机视觉任务的性能。该模型特别关注不同尺度特征之间的交互&#xff0c;解决了现有视觉Transformer在处理多尺度特征时的不足。 研究背景 在计算机视觉中&#xff0c;特征…...

性能测试工具Jmeter中的FTP脚本开发

FTP文件传输协议是TCP/IP协议组织中的常用协议之一&#xff0c;主要用在internet上双向传输文件。FTP协议具有客户端和服务器端两个部分组成部分&#xff0c;具有上传与下载两种功能。Jmeter也提供了FTP请求的测试支持&#xff0c;实现了上传和下载功能测试。 对于上图的FTP请求…...

探索微软 M365 安全:全方位守护数字世界

在当今这个科技呈井喷式飞速发展,数字化浪潮以汹涌澎湃、锐不可当之势席卷全球的时代,企业与个人仿若置身于一片浩瀚无垠、信息奔涌的海洋之中,尽情畅享着技术革新所带来的无穷无尽便利。然而,恰如平静海面下潜藏着暗礁与汹涌暗流,网络安全问题恰似隐匿在暗处、随时可能给…...

Qt C++读写NFC标签NDEF网址URI

本示例使用的发卡器&#xff1a;https://item.taobao.com/item.htm?spma21dvs.23580594.0.0.1d292c1biFgjSs&ftt&id615391857885 #include "mainwindow.h" #include "ui_mainwindow.h" #include <QDebug> #include "QLibrary" …...

[SMARTFORMS] 自定义SMARTFORMS表单页格式

在SMARTFORMS表单开发过程中&#xff0c;用户打印的纸张有可能不是标准的页格式&#xff0c;需要我自定义页格式 具体操作步骤如下所示 1.定义页格式 事务码SPAD&#xff0c;点击"完全管理" 点击"设备类型"中的页格式的"显示"按钮 点击创建按…...

大模型笔记:KV cache

1 为什么要使用KV cache 假设模型最终生成了四个token 对于第一个token&#xff0c;他的attention的计算方法为&#xff1a; 有了第一个token之后&#xff0c;生成第二个token的时候&#xff1a; sottmaxed表示已经逐行softmax后的结果同理&#xff0c;对于第三个token&…...

Android车机DIY开发之学习篇(三)替换Logo以正点原子为例

Android车机DIY开发之学习篇(三)替换Logo以正点原子为例 启动 logo 包括 u-boot 阶段 logo 内核阶段 logo /sdk/kernel-5.10 目录下替换 logo.bmp 654270 logo_kernel.bmp 654270 编译 Linux 内核...

宝塔面板 php8.0 安装 fileinfo 拓展失败

系统&#xff1a;Albaba Cloud Linux release 3 &#xff08;OpenAnolis Editon&#xff09;即 Centos 平替 异常提示&#xff1a; cc: fatal error: ** signal terminated program cc1 compilation terminated. make: *** [Makefile:211: libmagic/apprentice.lo] Error 1搜…...

机器学习数据预处理preprocessing

预处理方法预处理方法预处理方法BinarizerFunctionTransformerKBinsDiscretizerKernelCentererLabelBinarizerLabelEncoderMaxAbsScalerMinMaxScalerMultiLabelBinarizer sklearn.preprocessing.Binarizer 设定一个阈值&#xff08;threshold&#xff09;&#xff0c;对于每个…...

网络安全 | 什么是Bot防护?

关注&#xff1a;CodingTechWork Bot防护介绍 随着互联网服务的普及和发展&#xff0c;越来越多的网站和应用遭遇了自动化攻击&#xff08;Bot攻击&#xff09;。Bot防护是一种安全技术&#xff0c;旨在检测和阻止自动化程序&#xff08;即“机器人”或“bot”&#xff09;对网…...

Qt学习笔记第81到90讲

第81讲 串口调试助手实现自动发送 为这个名叫“定时发送”的QCheckBox编写槽函数。 想要做出定时发送的效果&#xff0c;必须引入QT框架下的毫秒级定时器QTimer&#xff0c;查阅手册了解详情。 在widget.h内添加新的私有成员变量&#xff1a; QTimer *timer; 在widget类的构造…...

如何在本地部署大模型并实现接口访问( Llama3、Qwen、DeepSeek等)

如何在本地部署大模型并实现接口访问&#xff08; Llama3、Qwen、DeepSeek等&#xff09; 如何在本地部署大模型并实现接口访问&#xff08; Llama3、Qwen、DeepSeek等&#xff09;模型地址模型下载模型部署指定显卡运行app.py 运行环境requirements 调用接口代码调用 结语 如何…...

使用 Linux tracepoint、perf 和 eBPF 跟踪数据包

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 目录 1 破局 1.1 逃离迷宫&#xff1a;上帝视角 1.2 网络跟踪&#xff1a;渴求利器 1.3 巨人肩膀&#xff1a;perf/eBPF 2 Perf 2.1 安装 perf 2.2 测试环境 2.3 初体验&#xff1a;跟踪 ping …...

给DevOps加点料:融入安全性的DevSecOps

从前&#xff0c;安全防护只是特定团队的责任&#xff0c;在开发的最后阶段才会介入。当开发周期长达数月、甚至数年时&#xff0c;这样做没什么问题&#xff1b;但是现在&#xff0c;这种做法现在已经行不通了。 采用 DevOps 可以有效推进快速频繁的开发周期&#xff08;有时…...

MySQL视图笔记

视图的理解 ①视图是一种 虚拟表 &#xff0c;本身是 不具有数据 的&#xff0c;占用很少的内存空间&#xff0c;它是 SQL 中的一个重要概念。 ②视图建立在已有表的基础上, 视图赖以建立的这些表称为基表。 ③对视图中的数据进行增加删除和修改&#xff0c;对应的数据表&a…...

【Ubuntu与Linux操作系统:十、C/C++编程】

第10章 C/C编程 10.1 Linux编程基础 Linux编程基础涵盖了C/C语言在Linux环境中的特点和使用方法。Linux以其高性能和开源特性成为系统编程的重要平台。 1. C语言与Linux的关系 Linux内核主要是用C语言编写的&#xff0c;因此学习C语言是理解Linux底层机制的必要前提。C语言的…...

豆包MarsCode:可以在线用的智能AI编程助手

大家好&#xff0c;今天我想和大家分享一个我最近发现的宝藏工具——豆包MarsCode。 作为一个程序员&#xff0c;我一直在寻找能够提高工作效率、快捷、 优化代码质量的在线编程工具。豆包MarsCode IDE&#xff0c;这个由字节跳动推出的智能编程助手&#xff0c;让我眼前一亮&…...

RabbitMQ基础(简单易懂)

RabbitMQ高级篇请看&#xff1a; RabbitMQ高级篇-CSDN博客 目录 什么是RabbitMQ&#xff1f; MQ 的核心概念 1. RabbitMQ 的核心组件 2. Exchange 的类型 3. 数据流向说明 如何安装RabbitQueue&#xff1f; WorkQueue&#xff08;工作队列&#xff09;&#xff1a; Fa…...

UE5 使用内置组件进行网格切割

UE引擎非常强大&#xff0c;直接内置了网格切割功能并封装为蓝图节点&#xff0c;这项功能在UE4中就存在&#xff0c;并且无需使用Chaos等模块。那么就来学习下如何使用内置组件实现网格切割。 1.配置测试用StaticMesh 对于被切割的模型&#xff0c;需要配置一些参数。以UE5…...

【面试题】技术场景 6、Java 生产环境 bug 排查

生产环境 bug 排查思路 分析日志&#xff1a;首先通过分析日志查看是否存在错误信息&#xff0c;利用之前讲过的 elk 及查看日志的命令缩小查找错误范围&#xff0c;方便定位问题。远程 debug 适用环境&#xff1a;一般公司正式生产环境不允许远程 debug&#xff0c;多在测试环…...

macOS 安装tomcat9

macOS 安装tomcat9 URL&#xff1a;https://tomcat.apache.org/download-90.cgi 解压之后放到指定目录 /Users/lanren/install/tomcat-9 自己取个名字就行 给权限&#xff1a; ① 先进行权限修改&#xff1a;终端输入sudo chmod 755 /Users/lanren/install/tomcat-9/bin/…...

多线程之旅:属性及其基本操作

上次分享到了&#xff0c;多线程中是是如何创建的&#xff0c;那么接下来&#xff0c;小编继续分享下多线程的相关知识。 多线程中的一些基本属性。 基本属性 属性获取方法IDgetId()名称getName()状态getState()优先级getPriority()是否后台线程isDemo()是否存活isAlive()是…...

隧道网络:为数据传输开辟安全通道

什么是隧道网络&#xff1f; 想象一下&#xff0c;你正在一个陌生的城市旅行&#xff0c;并且想要访问家里的电脑。但是&#xff0c;直接连接是不可能的&#xff0c;因为家庭网络通常受到防火墙或路由器的保护&#xff0c;不允许外部直接访问。这时候&#xff0c;隧道网络&…...

Python爬虫-汽车之家各车系周销量榜数据

前言 本文是该专栏的第43篇,后面会持续分享python爬虫干货知识,记得关注。 在本专栏之前,笔者在文章《Python爬虫-汽车之家各车系月销量榜数据》中,有详细介绍,如何爬取“各车系车型的月销量榜单数据”的方法以及完整代码教学教程。 而本文,笔者同样以汽车之家平台为例,…...

【机器学习】时序数据与序列建模:理论与实践的全面指南

云边有个稻草人-CSDN博客 目录 云边有个稻草人-CSDN博客 引言 一、时序数据的特点与挑战 1.1 时序数据的特点 1.2 序列建模的挑战 二、传统方法概览 2.1 ARIMA 模型 2.2 Prophet 三、深度学习方法 3.1 RNN 和 LSTM 3.2 Attention 和 Transformer 3.3 自监督学习 四、…...

java.net.SocketException: Connection reset 异常原因分析和解决方法

导致此异常的原因&#xff0c;总结下来有三种情况&#xff1a; 一、服务器端偶尔出现了异常&#xff0c;导致连接关闭 解决方法&#xff1a; 采用出错重试机制 二、 服务器端和客户端使用的连接方式不一致 解决方法&#xff1a; 服务器端和客户端使用相同的连接方式&#xff…...

【华为OD-E卷 - 恢复数字序列 100分(python、java、c++、js、c)】

【华为OD-E卷 - 恢复数字序列 100分&#xff08;python、java、c、js、c&#xff09;】 题目 对于一个连续正整数组成的序列&#xff0c;可以将其拼接成一个字符串&#xff0c;再将字符串里的部分字符打乱顺序。如序列8 9 10 11 12&#xff0c;拼接成的字符串为89101112&…...

05、Redis持久化

Redis是在内存中操作的&#xff0c;我们服器关闭重启机器后&#xff0c;正常是之前在redis中操作的数据都不存在了&#xff0c;但是实际上我们开机后重新启动redis服务&#xff0c;一样可以看到之前操作的数据。这是为什么呢&#xff1f; 我们在redis的安装目录下可以看到有一…...

Python爬虫基础——selenium模块进阶(模拟鼠标操作)

主要内容包括&#xff1a;模拟鼠标操作。常用的鼠标操作有单击、双击、右击、长按、拖动、移动等&#xff0c;模拟这些操作需要用到selenium模块中的ActionChains类。该类的基本使用方法是将实例化好的WebDriver对象作参数传到该类中&#xff0c;实例化成一个ActionChains对象&…...

C++ macro: The # operator

C macro: The # operator 1. The # operator2. Stringizing (字符串化)References 1. The # operator The # operator converts a parameter of a function-like macro into a character string literal. #define STR(x) #xAll subsequent invocations of the macro STR woul…...

一学就废|Python基础碎片,文件读写

文件处理是指通过编程接口对文件执行诸如创建、打开、读取、写入和关闭等操作的过程。它涉及管理程序与存储设备上的文件系统之间的数据流&#xff0c;确保数据得到安全高效的处理。 Python 中的文件模式 打开文件时&#xff0c;我们必须指定我们想要的模式&#xff0c;该模式…...

使用MATLAB正则表达式从文本文件中提取数据

使用MATLAB正则表达式从文本文件中提取数据 使用Python正则表达式从文本文件中提取数据的代码请看这篇文章使用正则表达式读取文本数据【Python】-CSDN博客 文本数据格式 需要提取 V 后面的数据, 并绘制出曲线. index 1V 0.000000W 0.000000E_theta 0.000000UINV 0.0…...

Java基于SSM框架的在线视频教育系统小程序【附源码、文档】

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…...

Git文件夹提交错了,怎么撤销?

最近提交了一些不应该提交的文件夹到git中,现在需要移除它们,现在简单记录一下操作日志: 情况一 文件夹已经被添加到 Git&#xff0c;但未提交 如果文件夹已经被 git add 添加到暂存区中&#xff0c;但尚未提交&#xff0c;你可以使用以下命令将其从暂存区中移除: git rm -r …...

Unity TextMesh Pro入门

概述 TextMesh Pro是Unity提供的一组工具&#xff0c;用于创建2D和3D文本。与Unity的UI文本和Text Mesh系统相比&#xff0c;TextMesh Pro提供了更好的文本格式控制和布局管理功能。 本文介绍了TMP_Text组件和Tmp字体资产(如何创建字体资产和如何解决缺字问题),还有一些高级功…...

大疆C++开发面试题及参考答案

虚函数的作用是什么&#xff1f;虚函数机制是如何实现的&#xff1f;虚表指针在内存中的存放位置在哪里&#xff1f; 虚函数主要用于实现多态性。多态是面向对象编程中的一个重要概念&#xff0c;它允许通过基类指针或引用调用派生类中重写的函数。这样可以在运行时根据对象的实…...

极品飞车6里的赛道简介

极品飞车里有很多赛道,赛道分为前向赛道Forward、后向赛道Backward。前向赛道Forward是从A点到B点;后向赛道Backward是前向赛道的逆过程,即从B点到A点。这里介绍极品飞车6的赛道长度、中英文名称翻译、难度等级。 序号赛道英文名赛道中文名总长(km)急弯难度等级1Alpine Trai…...

Swagger学习⑰——@Link注解

介绍 Link 是 Swagger/OpenAPI 3.0 注解库中的一个注解&#xff0c;用于在 OpenAPI 文档中定义链接&#xff08;Link&#xff09;。链接是一种在 API 响应中提供相关操作或资源引用的机制&#xff0c;通常用于描述操作之间的关系或提供额外的操作提示。 Link 注解的作用 Link…...

Cline(原Claude Dev)开源的IDE AI插件,如何搭配OpenRouter实现cursor功能,Cline怎么使用

Cline&#xff08;原Claude Dev&#xff09;是一个开源的IDE AI插件&#xff0c;可以使用你的命令行界面和编辑器的人工智能助手。 你可以直接在VS Code编辑器进行安装。如果你使用过Cursor AI IDE的话&#xff0c;可以尝试最新发布的Cline3.1版本。 在OpenRouter上&#xff0…...

WEB前端-3.1

目录 CSS部分 什么是CSS CSS的书写方式 网页引入CSS的方式 css的颜色、大小、边线 文本和字体样式 CSS选择器 属性选择器 伪类选择器 伪元素选择器 文本样式 display属性 背景样式 精灵图、雪碧图 元素定位 绝对定位 相对定位 浮动定位 浮动 CSS部分 什么是…...

灌区闸门自动化控制系统-精准渠道量测水-灌区现代化建设

项目背景 本项目聚焦于黑龙江某一灌区的现代化改造工程&#xff0c;该灌区覆盖广阔&#xff0c;灌溉面积高达7.5万亩&#xff0c;地域上跨越6个乡镇及涵盖17个村庄。项目核心在于通过全面的信息化建设&#xff0c;强力推动节水灌溉措施的实施&#xff0c;旨在显著提升农业用水的…...

QT中引入OpenCV库总结(qmake方式和cmake方式)

文章目录 前言opencv环境配置一、opencv库获取的两种方式二、qmake和cmake配置2.1、 qmake2.2、cmake2.2.1、引入opencv示例 三、qt与opencv对应关系四、问题 前言 我的软件环境&#xff0c;写在前面 Windows10QT5.12.12VS2017OpenCV4.5.4 opencv环境配置 一、opencv库获取…...

【DAPM杂谈之三】DAPM的初始化流程

本文主要分析DAPM的设计与实现 内核的版本是&#xff1a;linux-5.15.164&#xff0c;下载链接&#xff1a;Linux内核下载 主要讲解有关于DAPM相关的知识&#xff0c;会给出一些例程并分析内核如何去实现的 /**************************************************************…...

消息队列架构、选型、专有名词解释

私人博客传送门 消息队列专有名词解释 | 魔筝炼药师 MQ选型 | 魔筝炼药师 MQ架构 | 魔筝炼药师 MQ顺序消息 | 魔筝炼药师...

Scala语言的计算机基础

Scala语言的计算机基础 Scala是一种现代的编程语言&#xff0c;兼具面向对象和函数式编程的特性&#xff0c;广泛应用于大数据处理、后端开发和分布式系统等领域。本文将围绕Scala语言的基础知识&#xff0c;包括其语法特点、数据结构、函数式编程思想、与Java的关系以及在实际…...

爬虫基础之爬取歌曲宝歌曲批量下载

声明&#xff1a;本案列仅供学习交流使用 任何用于非法用途均与本作者无关 需求分析: 网站:邓紫棋-mp3在线免费下载-歌曲宝-找歌就用歌曲宝-MP3音乐高品质在线免费下载 (gequbao.com) 爬取 歌曲名 歌曲 实现歌手名称下载所有歌曲 本案列所使用的模块 requests (发送…...

书说 MySQL 的悲观锁和乐观锁

什么是乐观锁&#xff1f;什么是悲观锁&#xff1f; 悲观锁&#xff1a; 悲观锁是一种基于悲观态度的控制机制&#xff08;最坏的程度想&#xff0c;每次并发一定会造成阻塞&#xff09;&#xff0c;用于防止数据冲突。它采取预防性措施&#xff0c;在修改数据之前将其锁定&a…...

Linux WEB漏洞

定义&#xff1a;Linux Web 漏洞是指在基于 Linux 操作系统的 Web 应用程序、Web 服务器软件或者相关的网络服务配置中存在的安全弱点。这些漏洞可能导致攻击者未经授权访问敏感信息、篡改网页内容、执行恶意代码&#xff0c;甚至完全控制服务器。 常见类型及原理 SQL 注入漏…...

AIDD - 人工智能药物设计 -深度学习赋能脂质纳米颗粒设计,实现高效肺部基因递送

Nat. Biotechnol. | 深度学习赋能脂质纳米颗粒设计&#xff0c;实现高效肺部基因递送 今天为大家介绍的是来自美国麻省理工和爱荷华大学卡弗医学院团队的一篇论文。可离子化脂质&#xff08;ionizable lipids&#xff09;是脂质纳米颗粒&#xff08;lipid nanoparticles&#…...

Selenium 进行网页自动化操作的一个示例,绕过一些网站的自动化检测。python编程

初级教程 selenium 教程和视频教程s原理与安装 - 白月黑羽 https://www.byhy.net/auto/selenium/01/#chrome%201 Selenium 自动化环境安装_哔哩哔哩_bilibili Selenium 自动化环境安装是Python Selenium Web自动化 2024版 - 自动化测试 爬虫的第2集视频&#xff0c;该合集共…...