当前位置: 首页 > news >正文

Pytorch学习笔记(七)Learn the Basics - Optimizing Model Parameters

这篇博客瞄准的是 pytorch 官方教程中 Learn the Basics 章节的 Optimizing Model Parameters 部分。

  • 官网链接:https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd=aa2m 提取码: aa2m 

Optimizing Model Parameters

训练模型是一个迭代过程;在每次迭代中模型都会对输出进行预测,计算预测与真实值的误差(即损失),对误差相对于其参数的导数使用梯度下降优化这些参数。


Step1. 准备代码

这里使用前面关于 datasetsDataLoader 以及构建模型的部分加载代码。

  • 导入必要的库
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
  • 准备训练与测试数据集
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" 
print(device)training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)
  • 将数据集转换为迭代器
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • 定义神经网络模型
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logitc = self.linear_relu_stack(x)return logiticdevice = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" 
model = NeuralNetwork().to(device)

Step2. 定义超参数

超参数是可调整的参数,用来控制模型优化过程。不同的超参数值会影响模型训练和收敛速度。

通常情况下至少存在以下三个超参数:

  • Number of Epochs : 迭代数据集的次数;
  • Batch Size: 每次训练的样本数量;
  • Learning Rate: 在每个批次/迭代中更新模型参数的程度。较小的值会导致学习速度变慢,而较大的值可能会导致训练期间的行为不可预测。
learning_rate = 1e-3
batch_size = 64
epochs = 5

Step3. 定义优化器与损失函数

一旦确定了超参数就可以使用优化循环来训练和优化模型。优化循环的每次迭代称为一个epoch。每个epoch由两个主要部分组成:

  • The Train Loop:迭代训练数据集并尝试收敛到最佳参数;
  • The Validation/Test Loop:迭代测试数据集以检查模型性能是否有所改善;

当训练数据时,未经训练的网络很大概率不会给出正确答案。损失函数所得结果与目标值的差值是我们在训练期间想要最小化的目标。为了计算损失,使用给定数据样本的输入进行推理,并将其与真实数据标签值进行比较。常见的损失函数包括用于回归任务的 nn.MSELoss(均方误差)和用于分类的 nn.NLLLoss(负对数似然)。nn.CrossEntropyLoss 结合了 nn.LogSoftmaxnn.NLLLoss 两者的功能。将模型的输出 logits 传递给 nn.CrossEntropyLoss,由它将对 logits 进行规范化并计算预测误差:

  • 定义损失函数:
loss_fn = nn.CrossEntropyLoss()
  • 定义优化器:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练循环中,优化分为三个步骤:

  1. 调用 optimizer.zero_grad() 重置模型参数的梯度。默认情况下梯度会相加;为了防止重复计算,在每次迭代时明确将其归零;
  2. 通过调用 loss.backward() 反向传播预测损失。PyTorch 会存储相对于每个参数的损失梯度;
  3. 一旦获得了梯度,调用 optimizer.step() 来根据反向传递中收集的梯度调整参数。

Step4. 定义训练与测试循环

  • 训练循环:
def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader)model.train()for batch, (X,y) in enumerate(dataloader):# 计算lossX,y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向传播loss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), batch*batch_size+len(X)print(f"loss: {loss:>7f}, [{current:>5d}/{size:>5d}]")
  • 测试循环:
def test_loop(dataloader, model, loss_fn):model.eval()size = len(dataloader)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X,y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Step5. 训练

epochs = 20for t in range(epochs):print(f"Epoch {t+1}/{epochs}\n------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done")

在这里插入图片描述

相关文章:

Pytorch学习笔记(七)Learn the Basics - Optimizing Model Parameters

这篇博客瞄准的是 pytorch 官方教程中 Learn the Basics 章节的 Optimizing Model Parameters 部分。 官网链接:https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html 完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd…...

数据可视化TensorboardX和tensorBoard安装及使用

tensorBoard 和TensorboardX 安装及使用指南 tensorBoard 和 TensorBoardX 是用于可视化机器学习实验和模型训练过程的工具。TensorBoard 是 TensorFlow 官方提供的可视化工具,而 TensorBoardX 是其社区驱动的替代品,支持 PyTorch 等其他框架。以下是它…...

工业4G路由器赋能智慧停车场高效管理

工业4G路由器作为智慧停车场管理系统通信核心,将停车场内的各个子系统连接起来,包括车牌识别系统、道闸控制系统、车位检测系统、收费系统以及监控系统等。通过4G网络,将这些系统采集到的数据传输到云端服务器或管理中心,实现信息…...

深度学习1—Python基础

深度学习1—python基础 你的第一个程序 print(hello world and hello deep learning!)基本数据结构 空值 (None):在 Python 中,None 是一个特殊的对象,用于表示空值或缺失的值。它不同于数字 0,因为 0 是一个有意义的数字&#…...

数据结构十三、set map

一、set 1、size / empty size:返回set中实际元素的个数 empty:判断set是否为空 2、begin / end 这是两个迭代器,因此可以使用范围for来遍历整个红黑树。其中,遍历是按照中序遍历的顺序,因此是一个有序序列。 3、in…...

【大模型基础_毛玉仁】3.5 Prompt相关应用

目录 3.5 相关应用3.5.1 基于大语言模型的Agent3.5.2 数据合成3.5.3 Text-to-SQL3.5.4 GPTs 3.5 相关应用 Prompt工程应用广泛,能提升大语言模型处理基础及复杂任务的能力,在构建Agent、数据合成、Text-to-SQL转换和设计个性化GPTs等方面不可或缺。 . …...

自动驾驶VLA模型技术解析与模型设计

1.前言 2025年被称为“VLA上车元年”,以视觉语言动作模型(Vision-Language-Action Model, VLA)为核心的技术范式正在重塑智能驾驶行业。VLA不仅融合了视觉语言模型(VLM)的感知能力和端到端模型的决策能力,…...

【AI】Orin NX+ubuntu22.04上移植YoloV11,并使用DeepStream测试成功

【AI】郭老二博文之:AI学习目录汇总 1、烧写系统 新到的开发板,已经烧写好Ubuntu系统,版本为22.04。 如果没有升级到Ubuntu22.04,可以在电脑Ubuntu系统中使用SDKManager来烧写Ubuntu系统,网络情况好的话,也可以直接将CUDA、cuDNN、TensorRT、Deepstream等也安装上。 2…...

vscode 通过Remote-ssh远程连接服务器报错 could not establish connection to ubuntu

vscode 通过Remote-ssh插件远程连接服务器报错 could not establish connection to ubuntu,并且出现下面的错误打印: [21:00:57.307] Log Level: 2 [21:00:57.350] SSH Resolver called for "ssh-remoteubuntu", attempt 1 [21:00:57.359] r…...

ESP32S3 WIFI 实现TCP服务器和静态IP

一、 TCP服务器代码 代码由station_example_main的官方例程修改 /* WiFi station ExampleThis example code is in the Public Domain (or CC0 licensed, at your option.)Unless required by applicable law or agreed to in writing, thissoftware is distributed on an &q…...

第三课:Stable Diffusion图生图入门及应用

文章目录 Part01 图生图原理Part02 图生图基本流程Part03 随机种子作用解析Part04 图生图的拓展应用 Part01 图生图原理 当提示词不能足够表达用户需求的时候,加入图片能让AI更好的理解你的想法图片上的像素信息会在加噪和去噪的过程中,作为一种特征反映…...

蓝桥与力扣刷题(蓝桥 蓝桥骑士)

题目:小明是蓝桥王国的骑士,他喜欢不断突破自我。 这天蓝桥国王给他安排了 N 个对手,他们的战力值分别为 a1,a2,...,an,且按顺序阻挡在小明的前方。对于这些对手小明可以选择挑战,也可以选择避战。 身为高傲的骑士&a…...

Photoshop怎样保存为ico格式

1. 打开图像 开启 Photoshop 软件,选择 “文件” 菜单,点击 “打开” 选项,然后找到你想要保存为 ICO 格式的图像文件并打开。 2. 调整图像大小(可选) ICO 图标通常有特定尺寸要求,你可以根据需求调整图像…...

Ubuntu xinference部署本地模型bge-large-zh-v1.5、bge-reranker-v2-m3

bge-large-zh-v1.5 下载模型到指定路径: modelscope download --model BAAI/bge-large-zh-v1.5 --local_dir ./bge-large-zh-v1.5自定义 embedding 模型,custom-bge-large-zh-v1.5.json: {"model_name": "custom-bge-large…...

python笔记之判断月份有多少天

1、通过随机数作为目标月份 import random month random.randint(1,12)2、判断对应的年份是闰年还是平年 因为2月这个特殊月份,闰年有29天,而平年是28天,所以需要判断对应的年份属于闰年还是平年,代码如下 # 判断年份是闰年还…...

Kotlin泛型: 协变|逆变|不变

引言 无论java 通配符上限还是下限,都多少存在缺陷,要么存不安全,要么取不安全。而kotlin就解决这个问题。让out 纯输出, 让in纯输入。 java这块知识: java泛型的协变、逆变和不变-CSDN博客 协变 生产者out T 协变…...

高斯数据库的空分区的查看和清理

在 高斯数据库(GaussDB) 中,分区表是一种常见的表设计方式,用于优化大数据的查询性能。 一、空分区的影响: 存储空间占用 元数据开销:即使分区中没有数据,数据库仍然需要维护分区的元数据&…...

word使用自带的公式

文章目录 Word公式中word公式快捷键:word2016公式框输入多行word 公式加入空格:word公式如何输入矩阵:公式图片转为Latex语法word 能直接输入 latex 公式么word公式中有的是斜体有的不是 word文本中将文字转为上标的快捷键 Tips几个好用的网站&#xff1…...

Linux系统-ls命令

一、ls命令的定义 Linux ls命令(英文全拼:list directory contents)用于显示指定工作目录下之内容(列出目前工作目录所含的文件及子目录)。 二、ls命令的语法 ls [选项] [目录或文件名] ls [-alrtAFR] [name...] 三、参数[选项…...

数据结构:利用递推式计算next表

next 表是 KMP 算法的核心内容,下面介绍一种计算 next 表的方法:利用递推式计算 如图 6.3.1 所示,在某一趟匹配中,当对比到最后一个字符的时候,发现匹配失败(s[i] ≠ t[j])。根据 BF 算法&…...

Git操作

1 git init 项目初始化&#xff08;init&#xff09;成仓库 2、git add 管理文件 3、git commit -m <message> 告诉Git&#xff0c;把文件提交到仓库 4、git status 查看当前管理文件的状态&#xff0c;命令 5、git log 查看提交&#xff08;commit&#xff09;的…...

什么是快重传

原理&#xff1a; 在TCP连接中&#xff0c;接受方会对收到的数据包发送确认&#xff08;ACK&#xff09;。如果接受方收到一个乱序的数据包&#xff08;即期望的下一个数据包尚未到达&#xff09;&#xff0c;它会重复发送对上一个已成功接受的数据包的确认。 当发送方连续收…...

计算机网络——物理层设备

目录 ​编辑 中继器 集线器&#xff08;Hub&#xff09; 集线器&#xff0c;中继器的一些特性 集线器和中继器不能“无限串联” 集线器连接的网络&#xff0c;物理上是星型拓扑&#xff0c;逻辑上是总线型拓扑 集线器连接的各网段会“共享带宽” 中继器 如果我们想要网络…...

CSS 预处理器

在面试中回答关于 CSS 预处理器的问题时&#xff0c;你可以从以下几个方面进行回答&#xff0c;展示你的知识深度和实践经验&#xff1a; 1. 什么是 CSS 预处理器&#xff1f; 你可以从定义和目的入手&#xff1a; “CSS 预处理器是一种扩展 CSS 功能的工具&#xff0c;它允许…...

解锁智能制造新体验:兰亭妙微 UE/UI 设计赋能行业变革

在智能制造时代的滚滚浪潮中&#xff0c;企业的数字化转型不仅是技术的革新&#xff0c;更是用户体验与交互界面的全面升级。然而&#xff0c;许多制造企业在这一转型过程中&#xff0c;面临着一系列 UI/UE 设计难题&#xff0c;严重阻碍了企业的数字化发展进程。兰亭妙微凭借专…...

计算机网络高频(三)UDP基础

计算机网络高频(三)UDP基础 1.UDP的头部格式是什么样的?⭐ UDP 头部具有以下字段: 源端口(Source Port):16 位字段,表示发送方的端口号。目标端口(Destination Port):16 位字段,表示接收方的端口号。长度(Length):16 位字段,表示 UDP 数据报(包括头部和数据部…...

Oracle数据库服务器地址变更与监听配置修改完整指南

一、前言 在企业IT运维中&#xff0c;Oracle数据库服务器地址变更是常见的运维操作。本文将详细介绍如何安全、高效地完成Oracle数据库服务器地址变更及相关的监听配置修改工作&#xff0c;确保数据库服务在迁移后能够正常运行。 二、准备工作 1. 环境检查 确认新旧服务器I…...

获取1688.item_password接口:解析淘口令真实URL

一、接口介绍 1688的item_password接口主要用于将1688平台的淘口令短链接转换为实际商品链接。它基于1688平台的后台数据和规则&#xff0c;对用户传入的淘口令进行解析和验证&#xff0c;通过相应的算法和数据匹配&#xff0c;找到对应的商品信息&#xff0c;并生成可直接访问…...

计算机网络的分类及其性能指标

一. 计算机网络的分类 1. 按分布范围分类 广域网&#xff08;WAN&#xff09; 也称远程网。广域网提供长距离通信&#xff0c;通常是几十千米到几千千米的区域&#xff0c;比如跨国通信。连接广域网的各结点交换机的链路一般是高速链路&#xff0c;具有较大的通信容量城域网&…...

Redis原理:watch命令

在前面的文章中有提到&#xff0c;在multi 前可以通过watch 来观察哪些key&#xff0c;被观察的这些key&#xff0c;会被redis服务器监控&#xff0c;涉及该key被修改时&#xff0c;则在exec 命令执行过程中会被识别出来&#xff0c;exec 就不会再执行命令。 源码分析 // 监控…...

微服务中的服务发现与注册中心

在微服务架构中&#xff0c;服务实例的数量可能随着流量负载自动扩展或缩减&#xff0c;因此服务之间如何高效地进行通信成为一个重要问题。本篇博客将介绍服务发现的概念&#xff0c;并结合 Consul 和 自定义注册中心 进行实践&#xff0c;帮助开发者在微服务架构下高效管理服…...

Flutter网络请求封装:高效、灵活、易用的Dio工具类

在Flutter开发中&#xff0c;网络请求是必不可少的功能。为了简化代码、提高开发效率&#xff0c;我们通常会封装一个网络请求工具类。本文基于Dio库&#xff0c;详细介绍如何封装一个高效、灵活、易用的网络请求工具类&#xff0c;支持以下功能&#xff1a; 单例模式&#xf…...

Axure项目实战:智慧城市APP(三)教育查询(显示与隐藏交互)

亲爱的小伙伴&#xff0c;在您浏览之前&#xff0c;烦请关注一下&#xff0c;在此深表感谢&#xff01; 课程主题&#xff1a;教育查询 主要内容&#xff1a;教育公告信息&#xff0c;小升初、初升高、高考成绩查询&#xff1b;教育公告信息为传统的信息页面&#xff0c;小升…...

案例实践 | 招商局集团以长安链构建“基于DID的航运贸易数据资产目录链”

概览 案例名称 基于DID的航运贸易数据资产目录链 业主单位 招商局集团 上线时间 2024年10月 用户群体 供数用数企业和个人 用户规模 集团内20企业 案例背景 招商局集团深入落实“促进数据高效流通使用、赋能实体经济”精神&#xff0c;深化集团数字化水平&#xff0c…...

计算机网络入门:物理层与数据链路层详解

&#x1f310; &#xff08;专业解析 中学生也能懂&#xff01;&#xff09; &#x1f4d6; 前言 计算机网络就像数字世界的“高速公路系统”&#xff0c;而物理层和数据链路层是这条公路的基石。本文用 专业视角 和 生活化比喻 &#xff0c;带你轻松理解这两层的核心原理&a…...

使用 Docker 部署 RabbitMQ 的详细指南

使用 Docker 部署 RabbitMQ 的详细指南 在现代应用程序开发中&#xff0c;消息队列系统是不可或缺的一部分。RabbitMQ 是一个流行的开源消息代理软件&#xff0c;它实现了高级消息队列协议&#xff08;AMQP&#xff09;。本文将详细介绍如何使用 Docker 部署 RabbitMQ&#xf…...

数据结构之基本队列-顺序结构实现-初始化-判断队列是否为空(front=rear)-出队-入队-队尾满了,调整队列-获取队头元素

数据结构之基本队列-顺序结构实现-初始化-判断队列是否为空(frontrear)-出队-入队-队尾满了&#xff0c;调整队列-获取队头元素——完整可运行代码 #include <stdio.h>#define MAXSIZE 100 typedef int ElemType;typedef struct {ElemType data[MAXSIZE];int front;int…...

如何用 Postman 发送 POST 请求?

POST 请求是 HTTP 协议中用于提交数据的一种方法&#xff0c;Postman 提供了丰富的功能来支持用户发送包含各种信息的 POST 请求&#xff0c;如文本数据、JSON 或 XML 数据结构、文件等。 Postman 发送 post 请求教程...

基于Spring Boot的网上商城系统的设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…...

mysql中的聚簇索引,什么是聚簇索引和非聚簇索引

文章目录 1. 什么是聚簇索引2. 非聚簇索引3. 聚簇索引的优缺点4. 聚簇索引的使用场景5. 聚簇索引和主键索引的异同前言: 在继续讲解专栏内容之前,先学习几个概念,以便更好了解: 什么是聚簇索引什么是回表这篇文章详细分析 聚簇索引。回表的理解可以进入这篇文章:什么是回表…...

涨薪技术|使用Dockerfile创建镜像

上次的推文内容中介绍了如何使用docker commit的方法来构建镜像&#xff0c;相反推荐使用被称为Dockerfile的定义文件和docker build命令来构建镜像。Dockerfile使用基本的基于DSL语法的指令来构建一个Docker镜像&#xff0c;之后使用docker build命令基于该Dockerfile中的指令…...

OpenFeign在微服务中的远程服务调用工作流程

OpenFeign作为声明式的HTTP客户端,在微服务架构中的远程调用工作流程可分为以下标准步骤: 一、初始化阶段 1. 接口定义(声明式API) @FeignClient(name = "user-service", path = "/api/users") public interface UserServiceClient {@GetMapping(&q…...

力扣14. 最长公共前缀:Java四种解法详解

力扣14. 最长公共前缀&#xff1a;Java四种解法详解 题目描述 编写一个函数来查找字符串数组中的最长公共前缀。如果不存在公共前缀&#xff0c;返回空字符串 ""。 示例&#xff1a; 输入&#xff1a;strs ["flower","flow","flight&quo…...

关于deepin上WPS读取windows上的docx文件

最近在尝试着用deepin替代windows&#xff0c;在deepin上安装了wps读取在windows上编辑的docx和xlsx文件&#xff0c;遇到类似如下的错误&#xff1a; 系统缺失字体&#xff1a;Symbol、Wingdings、Wingdings2、Wingdings3、Webdings、MT Extra WPS无法正确的显示某些符号&…...

利用python调接口获取物流标签,并转成PDF保存在指定的文件夹。

需求 调用get label 接口将接口返回的base64文件转换成pdf文件命名用接口返回值的单号命名保存再指定的文件夹重 实现代码 # -*- coding: utf-8 -*- import requests import base64 import os import json # 新增json模块导入url "http://releasud.com/api/label/Lab…...

31天Python入门——第15天:日志记录

你好&#xff0c;我是安然无虞。 文章目录 日志记录python的日志记录模块创建日志处理程序并配置输出格式将日志内容输出到控制台将日志写入到文件 logging更简单的一种使用方式 日志记录 日志记录是一种重要的应用程序开发和维护技术, 它用于记录应用程序运行时的关键信息和…...

“自动驾驶背后的数学” 专栏导读

专栏链接&#xff1a; 自动驾驶背后的数学 专栏以“自动驾驶背后的数学”为主题&#xff0c;从基础到深入&#xff0c;再到实际应用和未来展望&#xff0c;全面解析自动驾驶技术中的数学原理。开篇用基础数学工具搭建自动驾驶的整体框架&#xff0c;吸引儿童培养兴趣&#xff0…...

Redis中的数据类型与适用场景

目录 前言1. 字符串 (String)1.1 特点1.2 适用场景 2. 哈希 (Hash)2.1 特点2.2 适用场景 3. 列表 (List)3.1 特点3.2 适用场景 4. 集合 (Set)4.1 特点4.2 适用场景 5. 有序集合 (Sorted Set)5.1 特点5.2 适用场景 6. Redis 数据类型的选型建议结语 前言 Redis 作为一款高性能的…...

python并发爬虫

爬虫多线程方法生成 from threading import Threaddef func(name):for i in range(100):print(f"{name}完成了{i}项任务")if __name__ __main__:t1 Thread(targetfunc, args(老杨,))t2 Thread(targetfunc, args(老李,))t3 Thread(targetfunc, args(老孙,))t1.st…...

react-create-app整合windicss

引用&#xff1a;https://blog.csdn.net/gitblog_00339/article/details/142544145 package.json: "react": "^19.0.0","react-dom": "^19.0.0","react-scripts": "5.0.1","typescript": "^4.9.5…...