当前位置: 首页 > news >正文

批量归一化(Batch Normalization)原理与PyTorch实现

批量归一化(Batch Normalization)是加速深度神经网络训练的常用技术。本文通过Fashion-MNIST数据集,演示如何从零实现批量归一化,并对比PyTorch内置API的简洁实现方式。


1. 从零实现批量归一化

1.1 批量归一化函数实现

import torch
from torch import nn
from d2l import torch as d2ldef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled():# 预测模式下使用移动平均X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 全连接层:特征维计算均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 卷积层:通道维计算均值和方差mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下更新移动平均X_hat = (X - mean) / torch.sqrt(var + eps)moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和平移return Y, moving_mean.data, moving_var.data

1.2 批量归一化层类

class BatchNorm(nn.Module):def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)Y, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

1.3 构建含批量归一化的网络

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))

1.4 训练与结果

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

输出结果

loss 0.277, train acc 0.898, test acc 0.835
28009.9 examples/sec on cuda:0

训练曲线

2. 使用PyTorch内置批量归一化

2.1 简洁实现网络结构

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5),nn.BatchNorm2d(6),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(256, 120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120, 84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84, 10))

 2.2 训练与结果对比

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

输出结果

loss 0.264, train acc 0.902, test acc 0.849
44608.4 examples/sec on cuda:0

训练曲线

3. 关键参数分析

查看第一个批量归一化层的缩放(gamma)和偏移(beta)参数:

print(net[1].gamma.reshape((-1,)), 
print(net[1].beta.reshape((-1,)))

 输出

(tensor([0.3957, 2.2124, 2.8581, 2.1908, 3.6253, 3.5650], device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
tensor([ 0.1832, -2.5689, -3.2450, -0.7221, 1.1290, 2.2353], device='cuda:0', grad_fn=<ReshapeAliasBackward0>))

4. 结论

  1. 性能对比:PyTorch内置实现相比手动实现,测试准确率从83.5%提升到84.9%,且训练速度更快(44k样本/秒 vs 28k样本/秒)

  2. 实现差异:内置API自动处理设备迁移和参数初始化,代码更简洁

  3. 注意事项:全连接层使用nn.BatchNorm1d,卷积层使用nn.BatchNorm2d

完整代码已通过测试,可直接复现实验结果。批量归一化能有效加速收敛并提升模型泛化能力,是深度网络设计的必备组件。


提示:运行代码需要安装d2l库(pip install d2l)并支持GPU环境。

相关文章:

批量归一化(Batch Normalization)原理与PyTorch实现

批量归一化&#xff08;Batch Normalization&#xff09;是加速深度神经网络训练的常用技术。本文通过Fashion-MNIST数据集&#xff0c;演示如何从零实现批量归一化&#xff0c;并对比PyTorch内置API的简洁实现方式。 1. 从零实现批量归一化 1.1 批量归一化函数实现 import t…...

Flutter 文本组件深度剖析:从基础到高级应用

引言 在 Flutter 应用开发中&#xff0c;文本是向用户传达信息的重要媒介。Flutter 提供了丰富且强大的文本组件和相关属性&#xff0c;使开发者能够轻松实现多样化的文本展示效果。无论是简单的静态文本显示&#xff0c;还是复杂的富文本渲染&#xff0c;Flutter 都能满足需求…...

FABC是什么?

在销售和品牌营销领域&#xff0c;FABC 是一种用于构建销售话术和营销信息的框架&#xff0c;其全称为 Features&#xff08;特点&#xff09;、Advantages&#xff08;优势&#xff09;、Benefits&#xff08;利益&#xff09;、Case&#xff08;案例&#xff09;。该模型帮助…...

【MySQL】MVCC工作原理、事务隔离机制、undo log回滚日志、间隙锁

一、什么是MVCC&#xff1f; MVCC&#xff0c;即 Multiversion Concurrency Control&#xff08;多版本并发控制&#xff09;&#xff0c;它是数据库实现并发控制的一种方式。 MVCC 的核心思想是&#xff1a; 为每个事务提供数据的“快照”版本&#xff0c;从而避免加锁&…...

Spring Boot 集成 RocketMQ 全流程指南:从依赖引入到消息收发

前言 在分布式系统中&#xff0c;消息中间件是解耦服务、实现异步通信的核心组件。RocketMQ 作为阿里巴巴开源的高性能分布式消息中间件&#xff0c;凭借其高吞吐、低延迟、高可靠等特性&#xff0c;成为企业级应用的首选。而 Spring Boot 通过其“约定优于配置”的设计理念&a…...

PCL 点云RANSAC提取平面(非内置函数)

文章目录 一、算法实现1.1实现步骤二、实现代码三、实现效果参考资料一、算法实现 1.1实现步骤 1、确定模型。三个点确定一个平面,方程式为 a x + b y + c z + 1 = 0 ax+by+cz+1=0...

中介者模式:理论、实践与 Spring 源码解析

摘要 本论文以中介者模式为核心,系统阐述其设计原理、应用场景及在 Spring 框架中的实现机制。通过机票预订系统、银行交易系统等典型案例,具象化展示模式如何解耦复杂对象交互;结合 Spring 5.3.29 源码,深入剖析事件驱动模型中ApplicationEventPublisher与ApplicationLis…...

2025.04.14【Table】| 生信数据表图技巧

Custom title A set of examples showing how to customize the titles of a table made with GT Custom footer How to customize the footer and the references section of a gt table 文章目录 Custom titleCustom footer 生信数据可视化&#xff1a;Table图表详解1. R语…...

Unified Modeling Language,统一建模语言

UML&#xff08;Unified Modeling Language&#xff0c;统一建模语言&#xff09;是一种标准化的图形化建模语言&#xff0c;用于可视化、规范和文档化软件系统的设计。UML 提供了一套通用的符号和规则&#xff0c;帮助开发者、架构师和团队成员更好地理解和沟通软件系统的结构…...

OCP证书有效期是永久,但需要更新

在数据库管理领域&#xff0c;OCP证书作为Oracle认证体系中的重要组成部分&#xff0c;一直是数据库专业人士追求的目标。许多考证者会有疑惑:OCP证书是永久有效的吗&#xff1f;需要更新吗&#xff1f; Oracle官方明确规定&#xff1a;OCP证书一经获得&#xff0c;终身有效。无…...

服务器本地搭建

socket函数 它用于创建一个新的套接字&#xff08;socket&#xff09;。 函数原型 #include <sys/socket.h> int socket(int domain, int type, int protocol);参数解释 domain&#xff1a;它指定了通信所使用的协议族&#xff0c;常见的取值如下&#xff1a; AF_INET…...

调节磁盘和CPU的矛盾——InnoDB的Buffer Pool

缓存的重要性 无论是用于存储用户数据的索引【聚簇索引、二级索引】还是各种系统数据&#xff0c;都是以页的形式存放在表空间中【对一个/几个实际文件的抽象&#xff0c;存储在磁盘上】如果需要访问某页的数据&#xff0c;就会把完整的页数据加载到内存中【即使只访问页中的一…...

[dp12_回文子串] 最长回文子串 | 分割回文串 IV

目录 1.回文子串 题解 2.最长回文子串 题解 3.分割回文串 IV 题解 dp[i][j] 表示 s 字符串 [i, j] 的子串&#xff0c;是否是回文串( 建始末表&#xff09; 将两个 for 循环的结果&#xff0c;借助二维 dp 来存 1.回文子串 链接&#xff1a;647. 回文子串 给你一个字符…...

分布式应用架构的演变

整体演变过程 第一阶段&#xff1a;单一应用架构 单一应用架构&#xff0c;是把所有服务都放在一个项目中&#xff0c;进行打包部署到服务器上&#xff0c;如果流量特别大的话&#xff0c;就在另外的服务器上部署相同的功能模块用来分摊流量。但是这样的话&#xff0c;一旦有某…...

zephyr RTOS 中 bt_le_adv_start函数的功能应用

目录 概述 1 功能 1.1 功能介绍 1.2 函数原型 2 参数说明 2.1 广播参数&#xff08;bt_le_adv_param&#xff09; 2.2 常用广播选项&#xff08;options&#xff09; 2.3 广播数据&#xff08;bt_data&#xff09; 3 示例代码 3.1 启动可连接广播&#xff08;带设备名…...

双按键控制LED(中断优先级)

1.启动时&#xff0c;两个LED灯熄灭&#xff0c;1秒钟后&#xff08;定时器实现&#xff09;&#xff0c;LED自动点亮&#xff1b; 2.按键1按下后&#xff0c;通过中断int0把两个LED熄灭5s时间&#xff0c;int0优先级设置为最高&#xff08;优先级必须设置&#xff0c;设置后才…...

美团即时零售大动作,将独立的闪购将会改变什么?

4月12日上午&#xff0c;美团核心本地商业CEO王莆中在社交媒体上发文&#xff0c;宣布美团将在下周正式发布即时零售品牌&#xff0c;标志着美团将进一步发展即时零售业务。 首先&#xff0c;从市场格局角度来看&#xff0c;美团将独立的闪购品牌推出&#xff0c;会进一步加剧…...

如何安装git?

以下是 Windows、macOS 和 Linux 系统安装 Git 的详细步骤&#xff1a; 一、Windows 系统安装 Git 下载安装包 访问 Git 官网下载页&#xff0c;点击下载 Windows 版安装程序&#xff08;如 Git-2.45.1-64-bit.exe&#xff09;。 运行安装程序 安装选项&#xff1a; 选择安装路…...

Ubuntu上docker、docker-compose的安装

今天来实践下Ubuntu上面安装docker跟docker-compose&#xff0c;为后面安装dify、fastgpt做准备。 一、安装docker sudo apt-get updatesudo apt-get install docker.io 然后系统输入 docker --version 出现下图即为docker安装成功。 二、安装docker-compose 我先看下系统…...

ubuntu如何设置静态ip

服务器有时是通过dhcp动态获取ip的&#xff0c;有时出于远程登录方便的考虑&#xff0c;会将其设置为静态ip&#xff0c;以下是设置静态ip的方法 在 Ubuntu 中设置静态 IP 的方法取决于你使用的网络管理工具&#xff08;如 netplan、NetworkManager 或 ifconfig&#xff09;。…...

js原型和原型链

js原型&#xff1a; 1、原型诞生的目的是什么呢&#xff1f; js原型的产生是为了解决在js对象实例之间共享属性和方法&#xff0c;并把他们很好聚集在一起&#xff08;原型对象上&#xff09;。每个函数都会创建一个prototype属性&#xff0c;这个属性指向的就是原型对象。 …...

大数据 - 2. Hadoop - HDFS

前言 HDFS&#xff1a;分布式文件系统 为什么海量数据需要分布式存储技术&#xff1f; 文件过大时&#xff0c;单台服务器无法承担&#xff0c;要靠数量来解决。数量的提升带来的是网络传输、磁盘读写、CPU、内存等各方面的提升。 众多的服务器一起工作&#xff0c;如何保证…...

嵌入式硬件常用总线接口知识体系总结和对比

0.前言 在嵌入式工程实现中,多多少少我们都使用过总线,各种各样的总线应用于不同场合,不同场景有不同的优势,但是我们在作为工程师过程中在如何选择项目合适的总线,根据什么来选?需要我们对项目全局和总线特征有所了解,本文目的就是对比多种总线的关键特征 我们在聊到…...

prime 1 靶场笔记(渗透测试)

环境说明&#xff1a; 靶机prime1和kali都使用的是NAT模式&#xff0c;网段在192.168.144.0/24。 Download (Mirror): https://download.vulnhub.com/prime/Prime_Series_Level-1.rar 一.信息收集 1.主机探测&#xff1a; 使用nmap进行全面扫描扫描&#xff0c;找到目标地址及…...

(二十四)安卓开发中的AppCompatActivity详解

在安卓开发中&#xff0c;AppCompatActivity 是一个非常核心的类&#xff0c;它继承自 Activity&#xff0c;并通过 Android Support Library&#xff08;现已迁移至 AndroidX&#xff09;提供了对 ActionBar 和 Material Design 的支持。它的主要作用是帮助开发者在不同版本的…...

AI大模型+全渠道整合:容联七陌智能客服赋能制造业升级

自《中国制造2025》战略提出以来&#xff0c;制造业的智能化发展进入快车道&#xff0c;但行业仍面临劳动力成本上升、供应链不透明、客户需求碎片化等挑战。企业亟需通过技术手段实现降本增效&#xff0c;而智能化客户服务成为关键突破口。 与此同时&#xff0c;客服行业正经历…...

Vue 技术解析:从核心概念到实战应用

Vue.js 是一款流行的渐进式前端框架&#xff0c;以其简洁的 API、灵活的组件化结构和高效的响应式数据绑定而受到开发者的广泛欢迎。本文将深入解析 Vue 技术的核心概念、原理和应用场景&#xff0c;帮助开发者更好地理解和使用 Vue.js。 一、Vue 的设计哲学与核心概念 &…...

中英文提示词对AI IDE编程能力影响有多大?

深度剖析 &#x1f9e0;&#xff1a;中英文提示词对AI IDE编程能力影响有多大&#xff1f;&#xff08;附实战建议&#xff09; 作者&#xff1a;AI助手 | 日期&#xff1a;2023-10-27 | 标签&#xff1a;AI, IDE, Prompt Engineering, LLM, 编程效率 摘要&#xff1a;随着 AI…...

ARM处理器程序烧写方式

一、烧写原理 无论是jtag还是串口烧写&#xff0c;本质都是先通过上位机&#xff08;keil 或者flymcu或者芯片官方上位机等烧写bin的上位机&#xff09;往mcu的ram里烧写一段代码即.FLM文件&#xff0c;这段代码在上位机&#xff08;keil体现在配置项里&#xff0c;flymcu应该…...

AI 项目详细开发步骤指南

AI 项目详细开发步骤指南 一、环境搭建详解 1. JDK 17 安装与配置 Windows 系统安装步骤&#xff1a; 访问 Oracle 官网下载 JDK 17 安装包&#xff1a;https://www.oracle.com/java/technologies/downloads/#java17下载 Windows x64 Installer 版本双击安装包&#xff0c;…...

文本纠错WPS插件:提升文档质量的利器

文本纠错WPS插件&#xff1a;提升文档质量的利器 引言 在数字化办公日益普及的今天&#xff0c;文档的质量直接影响到我们的工作效率和形象。一个错别字或标点错误&#xff0c;可能就会让我们的专业形象大打折扣。今天&#xff0c;我要向大家介绍一款强大的WPS插件——文本纠…...

Node.js 模块包的管理和使用是

一、模块包的概念 1.模块分类&#xff1a; 核心模块&#xff1a;Node.js 内置模块&#xff08;如 fs, http, path&#xff09;&#xff0c;无需安装直接引用。 本地模块&#xff1a;开发者自己编写的模块文件&#xff0c;通过相对路径引入。 第三方模块&#xff1a;通过 npm…...

腾讯云golang一面

go垃圾回收机制 参考自&#xff1a;https://zhuanlan.zhihu.com/p/334999060 go 1.3 标记清除法 缺点 go 1.5 三色标记法 屏障机制 插入屏障 但是如果栈不添加,当全部三色标记扫描之后,栈上有可能依然存在白色对象被引用的情况(如上图的对象9). 所以要对栈重新进行三色标记扫…...

【Three.js基础学习】35.Particles Cursor Animation Shader

前言 关于着色器应用和画布&#xff0c;实现黑白色照片动态效果 一、代码 script.js ​ import * as THREE from three import { OrbitControls } from three/addons/controls/OrbitControls.js import particlesVertexShader from ./shaders/particles/vertex.glsl import p…...

安卓性能调优之-掉帧测试

掉帧指的是某一帧没有在规定时间内完成渲染&#xff0c;导致 UI 画面不流畅&#xff0c;产生视觉上的卡顿、跳帧现象。 Android目标帧率&#xff1a; 一般情况下&#xff0c;Android设备的屏幕刷新率是60Hz&#xff0c;即每秒需要渲染60帧&#xff08;Frame Per Second, FPS&a…...

六、分布式嵌入

六、分布式嵌入 文章目录 六、分布式嵌入前言一、先要配置torch.distributed环境二、Distributed Embeddings2.1 EmbeddingBagCollectionSharder2.2 ShardedEmbeddingBagCollection 三、Planner总结 前言 我们已经使用了TorchRec的主模块&#xff1a;EmbeddedBagCollection。我…...

13-scala模式匹配

模式匹配是检查某个值&#xff08;value&#xff09;是否匹配某一个模式的机制&#xff0c;一个成功的匹配同时会将匹配值解构为其组成部分。它是Java中的switch语句的升级版&#xff0c;同样可以用于替代一系列的 if/else 语句。 语法 一个模式匹配语句包括一个待匹配的值&a…...

Multisim使用说明详尽版--(2025最新版)

一、Multisim14前言 1.1、主流电路仿真软件 1. Multisim&#xff1a;NI开发的SPICE标准仿真工具&#xff0c;支持模拟/数字电路混合仿真&#xff0c;内置丰富的元件库和虚拟仪器&#xff08;示波器、频谱仪等&#xff09;&#xff0c;适合教学和竞赛设计。官网&#xff1a;艾…...

试一下阿里云新出的mcp服务

前言 MCP这段时间的发展可谓是如火如荼&#xff0c;各种教程也是层出不穷&#xff0c;基本的教程都是如何集成各类型的mcp(比如高德地图)到开发工具(比如cursor)&#xff0c;效果很好&#xff0c;但是有个问题就是&#xff0c;配置教程较为繁琐。 阿里云悄然上线的mcp 今天早上…...

正弦波有效值和平均值(学习笔记)

一个周期的正弦波在坐标轴上围的面积有多大&#xff1f; 一般正弦波以 y Asin(wx)表示&#xff0c;其中A为振幅&#xff0c;W为角速度。周期T 2π/w; 确定积分区间是x 0&#xff0c;到x 2π。 计算绝对值积分&#xff1a; 变量代还&#xff1a;wx θ&#xff0c;dx dθ…...

科研软件分享

这个帖子不定期更新&#xff0c;分享博主自己使用的很好用的科研软件 1 connectedpaper Connected Papers | Find and explore academic papers 2 Semantic Scholar...

Python(12)深入解析Python参数传递:从底层机制到高级应用实践

目录 一、参数传递的编程哲学1.1 参数传递的本质1.2 参数传递类型矩阵 二、参数传递核心规则2.1 位置参数与关键字参数2.2 可变参数处理 三、参数传递高级特性3.1 类型约束与提示3.2 参数内存优化 四、参数传递工程实践4.1 防御性参数校验4.2 参数依赖注入 五、参数传递性能优化…...

MVCC是什么?MVCC的作用是什么?MVCC实现方式有哪些?

MVCC&#xff08;多版本并发控制&#xff09;详解 一、MVCC是什么&#xff1f; MVCC&#xff08;Multi-Version Concurrency Control&#xff0c;多版本并发控制&#xff09;是数据库管理系统中的一种并发控制机制&#xff0c;它通过维护数据的多个版本来实现非阻塞读和高并发…...

007.Gitlab CICD缓存与附件

文章目录 缓存与产物缓存与产物概述 同分支不同job数据共享默认数据共享不同 Job 数据共享 不同分支相同job数据共享跨分支同job数据共享 不同分支不同job数据共享跨分支跨job数据共享 将文件/夹保存为附件产物介绍创建产物跨job共享产物 缓存与产物 缓存与产物概述 缓存是一…...

A006-基于Selenium和JMeter的吉屋web端的自动化测试设计与实现

产出&#xff1a;自动测试脚本测试用例开题报告自动化测试报告论文jmeter性能测试 --------------------**论文主要内容***----- 第1章 吉屋web端需求分析 1.1 吉屋web端功能需求分析 由于社会对知识获取的需求不断增长&#xff0c;海量繁多的房屋信息已难以依靠传统人工高效…...

图像预处理-边缘填充,透视变换和色彩空间基础

一.边缘填充 一般来图片操作之后会有空区域&#xff0c;就是对空出来的区域进行了像素值的填充&#xff0c;(0&#xff0c;0&#xff0c;0)也就是黑色像素值的填充。 # 默认黑色填充 import cv2 as cvimg cv.imread(../images/lena.png) # 先让原图旋转45度 M cv.getRotatio…...

数字化赋能,众趣科技助力智慧园区深化管理运营能力

数字化、网络化和智能化&#xff0c;被公认为是未来社会发展的大趋势。随着全球物联网、云计算等新一代信息技术不断成熟&#xff0c;传统的招商管理运营模式难以满足园区当下所需&#xff0c;以“园区互联网”为理念的“智慧园区”应运而生&#xff0c;同时融入社交、移动、物…...

《AI大模型应知应会100篇》 第16篇:AI安全与对齐:大模型的灵魂工程

第16篇&#xff1a;AI安全与对齐&#xff1a;大模型的灵魂工程 摘要 在人工智能技术飞速发展的今天&#xff0c;大型语言模型&#xff08;LLM&#xff09;已经成为推动社会进步的重要工具。然而&#xff0c;随着这些模型能力的增强&#xff0c;如何确保它们的行为符合人类的期…...

MCP的另一面

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…...

Golang|锁相关

文章目录 并发安全性与原子操作读写锁分布式锁 并发安全性与原子操作 普通数据类型在并发读写中是会出现问题的&#xff0c;有时候操作会被吞&#xff0c;导致脏写&#xff0c;比如上面n加了两次应该为2&#xff0c;但是由于并发&#xff0c;n最后还是只加了一次 读写锁 sync.…...