当前位置: 首页 > news >正文

神经网络代码入门解析

神经网络代码入门解析

import torch
import matplotlib.pyplot as pltimport randomdef create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))y = torch.matmul(x, w) + b  # 矩阵相乘再加bnoise = torch.normal(0, 0.01, y.shape)  # 为y添加噪声y += noisereturn x, ynum = 500true_w = torch.tensor([8.1, 2, 2, 4])
true_b = 1.1X, Y = create_data(true_w, true_b, num)# plt.scatter(X[:, 3], Y, 1)  # 画散点图 对X取全部的行的第三列,标签Y,点大小
# plt.show()def data_provider(data, label, batchsize):  # 每次取batchsize个数据length = len(label)indices = list(range(length))# 这里需要把数据打乱random.shuffle(indices)for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]get_data = data[get_indices]get_label = label[get_indices]yield get_data, get_label  # 有存档点的returnbatchsize = 16
# for batch_x, batch_y in data_provider(X, Y, batchsize):
#     print(batch_x, batch_y)
#     break# 定义模型
def fun(x, w, b):pred_y = torch.matmul(x, w) + breturn pred_y# 定义loss
def maeLoss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零lr = 0.01
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)
b_0 = torch.tensor(0.01, requires_grad=True)
print(w_0, b_0)epochs = 50
for epoch in range(epochs):data_loss = 0for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)loss = maeLoss(pred_y, batch_y)loss.backward()sgd([w_0, b_0], lr)data_loss += lossprint("epoch %03d: loss: %.6f" % (epoch, data_loss))print("真实函数值:", true_w, true_b)
print("训练得到的函数值:", w_0, b_0)idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx].detach().numpy(), Y, 1)
plt.show()

逐步分析代码

1.数据生成

image-20250301120222530

首先设计一个函数create_data,提供我们所需要的数据集的x与y

def create_data(w, b, data_num):  # 数据生成x = torch.normal(0, 1, (data_num, len(w)))  # 生成特征数据,形状为 (data_num, len(w))y = torch.matmul(x, w) + b  # 计算目标值 y = x * w + bnoise = torch.normal(0, 0.01, y.shape)  # 生成噪声,形状与 y 相同y += noise  # 为 y 添加噪声,模拟真实数据中的随机误差return x, y
  • torch.normal() 生成一个张量

    • torch.normal(0, 1, (data_num, len(w))):生成一个形状为 (data_num, len(w)) 的张量,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样的。
  • torch.matmul() 让矩阵相乘

    matmul: matrix multiply

  • 再使用torch.normal()生成一个张量,添加到y上,相当于为y添加了随机的噪声

    噪声的引入是为了模拟真实数据中的随机误差,使生成的数据更接近现实场景。

2.设计一个数据加载器

def data_provider(data, label, batchsize):  # 每次取 batchsize 个数据length = len(label)indices = list(range(length))random.shuffle(indices)  # 打乱数据顺序,避免模型学习到顺序特征for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 获取当前批次的索引get_data = data[get_indices]  # 获取当前批次的数据get_label = label[get_indices]  # 获取当前批次的标签yield get_data, get_label  # 返回当前批次的数据和标签

data_provider可以分批提供数据,并通过yield来返回已实现记忆功能

首先把list y顺序打乱,这样就相当于从生成的训练集y中随机读取,若不打乱数据,可能造成训练结果的不理想

打乱数据可以避免模型在训练过程中学习到数据的顺序特征,从而提高模型的泛化能力。

之后分段遍历打乱的y,返回对应的局部的数据集来给神经网络进行训练

3.定义模型函数

image-20250301122853184

def fun(x, w, b):pred_y = torch.matmul(x, w) + b  # 计算预测值 y = x * w + breturn pred_y

fun(x, w, b) 是一个线性模型,形式为 y = x * w + b,其中 x 是输入特征,w 是权重,b 是偏置。

4.定义Loss函数

image-20250301122958888

def maeLoss(pre_y, y):return torch.sum(abs(pre_y - y)) / len(y)  # 计算平均绝对误差 (MAE)
  • maeLoss 是平均绝对误差(Mean Absolute Error, MAE),它计算预测值 pre_y 和真实值 y 之间的绝对误差的平均值。
  • 公式为:MAE = (1/n) * Σ|pre_y - y|,其中 n 是样本数量。

5.梯度下降sgd函数

# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 这部分代码不计算梯度for para in paras:para -= para.grad * lr  # 不能写成 para = para - paras.grad * lr !!!! 这句相当于要创建一个新的para,会导致报错para.grad.zero_()  # 将使用过的梯度归零

这里需要使用torch.no_grad()来避免重复计算梯度

image-20250301123531781

在前向过程中已经累计过一次梯度了,如果在梯度下降过程中又累计了梯度,那么就会造成不必要的麻烦

PyTorch 会累积梯度,如果不手动清零,梯度会不断累积,导致参数更新错误。

para -= para.grad * lr就是将参数w修正的过程(w=w-(dy^/dw)*learningRate)

torch.no_grad() 是一个上下文管理器,用于禁用梯度计算。在参数更新时,禁用梯度计算可以避免不必要的计算和内存占用。

5.开始训练

epochs = 50
for epoch in range(epochs):data_loss = 0num_batches = len(Y) // batchsize  # 计算批次数量for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)  # 前向传播loss = maeLoss(pred_y, batch_y)  # 计算损失loss.backward()  # 反向传播sgd([w_0, b_0], lr)  # 更新参数data_loss += loss.item()  # 累积损失print("epoch %03d: loss: %.6f" % (epoch, data_loss / num_batches))  # 打印平均损失

先定义一个训练轮次epochs=50,表示训练50轮

在每轮训练中将loss记录下来,以此评价训练的效果

首先用data_provider来获取数据集中随机的一部分

接着传入相应数据给模型函数,通过前向传播获得预测y值pred_y

调用Loss计算函数,获取这次的loss,再通过反向传播loss.backward()计算梯度

loss.backward() 是反向传播的核心步骤,用于计算损失函数对模型参数的梯度。

再通过梯度下降sgd([w_0, b_0], lr)来更新模型的参数

最终将这组数据的loss累加到这轮数据的loss中

6.结果绘制

idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy() * w_0[idx].detach().numpy() + b_0.detach().numpy())  # 绘制预测直线
plt.scatter(X[:, idx].detach().numpy(), Y, 1)  # 绘制真实数据点
plt.show()

相关文章:

神经网络代码入门解析

神经网络代码入门解析 import torch import matplotlib.pyplot as pltimport randomdef create_data(w, b, data_num): # 数据生成x torch.normal(0, 1, (data_num, len(w)))y torch.matmul(x, w) b # 矩阵相乘再加bnoise torch.normal(0, 0.01, y.shape) # 为y添加噪声…...

Android 数据库查询对比(APN案例)

功能背景 APN 数据通常存储在数据库中,由TelephonyProvider提供。当用户进入APN设置界面时,Activity会启动,AOSP源码通过ContentResolver查询APN数据。关键分析点在于这个查询操作是否在主线程执行,因为主线程上的耗时操作会导致…...

神卓 S500 异地组网设备实现监控视频异地组网的详细步骤

一、设备与环境准备 硬件清单 主设备:神卓 S500 异地组网路由器 1子设备:神卓 S500 或兼容设备 N(需通过官网认证)监控设备:支持 RTSP/ONVIF 协议的 NVR、摄像头网络要求:各网点需稳定联网(推荐…...

golang安装(1.23.6)

1.切换到安装目录 cd /usr/local 2.下载安装包 wget https://go.dev/dl/go1.23.6.linux-amd64.tar.gz 3.解压安装包 sudo tar -C /usr/local -xzf go1.23.6.linux-amd64.tar.gz 4.配置环境变量 vi /etc/profile export PATH$…...

leetcode35.搜索插入位置

题目: 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 示例 1: 输入: nums [1,3,5,6], target 5 输出…...

LeetCode第57题_插入区间

LeetCode 第57题:插入区间 题目描述 给你一个 无重叠的 ,按照区间起始端点排序的区间列表。在列表中插入一个新的区间,你需要确保列表中的区间仍然有序且不重叠(如果有必要的话,可以合并区间)。 难度 中…...

人工智能之数学基础:线性代数中矩阵的运算

本文重点 矩阵的运算在解决线性方程组、描述线性变换等方面发挥着至关重要的作用。通过对矩阵进行各种运算,可以简化问题、揭示问题的本质特征。在实际应用中,我们可以利用矩阵运算来处理图像变换、数据分析、电路网络等问题。深入理解和掌握矩阵的运算,对于学习线性代数以…...

SQL Server 创建用户并授权

创建用户前需要有一个数据库,创建数据库命令如下: CREATE DATABASE [数据库名称]; CREATE DATABASE database1;一、创建登录用户 方式1:SQL命令 命令格式:CREATE LOGIN [用户名] WITH PASSWORD ‘密码’; 例如,创…...

MySQL双主搭建-5.7.35

文章目录 上传并安装MySQL 5.7.35双主复制的配置实例一:172.25.0.19:实例二:172.25.0.20: 配置复制用户在实例 1 (172.25.0.19)上执行:在实例 2 (172.25.0.20)上执行&…...

RNN实现精神分裂症患者诊断(pytorch)

RNN理论知识 RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN…...

Python中字符串的常用操作

一、r原样输出 在 Python 中,字符串前加 r(即 r"string" 或 rstring)表示创建一个原始字符串(raw string)。下面详细介绍原始字符串的特点、使用场景及与普通字符串的对比。 特点 忽略转义字符&#xff1…...

uniapp 本地数据库多端适配实例(根据运行环境自动选择适配器)

项目有个需求,需要生成app和小程序,app支持离线数据库,如果当前没有网络提醒用户开启离线模式,所以就随便搞了下,具体的思路就是: 一个接口和多个实现类(类似后端的模板设计模式)&am…...

Spring Cloud Gateway 整合Spring Security

做了一个Spring Cloud项目,网关采用 Spring Cloud Gateway,想要用 Spring Security 进行权限校验,由于 Spring Cloud Gateway 采用 webflux ,所以平时用的 mvc 配置是无效的,本文实现了 webflu 下的登陆校验。 1. Sec…...

【异地访问本地DeepSeek】Flask+内网穿透,轻松实现本地DeepSeek的远程访问

写在前面:本博客仅作记录学习之用,部分图片来自网络,如需引用请注明出处,同时如有侵犯您的权益,请联系删除! 文章目录 前言依赖Flask构建本地网页访问LM Studio 开启网址访问DeepSeek 调用模板Flask 访问本…...

Windows对比MacOS

Windows对比MacOS 文章目录 Windows对比MacOS1-环境变量1-Windows添加环境变量示例步骤 1:打开环境变量设置窗口步骤 2:添加系统环境变量 2-Mac 系统添加环境变量示例步骤 1:打开终端步骤 2:编辑环境变量配置文件步骤 3&#xff1…...

React实现无缝滚动轮播图

实现效果: 由于是演示代码,我是直接写在了App.tsx里面在 文件位置如下: App.tsx代码如下: import { useState, useEffect, useCallback, useRef } from "react"; import { ImageContainer } from "./view/ImageC…...

Ubuntu20.04确认cuda和cudnn已经安装成功

当我们通过官网安装cuda和cudnn时,终端执行完命令后我们仍不能确定是否已经安装成功。接下来教大家用几句命令测试。 cuda 检测版本号 nvcc -V如果输出如下,则安装成功。 可以看到版本号是11.2 cudnn检测版本号 有两种命令:如果你的cudn…...

sqlilab 46 关(布尔、时间盲注)

sqlilabs 46关(布尔、时间盲注) 46关有变化了,需要我们输入sort,那我们就从sort1开始 递增测试: 发现测试到sort4就出现报错: 我们查看源码: 从图中可看出:用户输入的sort值被用于查…...

AI时代保护自己的隐私

人工智能最重要的就是数据,让我们面对现实,大多数人都不知道他们每天要向人工智能提供多少数据。你输入的每条聊天记录,你发出的每条语音命令,人工智能生成的每张图片、电子邮件和文本。我建设了一个网站(haptool.com)&#xff0c…...

模型优化之强化学习(RL)与监督微调(SFT)的区别和联系

强化学习(RL)与监督微调(SFT)是机器学习中两种重要的模型优化方法,它们在目标、数据依赖、应用场景及实现方式上既有联系又有区别。 想了解有关deepseek本地训练的内容可以看我的文章: 本地基于GGUF部署的…...

Buildroot 添加自定义模块-内置文件到文件系统

目录 概述实现步骤1. 创建包目录和文件结构2. 配置 Config.in3. 定义 cp_bin_files.mk4. 添加源文件install.shmy.conf 5. 配置与编译 概述 Buildroot 是一个高度可定制和模块化的嵌入式 Linux 构建系统,适用于从简单到复杂的各种嵌入式项目. buildroot的源码中bui…...

蓝牙接近开关模块感应开锁手机靠近解锁支持HID低功耗

ANS-BT101M是安朔科技推出的蓝牙接近开关模块,低功耗ble5.1,采用UART通信接口,实现手机自动无感连接,无需APP,人靠近车门自动开锁,支持苹果、安卓、鸿蒙系统,也可以通过手机手动开锁或上锁&…...

计算机毕业设计SpringBoot+Vue.js基于工程教育认证的计算机课程管理平台(源码+文档+PPT+讲解)

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…...

企业知识库搭建:14款开源与免费系统选择

本文介绍了以下14 款知识库管理系统:1.Worktile;2.PingCode;3.石墨文档; 4. 语雀; 5. 有道云笔记; 6. Bitrix24; 7. Logseq等。 在如今的数字化时代,企业和团队面临着越来越多的信息…...

蓝桥杯(握手问题)

小蓝组织了一场算法交流会议,总共有 50 人参加了本次会议。在会议上,大家进行了握手交流。按照惯例他们每个人都要与除自己以外的其他所有人进行一次握手 (且仅有一次)。 但有 7个人,这 7 人彼此之间没有进行握手 (但这 7 人与除这 7 人以外…...

如何使用 Jenkins 实现 CI/CD 流水线:从零开始搭建自动化部署流程

如何使用 Jenkins 实现 CI/CD 流水线:从零开始搭建自动化部署流程 在软件开发过程中,持续集成(CI)和持续交付(CD)已经成为现代开发和运维的标准实践。随着代码的迭代越来越频繁,传统的手动部署方式不仅低效,而且容易出错。为了提高开发效率和代码质量,Jenkins作为一款…...

c++字符编码/乱码问题

基本概念 c11版本引入了char16_t和char32_t两个类型,他们的特点分别如下: char16_t 16位的unicode字符类型用于表示UTF-16编码大小:2字节字面量前缀:u char32_t 32位unicode字符类型用于表示UTF-32编码大小:4字节…...

侯捷 C++ 课程学习笔记:深入理解类与继承

文章目录 每日一句正能量一、课程背景二、学习内容:类与继承(一)类的基本概念1. 类的定义与实例化2. 构造函数与析构函数 (二)继承1. 单继承与多继承2. 虚函数与多态 三、学习心得四、总结 每日一句正能量 有种承担&am…...

初始化列表

一:声明,定义,赋值的区别 ①:声明 这里,int _year; int _month;int _day; 是成员变量的声明,它们告诉编译器: 类 Date中有三个成员变量_year和 _month和_day。 它们的类型分别都是 int 此…...

7.1 - 定时器之中断控制LED实验

文章目录 1 实验任务2 系统框图3 软件设计 1 实验任务 本实验任务是通过CPU私有定时器的中断,每 200ms 控制一次PS LED灯的亮灭。 2 系统框图 3 软件设计 注意事项: 定时器中断在收到中断后,只需清除中断状态,无需禁用中断、启…...

Pytest之fixture的常见用法

文章目录 1.前言2.使用fixture执行前置操作3.使用conftest共享fixture4.使用yield执行后置操作 1.前言 在pytest中,fixture是一个非常强大和灵活的功能,用于为测试函数提供固定的测试数据、测试环境或执行一些前置和后置操作等, 与setup和te…...

【分库分表】基于mysql+shardingSphere的分库分表技术

目录 1.什么是分库分表 2.分片方法 3.测试数据 4.shardingSphere 4.1.介绍 4.2.sharding jdbc 4.3.sharding proxy 4.4.两者之间的对比 5.留个尾巴 1.什么是分库分表 分库分表是一种场景解决方案,它的出现是为了解决一些场景问题的,哪些场景喃…...

合并两个有序链表:递归与迭代的实现分析

合并两个有序链表:递归与迭代的实现分析 在算法与数据结构的世界里,链表作为一种基本的数据结构,经常被用来解决各种问题。特别是对于有序链表的合并,既是经典面试题,也是提高编程能力的重要练习之一。合并两个有序链…...

HTML AI 编程助手

HTML AI 编程助手 引言 随着人工智能技术的飞速发展,编程领域也迎来了新的变革。HTML,作为网页制作的基础语言,与AI技术的结合,为开发者带来了前所未有的便利。本文将探讨HTML AI编程助手的功能、应用场景以及如何利用它提高编程…...

备战蓝桥杯Day11 DFS

DFS 1.要点 (1)朴素dfs 下面保存现场和恢复现场就是回溯法的思想,用dfs实现,而本质是用递归实现,代码框架: ans; //答案,常用全局变量表示 int mark[N]; //记录状态i是否被处理过 …...

Oracle 认证为有哪几个技术方向

Oracle 认证技术方向,分别是数据库管理、开发、云平台,每个方向都有不同的学习等级 数据库运维方向 Oracle Certified Professional(OCP):19c OCA内容已和OCP合并 OCP 19c属于oracle认证专家,要求考生掌握深…...

25物理学研究生复试面试问题汇总 物理学专业知识问题很全! 物理学复试全流程攻略 物理学考研复试调剂真题汇总

正在为物理考研复试专业面试发愁的你,是不是不知道从哪开始准备? 学姐告诉你,其实物理考研复试并没有你想象的那么难!只要掌握正确的备考方法,稳扎稳打,你也可以轻松拿下高分!今天给大家准备了…...

网络安全技术与应用

文章详细介绍了网络安全及相关技术,分析了其中的一类应用安全问题——PC机的安全问题,给出了解决这类问题的安全技术——PC防火墙技术。 1 网络安全及相关技术 自20世纪…...

APISIX Dashboard上的配置操作

文章目录 登录配置路由配置消费者创建后端服务项目配置上游再创建一个路由测试 登录 http://192.168.10.101:9000/user/login?redirect%2Fdashboard 根据docker 容器里的指定端口: 配置路由 通过apisix 的API管理接口来创建(此路由,直接…...

深度剖析数据分析职业成长阶梯

一、数据分析岗位剖析 目前,数据分析领域主要有以下几类岗位:业务数据分析师、商业数据分析师、数据运营、数据产品经理、数据工程师、数据科学家等,按照工作侧重点不同,本文将上述岗位分为偏业务和偏技术两大类,并对…...

HarmonyOS学习第11天:布局秘籍RelativeLayout进阶之路

布局基础:RelativeLayout 初印象 在 HarmonyOS 的界面开发中,布局是构建用户界面的关键环节,它决定了各个组件在屏幕上的位置和排列方式。而 RelativeLayout(相对布局)则是其中一种功能强大且灵活的布局方式&#xff0…...

问题修复-后端返给前端的时间展示错误

问题现象: 后端给前端返回的时间展示有问题。 需要按照yyyy-MM-dd HH:mm:ss 的形式展示 两种办法: 第一种 在实体类的属性上添加JsonFormat注解 第二种(建议使用) 扩展mvc框架中的消息转换器 代码: 因为配置类继…...

怎么排查页面响应慢的问题

一、排查流程图 -----------------| 全局监控报警触发 |-----------------|▼-----------------| 定位异常服务节点 |-----------------|------------------▼ ▼ ----------------- ----------------- | 基础设施层排查 | | 应用层代码排查 | | (网…...

第二十四:5.2【搭建 pinia 环境】axios 异步调用数据

第一步安装&#xff1a;npm install pinia 第二步&#xff1a;操作src/main.ts 改变里面的值的信息&#xff1a; <div class"count"><h2>当前求和为&#xff1a;{{ sum }}</h2><select v-model.number"n">  // .number 这里是…...

SpringBoot——生成Excel文件

在Springboot以及其他的一些项目中&#xff0c;或许我们可能需要将数据查询出来进行生成Excel文件进行数据的展示&#xff0c;或者用于进行邮箱发送进行附件添加 依赖引入 此处demo使用maven依赖进行使用 <dependency><groupId>org.apache.poi</groupId>&…...

java高级(IO流多线程)

file 递归 字符集 编码 乱码gbk&#xff0c;a我m&#xff0c;utf-8 缓冲流 冒泡排序 //冒泡排序 public static void bubbleSort(int[] arr) {int n arr.length;for (int i 0; i < n - 1; i) { // 外层循环控制排序轮数for (int j 0; j < n -i - 1; j) { // 内层循环…...

MySQL 用户权限管理深度解析:从基础到高阶实践(2000字指南)

MySQL 用户权限管理是数据库安全与运维的核心环节。无论是本地开发环境还是企业级生产环境,合理配置用户权限、理解版本差异、遵循安全规范都至关重要。本文将从 ​基础权限配置、版本差异详解、安全加固策略、高阶权限管理、故障排查​ 等多个维度展开,覆盖 MySQL 5.7、8.0 …...

【0011】HTML其他文本格式化标签详解(em标签、strong标签、b标签、i标签、sup标签、sub标签......)

如果你觉得我的文章写的不错&#xff0c;请关注我哟&#xff0c;请点赞、评论&#xff0c;收藏此文章&#xff0c;谢谢&#xff01; 本文内容体系结构如下&#xff1a; 本文旨在深入探讨HTML中其他的文本格式化标签&#xff0c;主要有<em> 标签、<strong> 标签、…...

数据虚拟化的中阶实践:从概念到实现

数据虚拟化的中阶实践:从概念到实现 在大数据时代,数据的数量、种类和来源呈现爆炸式增长,如何高效、灵活地访问和利用这些数据成为了企业面临的重要问题。数据虚拟化作为一种创新的技术,正逐渐成为解决这一难题的关键。它通过抽象化层将底层数据源与应用程序隔离,使得数…...

AI辅助学习vue第十四章

第十四章&#xff1a;技术引领与未来展望 在第十五章&#xff0c;你已经在Vue技术领域深耕许久&#xff0c;积累了丰富的经验与卓越的影响力。此时&#xff0c;你将站在行业前沿&#xff0c;引领技术走向&#xff0c;为Vue技术的未来发展开辟新道路。 1. 引领Vue技术发展方向…...