《深度学习实践教程》[吴微] ch-5 3/5层全连接神经网络
一、练习课本上3层全连接神经网络识别手写数字。
答案代码:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定义一些超参数
batch_size = 64
learning_rate = 0.02class Batch_Net(nn.Module):"""在上面的Activation_Net的基础上,增加了一个加快收敛速度的方法——批标准化"""def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Batch_Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 数据预处理。transforms.ToTensor()将图片转换成PyTorch中处理的对象Tensor,并且进行标准化(数据在0~1之间)
# transforms.Normalize()做归一化。它进行了减均值,再除以标准差。两个参数分别是均值和标准差
# transforms.Compose()函数则是将各种预处理的操作组合到了一起
data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])# 数据集的下载器
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 选择模型
#model = net.simpleNet(28 * 28, 300, 100, 10)
# model = Activation_Net(28 * 28, 300, 100, 10)
model = Batch_Net(28 * 28, 300, 100, 10)
#if torch.cuda.is_available():# model = model.cuda()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
epoch = 0
for data in train_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()else:img = Variable(img)label = Variable(label)out = model(img)loss = criterion(out, label)print_loss = loss.data.item()optimizer.zero_grad()loss.backward()optimizer.step()epoch+=1if epoch%100 == 0:print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))# 模型评估
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()out = model(img)loss = criterion(out, label)eval_loss += loss.data.item()*label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)),eval_acc / (len(test_dataset))
))
运行结果:
……
二、课后练习
设计一个5层全连接神经网络,实现给MNIST数据集的分类,其中:
batch_size = 32, learning_rate = 0.01, epochs = 100, input_size = 28*28,
hidden_size1 = 400, hidden_size2 = 300, hideen_size3 = 200, hidden_size4 = 100.
隐藏层中要带有激励函数ReLU()和批标准化函数。
答案代码:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定义一些超参数
batch_size = 32
learning_rate = 0.01class Batch_Net(nn.Module):"""在上面的Activation_Net的基础上,增加了一个加快收敛速度的方法——批标准化"""def __init__(self, in_dim, n_hidden_1, n_hidden_2, n_hidden_3, n_hidden_4,out_dim):super(Batch_Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, n_hidden_3), nn.BatchNorm1d(n_hidden_3), nn.ReLU(True))self.layer4 = nn.Sequential(nn.Linear(n_hidden_3, n_hidden_4), nn.BatchNorm1d(n_hidden_4), nn.ReLU(True))self.layer5 = nn.Sequential(nn.Linear(n_hidden_4, out_dim))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)return x# 数据预处理。transforms.ToTensor()将图片转换成PyTorch中处理的对象Tensor,并且进行标准化(数据在0~1之间)
# transforms.Normalize()做归一化。它进行了减均值,再除以标准差。两个参数分别是均值和标准差
# transforms.Compose()函数则是将各种预处理的操作组合到了一起
data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])# 数据集的下载器
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 选择模型
#model = net.simpleNet(28 * 28, 300, 100, 10)
# model = Activation_Net(28 * 28, 300, 100, 10)
model = Batch_Net(28 * 28, 400, 300,200, 100, 10)
#if torch.cuda.is_available():# model = model.cuda()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
epoch = 0
for data in train_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()else:img = Variable(img)label = Variable(label)out = model(img)loss = criterion(out, label)print_loss = loss.data.item()optimizer.zero_grad()loss.backward()optimizer.step()epoch+=1if epoch%100 == 0:print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))# 模型评估
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()out = model(img)loss = criterion(out, label)eval_loss += loss.data.item()*label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)),eval_acc / (len(test_dataset))
))
运行结果:
从程序运行结果来看,loss为0.111534,准确率为96.93%。
声明:文章仅供学习使用。著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
相关文章:
《深度学习实践教程》[吴微] ch-5 3/5层全连接神经网络
一、练习课本上3层全连接神经网络识别手写数字。 答案代码: import torch from torch import nn, optim from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import datasets, transforms# 定义一些超参数 batch_size…...
OrcaFex11.5
OrcaFlex 11.5是一款专业的海洋工程动态分析软件 由英国Orcina公司开发 主要用于模拟和分析海洋结构物在复杂海洋环境中的动态响应 该软件广泛应用于海上油气开发 海上风电 海洋可再生能源等领域 OrcaFlex 11.5具有强大的建模和仿真能力 支持多种海洋结构物的模拟 包括船舶 …...
MUX-vlan
MUX-VLAN 理论环节 1. 定义与核心作用 Principal VLAN(主VLAN) 是 MUX VLAN(Multiplex VLAN)架构的核心组件,充当公共资源的访问枢纽,实现以下核心功能: 资源共享:允许所有从VLAN…...
vue3中解决 return‘ inside ‘finally‘ block报错的问题
vue3中解决 return’ inside ‘finally’ block报错的问题 这个错误信息通常表明你在使用Vue 3框架时,在finally块中不正确地使用了return语句。在JavaScript中,finally块是保证执行的最后一个代码块,用于释放资源或执行清理操作,…...
TestStand API 简介
TestStand API 简介 在自动化测试领域,TestStand 凭借其灵活的架构和强大的功能,成为众多开发者的首选工具。而 TestStand API(Application Programming Interface,应用程序编程接口)则是打开 TestStand 强大功能的 “…...
vue2+element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能
vue2element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能 文章目录 vue2element实现Table表格嵌套输入框、选择器、日期选择器、表单弹出窗组件的行内编辑功能前言一、准备工作二、行内编辑1.嵌入Input文本输入框1.1遇到问题1.文本框内容修改失…...
【Docker系列】使用格式化输出与排序技巧
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
针对面试-redis篇
1. 缓存穿透 什么是缓存穿透? 缓存穿透就是有人查询一个不存在的数据,数据库查询不到数据也不会直接写入缓存,就会导致每次请求都查数据库。 解决方案一:缓存空数据 当数据库中不存在该数据时,直接把查到的空数据给…...
HTML8:媒体元素
视频和音频 视频元素 video 音频 audio <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>媒体元素学习</title> </head> <body> <!--音频和视频 src:资源路径 controls:控制条…...
把其他conda的env复制到自己电脑的conda上
把其他conda的env复制到自己电脑的conda上 一 拷贝 将要拷贝的env环境拷贝到自己电脑的放置env环境的文件夹中 二 添加配置 找到.conda文件夹下的environments.txt文件,添加配置 三 测试 查看环境是否拷贝成功 激活环境 自此就拷贝成功了!&am…...
抖音热门视频评论数追踪爬虫获取
自动追踪抖音账号收藏夹视频的评论数变化 功能: 1、自动追踪特定抖音账号收藏夹视频热度变化,评论增速超过x,自动通知到钉钉或飞书 2、最新最先进的js逆向算法,无封号风险 3、支持私有化定制 4、可同时追踪500-5w个视频的热度…...
Hive优化秘籍:大数据处理加速之道
目录 一、认识 Hive 性能瓶颈 二、优化从基础开始:查询语句 2.1 列与分区裁剪 2.2 谓词下推 2.3 合理使用排序 三、解决数据倾斜难题 3.1 数据倾斜原因剖析 3.2 针对性优化策略 四、优化 join 操作 4.1 MapJoin 的应用 4.2 大表 join 优化技巧 五、调整 …...
机器学习例题——预测facebook签到位置(K近邻算法)和葡萄酒质量预测(线性回归)
一、预测facebook签到位置 代码展示: import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import…...
10B扩散文生图模型F-Lite技术报告速读
F Lite 技术报告解析 一、研究背景与目标 F Lite 是一个开源的 100 亿参数文本到图像的扩散变换器(DiT)模型。该研究的目标是探索在中等数据规模和计算资源条件下,大规模扩散模型的性能边界。F Lite 基于 Freepik 内部数据集训练࿰…...
源码分析之Leaflet中Marker
概述 Marker类用于创建一个标记点对象,可以用于在地图上添加标记点。Marker类继承自Layer类,提供了一些方法用于创建标记点对象。 源码分析 源码实现 Marker类实现如下: export var Marker Layer.extend({options: {icon: new IconDefault(), // 默认图标实例…...
从0开始学习大模型--Day2--大模型的工作流程以及初始Agent
大模型的工作流程 分词化(Tokenization)与词表映射 分词化(Tokenization)是自然语言处理(NLP)中的重要概念,它是将段落和句子分割成更小的分词(token)的过程。 将一个…...
P48-56 应用游戏标签
这一段课主要是把每种道具的游戏Tag进行了整理与应用 AuraAbilitySystemComponentBase.h // Fill out your copyright notice in the Description page of Project Settings. #pragma once #include "CoreMinimal.h" #include "AbilitySystemComponent.h"…...
4.29 tag的完整实现和登录页面的初步搭建
解释了v-for中每个属性的作用: 打印当前route的信息:(里面会有path的信息)当前的路由信息吧! handleMenu() 菜单选择!点击左侧菜单的栏目就会显示在Home.vue的tag上 这个方法的作用是让Home.vue上出现对应的…...
【Vue.js】 插槽通信——具名插槽通信
目录 前景基本语法命名规则默认内容使用建议 具体实例父组件 index.vue子组件 Category.vue 效果 前景 下面的父子组件代码仍然在Vue.js演练平台直接运行 基本语法 在子组件中定义插槽 <!-- Category.vue --> <slot name"插槽名称">默认内容</slo…...
从设备交付到并网调试:CET中电技术分布式光伏全流程管控方案详解
四月的最后一个工作日,当分布式光伏电站并网指示灯依次亮起的瞬间,CET中电技术与客户共同交出了一份满意的答卷。面对430政策窗口期的考验,我们凭借可靠的技术和高效的团队协作,在系统调试与并网对接的每个步骤都展现出过硬能力&a…...
(十)深入了解AVFoundation-采集:录制视频功能的实现
引言 在前文章中,我们深入探讨了如何通过 AVCaptureSession 配置 iOS 中的捕捉输入及输出。并通过使用 AVCaptureDeviceInput 和 AVCapturePhotoOutput,我们实现了基础的照片捕获功能,并配置了 PHPreviewView 来显示实时预览。 在本篇中&am…...
数据分析汇报七步法:用结构化思维驱动决策
在当今数据驱动的商业环境中,高效的数据汇报不仅是信息传递的工具,更是撬动决策的杠杆。基于您提供的五张核心图示,我们提炼出一套「七步汇报框架」,将复杂的数据分析转化为清晰的行动指南。这套方法论通过「现状-诊断-预见…...
推荐两本集成电路制作书籍
本书共分19章,涵盖先进集成电路工艺的发展史,集成电路制造流程、介电薄膜、金属化、光刻、刻蚀、表面清洁与湿法刻蚀、掺杂、化学机械平坦化,器件参数与工艺相关性,DFM(Design for Manufacturing)ÿ…...
认识Grafana及其面板(Panel)
Grafana简介 Grafana 是一款开源的数据可视化与监控平台,以其强大的数据展示能力、灵活的插件生态和广泛的兼容性,成为企业监控、IT运维、DevOps、物联网(IoT)和业务分析等领域的核心工具。 数据源(Data Source) 对于Grafana而言,Promethe…...
FlinkCDC采集MySQL8.4报错
报错日志 原因: MySQL8.4版本中弃用show MASTER STATUS语法 改为:SHOW BINARY LOG STATUS 解决方案: 1、降MySQL版本 2、修改源码...
Webview通信系统学习指南
Webview通信系统学习指南 一、定义与核心概念 1. 什么是Webview? 定义:Webview是移动端(Android/iOS)内置的轻量级浏览器组件,用于在原生应用中嵌入网页内容。作用:实现H5页面与原生应用的深度交互&…...
人工智能如何革新数据可视化领域?探索未来趋势
在当今数字化时代,数据如同汹涌浪潮般不断涌现。据国际数据公司(IDC)预测,全球每年产生的数据量将从 2018 年的 33ZB 增长到 2025 年的 175ZB。面对如此海量的数据,如何有效理解和利用这些数据成为了关键问题。数据可视…...
探索Hello Robot开源移动操作机器人Stretch 3的新技术亮点与市场定位
Hello Robot 推出的 Stretch 3 机器人凭借其前沿技术和多功能性在众多产品中占据优势。Stretch 3 机器人采用开源设计,为开发者提供了灵活的定制空间,能够满足各种不同的需求。其配备的灵活手腕组件和 Intel Realsense D405 摄像头,显著增强了…...
机器人系统设置
机器人系统设置 机器人系统设置与操作指南 1. 系统设置基础功能 偏好设置 控制柜名称修改:通过文本框输入新名称并确认主题切换:支持橙色/蓝色主题(需重启生效) 语言与日期 系统语言/键盘语言设置时间格式:支持系统时…...
C/C++ 扩展智能提示太慢或无法解析项目
问题 C/C 扩展不解析项目,导致源码中的变量、函数都为灰色状态,无法进行跳转。 有时候 log 会报如下错误: Attempting to get defaults from C compiler in "compilerPath" property: D:/Development/Tools/mingw64/bin/gcc.exe…...
通过Kubernetes 外部 DNS控制器来自动管理Azure DNS 和 AKS
前言: 将应用程序及其服务部署到 Kubernetes 集群后,一个问题浮现:如何使用自定义域名访问它?一个简单的解决方案是创建一条 A 记录,将域名指向服务 IP 地址。这可以手动完成,但随着服务数量的增加&#x…...
Elasticsearch知识汇总之ElasticSearch监控方案
八 ElasticSearch监控方案 8.1 ElasticSearch监控指标 监控指标为磐基生产项指标,以下‘监控项名称’‘指标名称 ‘使用的公式‘都已详细说明,图表如下: 监控项名称 指标英文名称 使用的公式 elasticsearch集群健康状态 Elastic_Cluster…...
【能力比对】K8S数据平台VS数据平台
🔥🔥 AllData大数据产品是可定义数据中台,以数据平台为底座,以数据中台为桥梁,以机器学习平台为中层框架,以大模型应用为上游产品,提供全链路数字化解决方案。 ✨AllData数据中台官方平台&…...
AutoDL+SSH在vscode中远程使用GPU训练深度学习模型
注册AutoDL账号 AutoDL官网:AutoDL 注册登录之后,如果你是学生,一定要进行学生认证,可以省钱。 认证之后,打开算力市场, 进行GPU选择 根据自己需要的环境选择版本 ,选好之后创建并开机 这里注…...
【C语言干货】野指针
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、什么是野指针?二、野指针的三大成因 1.指针未初始化2.指针越界访问2.指针指向已释放的内存 前言 提示:以下是本篇文章正文内容&…...
QT生成保存 Excel 文件的默认路径,导出的文件后缀自动加(1)(2)等等
//生成保存 Excel 文件的默认路径 QString MainWidget::getDefaultFilePath() const { QString basePath pathEdit->text(); if (basePath.isEmpty() || !QDir(basePath).exists()) { basePath QStandardPaths::writableLocation(QStandardPaths::DocumentsLocation); } r…...
React Native【详解】搭建开发环境,创建项目,启动项目
下载安装 node https://nodejs.cn/download/ 查看 npx 版本 npx -v若无 npx 则安装 npm install -g npx创建项目 npx create-expo-applatestRN_demo 为自定义的项目名称 下载安装 Python 2.7 下载安装 JAVA JDK https://www.oracle.com/java/technologies/downloads/#jdk24-…...
AIDC智算中心建设:存储核心技术解析
目录 一、智算中心存储概述 1、存储发展 2、智算存储指导政策 3、智算智能存储必要性 二、智算中心存储架构及特征 1、智算存储中心架构 2、智算存储特征 三、智算中心存储核心技术解析 1、长记忆存储范式为推理提质增效 2、数据编织加强全局数据高效处理 3、超节点…...
第11次:用户注册(完整版)
第一步:定义用户模型类 class User(AbstractUser):mobile models.CharField(max_length11, uniqueTrue, verbose_name手机号)class Meta:db_table tb_userverbose_name 用户verbose_name_plural verbose_namedef __str__(self):return self.username第二步&…...
论文速读《Embodied-R: 基于强化学习激活预训练模型具身空间推理能力》
项目主页:https://embodiedcity.github.io/Embodied-R/ 论文链接:https://arxiv.org/pdf/2504.12680 代码链接:https://github.com/EmbodiedCity/Embodied-R.code 0. 简介 具身智能是通用人工智能的重要组成部分。我们希望预训练模型不仅能在…...
VMware Fusion安装win11 arm;使用Mac远程连接到Win
目录 背景步骤1. 安装Fusion2. 下载Win113. 安装Win113.1 初始步骤3.2 进入安装 4. 安装Windows APP 背景 最近国补太火热了,让Macbook来到6000这个价位。实在没忍住,最后入手了一台M3芯片的Macbook Air(jd6799)。 既然运维出身&…...
【ARM】DS-试用授权离线激活
1、 文档目标 解决客户无法在公司网络管控下进行ARM DS 试用激活,记录解决方案。 2、 问题场景 客户在ARM DS激活时无法连接到ARM认证网址,客户公司网络管理无法开放全部网络权限,只能针对特定网址和网络端口可以开放或客户公司开发环境无法…...
泰迪杯特等奖案例学习资料:基于卷积神经网络与集成学习的网络问政平台留言文本挖掘与分析
(第八届“泰迪杯”数据挖掘挑战赛A题特等奖案例深度解析) 一、案例背景与核心挑战 1.1 应用场景与行业痛点 随着“互联网+政务”的推进,网络问政平台成为政府与民众沟通的重要渠道。某市问政平台日均接收留言超5000条,涉及民生、环保、交通等20余类诉求。然而,传统人工…...
基于 ReentrantReadWriteLock 实现高效并发控制
在多线程 Java 应用中,管理共享资源的访问是确保数据一致性和避免竞争条件的关键挑战。在某些场景中,多个线程需要频繁读取共享数据,而只有一个线程偶尔需要更新数据。例如,在一个网页投票系统中,大量用户可能同时查看投票结果(读操作),而投票更新(写操作)则相对较少…...
代理式AI(Agentic AI):2025年企业AI转型的催化剂
李升伟 摘译 步入2025:代理式AI开启企业智能化转型新纪元 随着2025年临近,企业已不再纠结"是否采用人工智能",而是迫切追问"如何加速AI进化"。传统AI系统在敏捷性、扩展性和自主性上的局限日益显现,新一代技…...
MySQL中MVCC指什么?
简要回答: MVCC(multi version concurrency control)即多版本并发控制,为了确保多线程下数据的安全,可以通过undo log和ReadView来实现不同的事务隔离级别。 对于已提交读和可重复读隔离级别的事务来说,M…...
购物数据分析
这是一个关于电商双11美妆数据分析的项目页面,包含版本记录、运行代码提示、评论等功能模块的相关描述。,会涉及数据处理、可视化、统计分析等代码逻辑,用于处理美妆电商双11相关数据,如销售数据统计、消费者行为分析等 。 数据源…...
基于GA遗传优化的不同规模城市TSP问题求解算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 旅行商问题(Traveling Salesman Problem,TSP)是一个经典的组合优化问题,旨在找到一个旅行商在访问多个城市后回到起…...
Nginx 安全防护与 HTTPS 部署
目录 一. 核心安全配置 1. 隐藏版本号 2. 限制危险请求方法 3. 请求限制(CC 攻击防御) 4. 防盗链 二. 高级防护 1. 动态黑名单 2. nginx https 配置 2.1 https 概念 2.1.1 https 为什么不安全 2.1.2 安全通信的四大原则 2.1.3 HTTPS 通信原理…...
隐私计算框架FATE二次开发心得整理(工业场景实践)
文章目录 版本介绍隐私计算介绍前言FATE架构总体架构FateBoard架构前端架构后端架构 FateClient架构创建DAG方式DAG生成任务管理python SDK方式 FateFlow架构Eggroll架构FATE算法架构Cpn层FATE ML层 组件新增流程新增组件流程新增算法流程 版本介绍 WeBank的FATE开源版本 2.2.…...