当前位置: 首页 > news >正文

计算机视觉注意力机制【一】常用注意力机制整理

在做目标检测项目,尤其是基于 YOLOv5 或 YOLOv7 的改进实验时,我发现不同注意力机制对模型性能的提升确实有明显影响,比如提高小目标检测能力、增强特征表达等。但每次找代码都得翻论文、找 GitHub,效率很低。所以我干脆把常见的注意力模块(比如 SE、CBAM、ShuffleAttention、SimAM 等)都整理到一起,统一了格式和接口,方便自己后续做结构替换和对比实验。这个整理也能帮助我更系统地理解各类注意力机制的原理和实现方式,也希望能为有类似需求的人提供一些参考。


后续会基于注意机制与 YOLO 目标检测进行融合,欢迎各位关注➕收藏


文章目录

      • 🔹 1. SEAttention(Squeeze-and-Excitation Attention)
      • 🔹 2. ShuffleAttention
      • 🔹 3. CrissCrossAttention(CCA)
      • 🔹 4. S2-MLPv2 Attention
      • 🔹 5. SimAM
      • 🔹 6. SKAttention(Selective Kernel)
      • 🔹 7. NAMAttention(Normalization-based Attention)
      • 🔹 8. SOCA(Second-order Channel Attention)
      • 🔹 9. CBAM(Convolutional Block Attention Module)
      • 🔹 10. GAMAttention
      • 🔹 11. Coordinate attention
      • 🔹 12. Efficient Channel Attention(ECA)

🔹 1. SEAttention(Squeeze-and-Excitation Attention)

来源:
https://arxiv.org/abs/1709.01507
机制:
全局平均池化 → 两层 MLP → Sigmoid → 通道权重调整

class SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

🔹 2. ShuffleAttention

来源:
https://arxiv.org/pdf/2102.00240.pdf
机制:
通道注意力 + 空间注意力,利用 GroupNorm 和 Shuffle 操作

class ShuffleAttention(nn.Module):def __init__(self, channel=512,reduction=16,G=8):super().__init__()self.G=Gself.channel=channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid=nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()#group into subfeaturesx=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w#channel_splitx_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w#channel attentionx_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1x_channel=x_0*self.sigmoid(x_channel)#spatial attentionx_spatial=self.gn(x_1) #bs*G,c//(2*G),h,wx_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,wx_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w# concatenate along channel axisout=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,wout=out.contiguous().view(b,-1,h,w)# channel shuffleout = self.channel_shuffle(out, 2)return out

🔹 3. CrissCrossAttention(CCA)

来源: CCNet-Pure-Pytorch
机制: 分别在 H 和 W 方向做注意力交叉计算,融合上下文

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmaxdef INF(B,H,W):return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self, in_dim):super(CrissCrossAttention,self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.softmax = Softmax(dim=3)self.INF = INFself.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):m_batchsize, _, height, width = x.size()proj_query = self.query_conv(x)proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)proj_key = self.key_conv(x)proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)proj_value = self.value_conv(x)proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)concate = self.softmax(torch.cat([energy_H, energy_W], 3))att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)#print(concate)#print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)#print(out_H.size(),out_W.size())return self.gamma*(out_H + out_W) + x

🔹 4. S2-MLPv2 Attention

来源:
https://arxiv.org/abs/2108.01072
机制:
Spatial Shift + MLP + 分支融合

def spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]return xclass SplitAttention(nn.Module):def __init__(self,channel=512,k=3):super().__init__()self.channel=channelself.k=kself.mlp1=nn.Linear(channel,channel,bias=False)self.gelu=nn.GELU()self.mlp2=nn.Linear(channel,channel*k,bias=False)self.softmax=nn.Softmax(1)def forward(self,x_all):b,k,h,w,c=x_all.shapex_all=x_all.reshape(b,k,-1,c) a=torch.sum(torch.sum(x_all,1),1) hat_a=self.mlp2(self.gelu(self.mlp1(a))) hat_a=hat_a.reshape(b,self.k,c) bar_a=self.softmax(hat_a) attention=bar_a.unsqueeze(-2) out=attention*x_all out=torch.sum(out,1).reshape(b,h,w,c)return outclass S2Attention(nn.Module):def __init__(self, channels=512 ):super().__init__()self.mlp1 = nn.Linear(channels,channels*3)self.mlp2 = nn.Linear(channels,channels)self.split_attention = SplitAttention()def forward(self, x):b,c,w,h = x.size()x=x.permute(0,2,3,1)x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c])x2 = spatial_shift2(x[:,:,:,c:c*2])x3 = x[:,:,:,c*2:]x_all=torch.stack([x1,x2,x3],1)a = self.split_attention(x_all)x = self.mlp2(a)x=x.permute(0,3,1,2)return x

🔹 5. SimAM

机制: 使用方差引导通道激活,无需参数

import torch
import torch.nn as nnclass SimAM(torch.nn.Module):def __init__(self, channels = None,out_channels = None, e_lambda = 1e-4):super(SimAM, self).__init__()self.activaton = nn.Sigmoid()self.e_lambda = e_lambdadef __repr__(self):s = self.__class__.__name__ + '('s += ('lambda=%f)' % self.e_lambda)return s@staticmethoddef get_module_name():return "simam"def forward(self, x):b, c, h, w = x.size()n = w * h - 1x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5 #atbriassreturn x * self.activaton(y)  

🔹 6. SKAttention(Selective Kernel)

机制: 多尺度卷积 + Soft attention 选择合适卷积核

class SKAttention(nn.Module):def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):super().__init__()self.d=max(L,channel//reduction)self.convs=nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),('bn',nn.BatchNorm2d(channel)),('relu',nn.ReLU())])))self.fc=nn.Linear(channel,self.d)self.fcs=nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d,channel))self.softmax=nn.Softmax(dim=0)def forward(self, x):bs, c, _, _ = x.size()conv_outs=[]### split atbriassfor conv in self.convs:conv_outs.append(conv(x))feats=torch.stack(conv_outs,0)#k,bs,channel,h,w### fuseU=sum(conv_outs) #bs,c,h,w### reduction channelS=U.mean(-1).mean(-1) #bs,cZ=self.fc(S) #bs,d### calculate attention weightweights=[]for fc in self.fcs:weight=fc(Z)weights.append(weight.view(bs,c,1,1)) #bs,channelattention_weughts=torch.stack(weights,0)#k,bs,channel,1,1attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1### fuseV=(attention_weughts*feats).sum(0)return V

🔹 7. NAMAttention(Normalization-based Attention)

机制: 使用 BN 参数的归一化特性引导注意力

import torch.nn as nn
import torch
from torch.nn import functional as Fclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual #return xclass NAMAttention(nn.Module):def __init__(self, channels, out_channels=None, no_spatial=True):super(NAMAttention, self).__init__()self.Channel_Att = Channel_Att(channels)def forward(self, x):x_out1=self.Channel_Att(x)return x_out1  

🔹 8. SOCA(Second-order Channel Attention)

机制: 基于协方差池化和矩阵平方根的高阶通道注意力机制

import numpy as np
import torch
from torch import nn
from torch.nn import initfrom torch.autograd import Functionclass Covpool(Function):@staticmethoddef forward(ctx, input):x = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]h = x.data.shape[2]w = x.data.shape[3]M = h*wx = x.reshape(batchSize,dim,M)I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)y = x.bmm(I_hat).bmm(x.transpose(1,2))ctx.save_for_backward(input,I_hat)return y@staticmethoddef backward(ctx, grad_output):input,I_hat = ctx.saved_tensorsx = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]h = x.data.shape[2]w = x.data.shape[3]M = h*wx = x.reshape(batchSize,dim,M)grad_input = grad_output + grad_output.transpose(1,2)grad_input = grad_input.bmm(x).bmm(I_hat)grad_input = grad_input.reshape(batchSize,dim,h,w)return grad_inputclass Sqrtm(Function):@staticmethoddef forward(ctx, input, iterN):x = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]dtype = x.dtypeI3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)A = x.div(normA.view(batchSize,1,1).expand_as(x))Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)if iterN < 2:ZY = 0.5*(I3 - A)Y[:,0,:,:] = A.bmm(ZY)else:ZY = 0.5*(I3 - A)Y[:,0,:,:] = A.bmm(ZY)Z[:,0,:,:] = ZYfor i in range(1, iterN-1):ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)ctx.save_for_backward(input, A, ZY, normA, Y, Z)ctx.iterN = iterNreturn y@staticmethoddef backward(ctx, grad_output):input, A, ZY, normA, Y, Z = ctx.saved_tensorsiterN = ctx.iterNx = inputbatchSize = x.data.shape[0]dim = x.data.shape[1]dtype = x.dtypeder_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)if iterN < 2:der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))else:dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])for i in range(iterN-3, -1, -1):YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])dldY_ = 0.5*(dldY.bmm(YZ) - Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - ZY.bmm(dldY))dldZ_ = 0.5*(YZ.bmm(dldZ) - Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -dldZ.bmm(ZY))dldY = dldY_dldZ = dldZ_der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)for i in range(batchSize):grad_input[i,:,:] += (der_postComAux[i] \- grad_aux[i] / (normA[i] * normA[i])) \*torch.ones(dim,device = x.device).diag()return grad_input, Nonedef CovpoolLayer(var):return Covpool.apply(var)def SqrtmLayer(var, iterN):return Sqrtm.apply(var, iterN)class SOCA(nn.Module):# second-order Channel attentiondef __init__(self, channel, reduction=8):super(SOCA, self).__init__()self.max_pool = nn.MaxPool2d(kernel_size=2)self.conv_du = nn.Sequential(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),nn.ReLU(inplace=True),nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),nn.Sigmoid())def forward(self, x):batch_size, C, h, w = x.shape  # x: NxCxHxWN = int(h * w)min_h = min(h, w)h1 = 1000w1 = 1000if h < h1 and w < w1:x_sub = xelif h < h1 and w > w1:W = (w - w1) // 2x_sub = x[:, :, :, W:(W + w1)]elif w < w1 and h > h1:H = (h - h1) // 2x_sub = x[:, :, H:H + h1, :]else:H = (h - h1) // 2W = (w - w1) // 2x_sub = x[:, :, H:(H + h1), W:(W + w1)]cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layercov_mat_sqrt = SqrtmLayer(cov_mat,5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration)cov_mat_sum = torch.mean(cov_mat_sqrt,1)cov_mat_sum = cov_mat_sum.view(batch_size,C,1,1)y_cov = self.conv_du(cov_mat_sum)return y_cov*x

🔹 9. CBAM(Convolutional Block Attention Module)

机制: Channel Attention + Spatial Attention 串联使用

class ChannelAttentionModule(nn.Module):def __init__(self, c1, reduction=16):super(ChannelAttentionModule, self).__init__()mid_channel = c1 // reductionself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Linear(in_features=c1, out_features=mid_channel),nn.LeakyReLU(0.1, inplace=True),nn.Linear(in_features=mid_channel, out_features=c1))self.act = nn.Sigmoid()#self.act=nn.SiLU()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)maxout = self.shared_MLP(self.max_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)return self.act(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.act = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.act(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, c1,c2):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out

🔹 10. GAMAttention

原理图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import numpy as np
import torch
from torch import nn
from torch.nn import initclass GAMAttention(nn.Module):#https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True,rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 /rate)),nn.ReLU(inplace=True),nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle out = x * x_spatial_attreturn out  def channel_shuffle(x, groups=2):B, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out=out.view(B, C, H, W) return out

🔹 11. Coordinate attention

class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)
class CA(nn.Module):# Coordinate Attention for Efficient Mobile Network Design'''Recent studies on mobile network design have demonstrated the remarkable effectiveness of channel attention (e.g., the Squeeze-and-Excitation attention) for liftingmodel performance, but they generally neglect the positional information, which is important for generating spatially selective attention maps. In this paper, we propose anovel attention mechanism for mobile iscyy networks by embedding positional information into channel attention, whichwe call “coordinate attention”. Unlike channel attentionthat transforms a feature tensor to a single feature vector iscyy via 2D global pooling, the coordinate attention factorizes channel attention into two 1D feature encoding processes that aggregate features along the two spatial directions, respectively'''def __init__(self, inp, oup, reduction=32):super(CA, self).__init__()mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn,c,h,w = x.size()pool_h = nn.AdaptiveAvgPool2d((h, 1))pool_w = nn.AdaptiveAvgPool2d((1, w))x_h = pool_h(x)x_w = pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out   

🔹 12. Efficient Channel Attention(ECA)

import torch.nn as nn
import torch
from torch.nn import functional as Fclass ECAttention(nn.Module):"""Constructs a ECA module.Args:channel: Number of channels of the input feature mapk_size: Adaptive selection of kernel size automg"""def __init__(self, c1,c2, k_size=3):super(ECAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)

相关文章:

计算机视觉注意力机制【一】常用注意力机制整理

在做目标检测项目&#xff0c;尤其是基于 YOLOv5 或 YOLOv7 的改进实验时&#xff0c;我发现不同注意力机制对模型性能的提升确实有明显影响&#xff0c;比如提高小目标检测能力、增强特征表达等。但每次找代码都得翻论文、找 GitHub&#xff0c;效率很低。所以我干脆把常见的注…...

交替序列长度的最大值

1、题目描述 给出n个正整数&#xff0c;你可以随意从中挑选一些数字组成 一段序列S&#xff0c;该序列满足以下两个条件&#xff1a; 1.奇偶交替排列&#xff1a;例如&#xff1a;"奇&#xff0c;偶&#xff0c;奇&#xff0c;偶&#xff0c;奇.…" 或者 "偶&a…...

追踪大型语言模型的思想(下)(来自针对Claude的分析)

多步推理 正如我们上面所讨论的&#xff0c;语言模型回答复杂问题的一种方式就是简单地记住答案。例如&#xff0c;如果问“达拉斯所在州的首府是哪里&#xff1f;”&#xff0c;一个“机械”的模型可以直接学会输出“奥斯汀”&#xff0c;而无需知道德克萨斯州&#xff0c;达拉…...

嵌入式通信协议总览篇:万物互联的基石

嵌入式系统的世界,是靠协议“说话”的世界。 在你设计一个智能设备、构建一个工业控制系统、开发一款 IoT 网关时,一个核心问题始终绕不开:**这些设备之间如何“对话”?**答案就是——通信协议。 本篇作为系列第一章,将带你全面理解嵌入式通信协议的全貌,为后续深入学习…...

Android 连接德佟打印机全实例+踩坑

文章目录 1. sdk下载2. 开始开发2.1 打印之前准备工作2.2 打印机是否连接检测2.3 打印框架设计 最近有个需求是要连接 德佟打印机 进行打印相关事宜, 现在就遇到的问题简单阐述一下。 1. sdk下载 我们首先需要在官网下载对应的SDK&#xff0c;地址为&#xff1a;https://www.d…...

TikTok 矩阵运营新手实操保姆级教程 2.0 版本

在当下这个全球化的数字浪潮中&#xff0c;TikTok 这片充满机遇的流量蓝海&#xff0c;正吸引着无数创业者和品牌方争相角逐。而要想在这激烈的竞争中脱颖而出&#xff0c;TikTok 矩阵运营无疑是至关重要的制胜法宝。今天&#xff0c;就给大家送上这份超实用的新手实操教程&…...

WordPress:Locoy.php火车头采集

<?php /* 模块参数列表&#xff1a; post_title 必选 标题 post_content 必选 内容 tag 可选 标签 post_category 可选 分类 post_date 可选 时间 post_excerpt 可选 摘要 post_author 可选 作者 category_description 可选 分类信息 post_cate_meta[name] 可选 自定义分…...

C++ 有哪些标准版本

目录 1.主要分为以下几个版本C98&#xff08;ISO/IEC 14882:1998&#xff09; 第一个国际标准C03&#xff08;ISO/IEC 14882:2003&#xff09;小幅度修订C11&#xff08;ISO/IEC 14882:2011&#xff09;一次重大更新C14&#xff08;ISO/IEC 14882:2014&#xff09;增量改进C17&…...

二、MySQL操作命令汇总

文章目录 二、MySQL操作命令汇总1.数据库操作2.表的增删改查2.1 查表2.2 建表给表添加注释假如表已经存在 2.3 删表2.4 查看表结构2.5 改表 3.简单查询3.1 查询单个字段3.2 查询多个字段3.3 查询所有字段3.4 查询结果去重3.5 查询结果排序3.6 查询结果限制条数3.7 查询分组结果…...

编程日志4.28

队列的链表表示代码 #include<iostream> #include<stdexcept> using namespace std; //队列 类的声明 template<typename T>//1.模板声明&#xff0c;表明Queue类是一个通用的模板类&#xff0c;可以用于存储任何类型的元素T class Queue {//2.Queue类的声…...

Qt 中信号与槽(signal-slot)机制支持 多种连接方式(ConnectionType)

Qt 中信号与槽&#xff08;signal-slot&#xff09;机制支持 多种连接方式&#xff08;ConnectionType&#xff09; Qt 中信号与槽&#xff08;signal-slot&#xff09;机制支持 多种连接方式&#xff08;ConnectionType&#xff09;&#xff0c;用于控制信号发出后如何调用槽…...

Python案例实战《手势识别》

目录 1、效果图2、手势识别关键步骤&#xff08;1&#xff09; 导入必要的库&#xff08;2&#xff09;配置 MediaPipe&#xff08;3&#xff09;启动摄像头&#xff08;4&#xff09;设置手指张开判断的距离阈值&#xff08;5&#xff09;计算手指之间的欧几里得距离&#xff…...

NGINX `ngx_http_charset_module` 字符集声明与编码转换

一、模块定位与功能 ngx_http_charset_module 主要提供两大能力&#xff1a; 响应头声明&#xff1a;在 Content-Type 头部自动添加 ; charsetXXX&#xff0c;告知客户端所用字符集。单向编码转换&#xff1a;在 NGINX 层将一种单字节编码&#xff08;如 koi8-r、windows-125…...

进程与线程详细介绍

目录 一 进程概念 二 进程的组成 2.1 PCB 2.2 数据段 2.3 程序段 三 进程的五大特点 四 进程的创建与销毁 五 线程概念 六 线程特征 七 进程与线程的区别与联系 区别 联系 一 进程概念 进程是程序的一次执行过程&#xff0c;是操作系统进行资源分配和调度的基本单位…...

JAVA中ArrayList的解析

gogogo出发喽&#xff01;让我们来认识一下它吧 什么是ArrayList Java 中的 ArrayList 是 Java 集合框架中的重要类&#xff0c;用于实现动态数组 动态数组&#xff1a;可按需自动扩展或缩小&#xff0c;无需手动管理数组大小。比如不断向 ArrayList 添加元素时&#xff0c;…...

【LLM+Code】Devin PromptTools详细解读

Devin 官网&#xff1a;https://devin.ai/ Prompt 大部分篇幅都是tools的直出的description和parameters的一些信息 其他的包含 Communicatework的一些指导Best PracticesInformation HandlingData SecurityResponse Limitationsplanthink You are Devin, a software engi…...

AI应用开发实战分享

一、前言 30年前的IntelWindows互相绑定&#xff0c;让世界被计算机技术重构了一次&#xff0c;有了程序员这个工种。十几年前iPhone、Android前后脚发布&#xff0c;智能手机和移动App互相绑定&#xff0c;引爆了一个长达十几年的移动互联网大跃进时代。而随着人工智能大模型…...

浅聊find_package命令的搜索模式(Search Modes)

背景 find_package应该算是我们使用最多的cmake命令了。但是它是如何找到上游库的.cmake文件的&#xff1f; 根据官方文档&#xff0c;整理下find_package涉及到的搜索模式。 搜索模式 find_package涉及到的搜索模式有两种&#xff1a;模块模式(Module mode)和配置模式(Conf…...

FPGA图像处理(二)-----彩色图像灰度化

由于fpga实现除法相对复杂&#xff0c;故将除法变为乘法再移位。因此每种方法对图像输入数据均分3步进行&#xff0c;极其有效信号打三拍处理。 timescale 1ns / 1ps // // Description: 彩色图像灰度化 // module image_rgb2gray(input wire clk ,input wir…...

Ultralytics中的YOLODataset和BaseDataset

YOLODataset 和 BaseDataset 是 Ultralytics YOLO 框架中用于加载和处理数据集的两个关键类。 YOLODataset类&#xff08;ultralytics/data/dataset.py&#xff09;继承于 BaseDataset类&#xff08;ultralytics/data/base.py&#xff09; BaseDataset() BaseDataset 是一个…...

Mac 使用 Charles代理生成https服务

在Mac电脑上使用Charles软件通过代理生成HTTPS服务&#xff0c;让手机访问电脑的开发地址&#xff0c;可按以下步骤操作&#xff1a; 一、Charles软件设置 安装与启动Charles&#xff1a;从Charles官网下载并安装Charles软件&#xff0c;之后启动它。开启代理服务 点击菜单栏…...

【PostgreSQL】数据库主从库备份与高可用部署

文章目录 一、架构设计原理二、部署清单示例2.1 StatefulSet配置片段2.2 Service配置三、配置详解3.1 主节点postgresql.conf3.2 从节点配置四、初始化流程4.1 创建复制用户4.2 配置pg_hba.conf五、故障转移示例5.1 自动切换脚本5.2 手动提升从节点六、监控与维护6.1 关键监控指…...

ERP进销存系统源码,SaaS模式多租户ERP管理系统,SpringBoot、Vue、UniAPP技术框架

SaaS ERP管理系统源码&#xff0c;覆盖了整个生产企业所有部门的管理&#xff1a;采购、销售、仓库、生产、财务、质量、OA&#xff1a; ERP源码技术架构&#xff1a;SpringBootVueElementUIUniAPP ERP系统功能清单&#xff1a; 流程处理中心&#xff1a;待审批任务、已审批任…...

Decode rpc invocation failed: null -> DecodeableRpcInvocation

DecodeableRpcInvocation 异常情况解决方法 错误警告官方FAQ 异常情况 记录一下Dubbo调用异常 java.util.concurrent.ExecutionException: org.apache.dubbo.remoting.TimeoutException: Waiting server-side response timeout by scan timer. start time: 2025-05-07 22:09:5…...

VAE和Stable Diffusion的关系

文章目录 ✅ 简单回顾&#xff1a;什么是 VAE&#xff1f;&#x1f504; Stable Diffusion 和 VAE 的关系&#xff1a;&#x1f3af; 编码器&#xff1a;&#x1f4a5; 解码器&#xff1a; &#x1f914; 那 Stable Diffusion 本身是 VAE 吗&#xff1f;&#x1f9e0; 简要对比…...

stable Diffusion模型结构

详细描述一下stable Diffusion的推理过程 其实很简单 prompt先经过textencoder tokenizer&#xff0c;embedding 随机生成噪声图片 通过vae encode压缩成潜空间大小 unet with cross attn 去噪 并融合文本信息 # 上面两个信息如何混合 cross-attention sd模型中各种不同的采样器…...

Milvus(16):索引解释

索引是建立在数据之上的附加结构。其内部结构取决于所使用的近似近邻搜索算法。索引可以加快搜索速度&#xff0c;但在搜索过程中会产生额外的预处理时间、空间和 RAM。此外&#xff0c;使用索引通常会降低召回率&#xff08;虽然影响可以忽略不计&#xff0c;但仍然很重要&…...

数字化转型-4A架构之应用架构

系列文章 数字化转型-4A架构&#xff08;业务架构、应用架构、数据架构、技术架构&#xff09;数字化转型-4A架构之业务架构 前言 应用架构AA&#xff08;Application Architecture&#xff09;是规划支撑业务的核心系统与功能模块&#xff0c;实现端到端协同。 一、什么是应…...

中间件-RocketMQ

RocketMQ 基本架构消息模型消费者消费消息模式顺序消息机制延迟消息批量消息事务消息消息重试最佳实践 基本架构 nameServer: 维护broker列表信息&#xff0c;客户端连接时只需要连接nameServer。可配置成集群。 broker&#xff1a;broker分为master和slave&#xff0c;master负…...

AI开发playwright tool提示词

[TASK] 生成一个isModuleElementObject function&#xff0c;若element的qa-test class在对象moduleObj {"qa-test-mycourses-course": "qa-test-mycourses-course-title", "qa-test-discussion-module": "qa-test-discussion-description&…...

《Origin画百图》之带显著性标记的多因子分组柱状图

带显著性标记的多因子分组柱状图 需要数据&#xff1a; 组1&#xff08;大类&#xff09; 组2&#xff08;小类&#xff09; Y数据 Y误差 选中Y数据和Y误差两列数据&#xff0c; 点击绘图--分组图--多因子分组柱状图 数据列就是上一步选择的Y和Y误差&#xff0c; 点击子组…...

邮件发送频率如何设置?尊重文化差异是关键!

一、不同文化背景&#xff0c;邮件频率大不同 1.工作习惯不一样 一些西方国家&#xff0c;美国和欧洲工作时间和个人时间分得很清楚。工作日的上午 9 点到下午 5 点&#xff0c;这期间发邮件&#xff0c;收件人大概率会看也会回。但是在深夜或者周末发邮件容易让收件人觉得你…...

Python 识别图片上标点位置

Python识别图片上标点位置 要识别图片上的标点位置&#xff0c;可以使用Python中的OpenCV库。以下是几种常见的方法&#xff1a; 方法一&#xff1a;使用颜色阈值识别 import cv2 import numpy as np# 读取图片 image cv2.imread(image.jpg)# 转换为HSV颜色空间 hsv cv2.c…...

JDK Version Manager (JVMS)

以下是使用 JDK Version Manager (JVMS) 工具在Windows系统中安装JDK的详细步骤及注意事项&#xff0c;结合多篇搜索结果整理而成&#xff1a; --- 一、安装前准备 1. 下载JVMS - 访问 [GitHub Releases页面](https://github.com/ystyle/jvms/releases) 或镜像地址&#x…...

办公学习 效率提升 超级PDF处理软件 转换批量 本地处理

各位办公小能手们&#xff01;我跟你们说啊&#xff0c;有个软件叫超级PDF&#xff0c;那可真是PDF文件处理界的全能选手&#xff0c;专门解决咱们办公、学习时文档管理的各种难题。接下来我给大家好好唠唠它的厉害之处。 先说说它的核心功能。第一是格式转换&#xff0c;这软件…...

阿里云服务器-centos部署定时同步数据库数据-dbswitch

前言&#xff1a; 本文章介绍通过dbswitch工具实现2个mysql数据库之间实现自动同步数据。 应用场景&#xff1a;公司要求实现正式环境数据库数据自动冷备 dbswitch依赖环境&#xff1a;git ,maven,jdk 方式一&#xff1a; 不需要在服务器中安装git和maven&#xff0c;直接用…...

C++函数栈帧详解

函数栈帧的创建和销毁 在不同的编译器下&#xff0c;函数调用过程中栈帧的创建是略有差异的&#xff0c;具体取决于编译器的实现&#xff01; 且需要注意的是&#xff0c;越高级的编译器越不容易观察到函数栈帧的内部的实现&#xff1b; 关于函数栈帧的维护这里我们要重点介…...

Wireshark抓账号密码

训练内容&#xff1a; 1. 安装Ethereal或者Wireshark&#xff0c;熟悉网络嗅探器的使用方法&#xff1b; 2. 实现浏览器与IIS服务器的ssl安全访问&#xff1b; 3. 利用网络嗅探器截获浏览器访问IIS服务器之间数据包&#xff0c;包括有ssl安全连接&#xff08;https方式&am…...

【hot100】bug指南记录1

之前学了一阵C&#xff0c;还是更熟悉C的语法呀&#xff0c;转Java还有点不适应........ 这个系列纯纯记录自己刷题犯的愚蠢的错误......hhhh&#xff0c;我是人&#xff0c;one 愚蠢的码人...... 巩固巩固基础好吗&#xff1f;&#xff01;编程菜鸟.......hhh&#xff0c;又…...

物联网从HomeAssistant开始

文章目录 一、在树梅派5上安装home-assistant二、接入米家1.对比下趋势2.手动安装插件3.配置方式 三、接入公牛1.手动安装插件2.配置方式 一、在树梅派5上安装home-assistant https://www.home-assistant.io/installation/ https://github.com/home-assistant/operating-syste…...

2025年渗透测试面试题总结-网络安全、Web安全、渗透测试笔试总结(一)(附回答)(题目+回答)

网络安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 网络安全、Web安全、渗透测试笔试总结(一) 1.什么是 WebShell? 2.什么是网络钓鱼&#xff1f; 3.你获取网络…...

C++ set和map系列(关联式容器)的介绍及使用

欢迎来到干货小仓库 "一个好汉三个帮&#xff0c;程序员同样如此" 1.关联式容器 STL中的容器分为两类&#xff0c;序列式容器和关联式容器。 序列式容器&#xff1a;例如STL库中的vector、list和deque、forward_list(C11)等&#xff0c;这些容器统称为序列式容器&…...

C#与Halcon联合编程

一、加载图片 导入并初始化 using HalconDotNet; ho_Image new HObject();需要在引用中导入 halcondotnet.dll 关联句柄 打开新窗口 //创建一个句柄变量 绑定winform 窗口 HTuple winfowFater this.pictureBox1.Handle; //打开新的窗口 HOperatorSet.SetWindowAttr(&qu…...

5.0.4 VisualStateManager(视觉状态管理器)使用说明

在 WPF 中,VisualStateManager(视觉状态管理器)是用于管理控件在不同状态下的外观变化的核心组件。它通过定义视觉状态(如按钮的默认、悬停、按下状态)和状态过渡动画,使控件在不同交互场景下动态切换样式,而无需重写整个控件模板。以下是其核心用法和示例: 1. 基本概…...

onenet连接微信小程序(mqtt协议)

一、关于mqtt协议 mqtt协议常用于物联网&#xff0c;是一种轻量级的消息推送协议。 其中有三个角色&#xff0c;Publisher设备&#xff08;客户端&#xff09;发布主题到服务器&#xff0c;其他的设备通过订阅主题&#xff0c;获取该主题下的消息&#xff0c;Publisher可以发…...

IT需求规格说明书,IT软件系统需求设计文档(DOC)

1 范围 1.1 系统概述 1.2 文档概述 1.3 术语及缩略语 2 引用文档 3 需求 3.1 要求的状态和方式 3.2 系统能力需求 3.3 系统外部接口需求 3.3.1 管理接口 3.3.2 业务接口 3.4 系统内部接口需求 3.5 系统内部数据需求 3.6 适应性需求 3.7 安全性需求 3.8 保密性需…...

探索 DevExpress:构建卓越应用的得力助手

探索 DevExpress&#xff1a;构建卓越应用的得力助手 在当今竞争激烈的软件开发领域&#xff0c;打造高效、美观且功能强大的应用程序是每个开发者的追求。而 DevExpress 作为一款备受瞩目的开发工具&#xff0c;为开发者们提供了实现这一目标的有力支持。在本专栏博客中&…...

康养休闲旅游住宿服务实训室:构建产教融合新标杆

随着健康中国战略的深入实施与银发经济市场的持续扩张&#xff0c;康养休闲旅游作为融合健康管理、文化体验与休闲度假的复合型产业&#xff0c;正迎来前所未有的发展机遇。北京凯禾瑞华科技有限公司依托其在智慧康养领域的技术积淀与产业洞察&#xff0c;创新推出“康养休闲旅…...

Python 程序设计教程:构建您的第一个计算器类

Python 程序设计教程:构建您的第一个计算器类 1. 引言:为什么要学习类? 面向对象编程 (Object-Oriented Programming, OOP) 是一种强大的编程范式,它通过将数据和操作数据的函数(方法)捆绑在一起来组织和结构化代码 1。类 (Class) 是 OOP 的核心概念,不仅在 Python 中…...

深入浅出理解常见的分布式ID解决方案

分布式ID在构建大规模分布式系统时扮演着至关重要的角色&#xff0c;主要用于确保在分布式环境中数据的唯一性和一致性。以下是分布式ID的几个主要作用&#xff1a; 确保唯一性&#xff1a;在分布式系统中&#xff0c;可能有成千上万个实例同时请求ID。分布式ID生成系统能保证即…...