当前位置: 首页 > news >正文

pytorch nn.RNN demo

之前已经讲过关于RNNCell的实现了.

这里用LLM写了一个简单的nn.RNN demo:

import torch
import torch.nn as nn# 设置随机种子以便结果可复现
torch.manual_seed(42)# 定义模型参数
input_size = 4      # 输入特征维度
hidden_size = 8     # 隐藏层维度
num_layers = 2      # RNN 层数(修改为2层)
seq_len = 10        # 序列长度
batch_size = 3      # 批量大小# 创建2层RNN模型
model = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=False  # 输入输出格式: [seq_len, batch_size, feature_size]
)# 生成随机输入数据 [seq_len, batch_size, input_size]
x = torch.randn(seq_len, batch_size, input_size)
print(f"输入 x 的形状: {x.shape}  # [seq_len, batch_size, input_size]")# 初始化隐藏状态 (可选)
h0 = torch.zeros(num_layers, batch_size, hidden_size)
print(f"初始隐藏状态 h0 的形状: {h0.shape}  # [num_layers, batch_size, hidden_size]")# 前向传播
output, h_n = model(x, h0)
# output: 所有时间步的最后一层隐藏状态
# h_n: 所有层的最后一个时间步的隐藏状态print(f"\n输出结果:")
print(f"output (所有时间步的最后一层隐藏状态) 的形状: {output.shape}  # [seq_len, batch_size, hidden_size]")
print(f"h_n (所有层的最后时间步隐藏状态) 的形状: {h_n.shape}  # [num_layers, batch_size, hidden_size]")# 验证 h_n 与 output 的关系(修正后的逻辑)
print(f"\n验证 h_n 与 output 的关系:")
# 最后一层的最后状态应等于 output 的最后时间步
assert torch.allclose(h_n[-1], output[-1]), "最后一层的最后状态应等于output的最后时间步"
print(" 最后一层的最后状态与 output 的最后时间步相等")# 打印第一层和第二层的最后隐藏状态
print(f"\n第一层的最后隐藏状态:")
print(h_n[0, 0, :5])  # 打印第一个样本的前5个元素
print(f"\n第二层的最后隐藏状态:")
print(h_n[1, 0, :5])  # 打印第一个样本的前5个元素

可以看到,nn.RNN默认会输出两个张量:一个是最后一个时间步的所有层,一个是最后一层的所有时间步。它是不会输出“所有时间步的所有层”的。

最后再给出与RNNCell部分类似的,一个完整的训练+测试的demo:

import torch
import torch.nn as nn
import torch.optim as optim# 配置
input_size = 4
hidden_size = 16
seq_len = 6
batch_size = 8
num_classes = 2
epochs = 30# 模型定义
class RNNClassifier(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super().__init__()self.rnn = nn.RNN(input_size, hidden_size, batch_first=False)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):# x: [seq_len, batch_size, input_size]output, h_n = self.rnn(x)  # h_n: [num_layers=1, batch_size, hidden_size]out = self.fc(h_n.squeeze(0))  # 使用最后一层的隐藏状态return out# 数据生成逻辑不变
def generate_batch(batch_size, seq_len, input_size):x = torch.randn(seq_len, batch_size, input_size)last_step = x[-1]labels = (last_step[:, 0] > 0).long()return x, labels# 初始化模型与训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNClassifier(input_size, hidden_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练
for epoch in range(epochs):model.train()x_batch, y_batch = generate_batch(batch_size, seq_len, input_size)x_batch, y_batch = x_batch.to(device), y_batch.to(device)logits = model(x_batch)loss = criterion(logits, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 5 == 0 or epoch == 0:pred = logits.argmax(dim=1)acc = (pred == y_batch).float().mean().item()print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}, Acc: {acc:.2f}")# 测试
model.eval()
with torch.no_grad():x_test, y_test = generate_batch(1, seq_len, input_size)x_test, y_test = x_test.to(device), y_test.to(device)pred = model(x_test).argmax(dim=1)print("\nTest sample:")print("Target label:", y_test.item())print("Predicted   :", pred.item())

相关文章:

pytorch nn.RNN demo

之前已经讲过关于RNNCell的实现了. 这里用LLM写了一个简单的nn.RNN demo: import torch import torch.nn as nn# 设置随机种子以便结果可复现 torch.manual_seed(42)# 定义模型参数 input_size 4 # 输入特征维度 hidden_size 8 # 隐藏层维度 num_layer…...

高防服务器流量“清洗”什么意思

在当今数字化的时代,网络安全成为了备受关注的焦点。其中,高防服务器流量“清洗”这个概念,对于许多朋友来说可能还比较陌生。今天,就让我们一起来揭开它神秘的面纱。 首先,咱们得明白,高防服务器流量“清…...

Unity3D开发AI桌面精灵/宠物系列 【六】 人物模型 语音口型同步 LipSync 、梅尔频谱MFCC技术、支持中英文自定义编辑- 基于 C# 语言开发

Unity3D开发AI桌面精灵/宠物系列 【六】 人物模型 语音口型同步 LipSync 、梅尔频谱MFCC技术 C# 语言开发 该系列主要介绍怎么制作AI桌面宠物的流程,我会从项目开始创建初期到最终可以和AI宠物进行交互为止,项目已经开发完成,我会仔细梳理一下…...

Java详解LeetCode 热题 100(17):LeetCode 41. 缺失的第一个正数(First Missing Positive)详解

文章目录 1. 题目描述2. 理解题目3. 解法一:排序法(不满足题目要求)3.1 思路3.2 Java代码实现3.3 代码详解3.4 复杂度分析3.5 不足之处 4. 解法二:哈希表法4.1 思路4.2 Java代码实现4.3 代码详解4.4 复杂度分析4.5 不足之处 5. 解…...

Kafka消息路由分区机制深度解析:架构设计与实现原理

一、消息路由系统的核心架构哲学 1.1 分布式系统的三元悖论 在分布式消息系统的设计过程中,架构师需要平衡三个核心诉求:数据一致性、系统可用性和分区容忍性。Kafka的分区路由机制本质上是对CAP定理的实践解: 一致性维度:通过…...

用C语言实现了——一个基于顺序表的插入排序演示系统

一、知识要点、 插入排序是一种简单直观的排序算法,它的工作方式类似于我们整理扑克牌。 基本原理: 插入排序通过构建有序序列来工作。它每次从无序序列中取出一个元素,然后将其插入到已排序序列的适当位置。这个过程重复进行,…...

linux libdbus使用案例

以下是一个基于 Linux libdbus 的详细指南,包含服务端和客户端的完整代码示例,涵盖 方法调用、信号发送 和 异步消息处理。libdbus 是 D-Bus 的底层 C 库,直接操作 D-Bus 协议,适合需要精细控制的场景。 1. libdbus 的核心机制 连接管理:通过 dbus_bus_get 连接系统总线或…...

Apple Vision Pro空间视频创作革命:从180度叙事到沉浸式语法的重构——《Adventure》系列幕后技术深度解析

🌌 引言:沉浸式媒体的“语法实验室” Apple Vision Pro的推出标志着空间计算时代的到来,而《Adventure》系列作为其原生内容标杆,正在成为沉浸式叙事的“语法实验室”。导演Charlotte Mikkelborg与播客主持人Kent Bye的对话揭示了这一领域的技术突破、创作挑战与行业生态…...

[特殊字符] 苍穹外卖项目中的 WebSocket 实战:实现来单与催单提醒功能

🚀 苍穹外卖项目中的 WebSocket 实战:实现来单与催单提醒功能 在现代 Web 应用中,实时通信成为提升用户体验的关键技术之一。WebSocket 作为一种在单个 TCP 连接上进行全双工通信的协议,被广泛应用于需要实时数据交换的场景&#…...

【C/C++】深度解析C++ Allocator:优化内存管理的关键

文章目录 深度解析C Allocator:优化内存管理的关键1 默认 std::allocator2 自定义 Allocator3 自定义 Allocator 的实现3.1 基本结构3.2 使用自定义 Allocator 4 关键特性详解4.1 rebind 机制4.2 状态化 Allocator 5 应用示例:内存池 Allocator5.1 简单内…...

gitlab+portainer 实现Ruoyi Vue前端CI/CD

1. 场景 最近整了一个Ruoyi Vue 项目,需要实现CICD,经过一番坎坷,最终达成,现将技术要点和踩坑呈现。 具体操作流程和后端大同小异,后端操作参考连接如下: https://blog.csdn.net/leinminna/article/detai…...

CAPL编程系列_04

1_ 测试模块TestModule:基本使用 1)在Simulation Setup 中创建并配置 Test Module节点 2)编写测试脚本 【1】测试用例函数(testcase):实现具体测试逻辑 【2】主测试函数(Main Test)&…...

Weblogic SSRF漏洞复现(CVE-2014-4210)【vulhub靶场】

漏洞概述: Weblogic中存在一个SSRF漏洞,利用该漏洞可以发送任意HTTP请求,进而攻击内网中redis、fastcgi等脆弱组件。 漏洞形成原因: WebLogic Server 的 UDDI 组件(uddiexplorer.war)中的 SearchPublicR…...

科技的成就(六十八)

623、杰文斯悖论 杰文斯悖论是1865年经济学家威廉斯坦利杰文斯提出的一悖论:当技术进步提高了效率,资源消耗不仅没有减少,反而激增。例如,瓦特改良的蒸汽机让煤炭燃烧更加高效,但结果却是煤炭需求飙升。 624、代码混…...

知从科技闪耀2025上海车展:以创新驱动未来出行新篇章

上海,2025年4月23日——全球汽车科技领域的年度盛会——2025上海国际汽车工业展览会(简称“上海车展”)于5月2日圆满落幕。作为智能汽车软件与系统解决方案的领军企业,知从科技受邀参展,并在活动期间全方位展示了其在智…...

【iOS安全】Dopamine越狱 iPhone X iOS 16.6 (20G75) | 解决Jailbreak failed with error

Dopamine越狱 iPhone X iOS 16.6 (20G75) Dopamine兼容设备 参考:https://www.bilibili.com/opus/977469285985157129 A9 - A11(iPhone6s-X):iOS15.0-16.6.1 A12-A14(iPhoneXR-12PM&#xf…...

医疗数据迁移质量与效率的深度研究:三维六阶框架与实践创新

引言 随着医疗信息化建设的深入推进,医疗数据作为医疗机构的核心资产,其价值与日俱增。在医院信息系统升级、迁移或整合过程中,数据迁移的质量与效率直接关系到医疗服务的连续性、患者信息的安全性以及医院运营的稳定性。传统数据迁移方法往往面临时间长、风险高、成本大等…...

[6-8] 编码器接口测速 江协科技学习笔记(7个知识点)

1 2 在STM32微控制器的定时器模块中,CNT通常指的是定时器的计数器值。以下是CNT是什么以及它的用途: 是什么: • CNT:代表定时器的当前计数值。在STM32中,定时器从0开始计数,直到达到预设的自动重装载值&am…...

java类加载阶段与双亲委派机制

java执行过程:.java->.class->然后被jvm加载解释执行。 一、类加载机制的三个阶段 ​​加载(Loading)​​ ​​任务​​:通过类的全限定名获取二进制字节流(如从文件系统、网络等),将字节流转换为方…...

医院网络安全托管服务(MSS)深度解读与实践路径

医疗行业网络安全挑战与MSS的应运而生 医疗行业在数智化转型的过程中面临着前所未有的网络安全挑战。根据2025年的最新数据,医疗行业将面临大量网络攻击,其中高达91%与勒索软件有关,且45%的数据泄露事件源于第三方供应商。医疗机构的平均数据…...

计算图存储采用矩阵吗,和张量关系

计算图存储采用矩阵吗,和张量关系 计算图的存储方式与张量的关系 一、计算图的存储方式 计算图(Computational Graph)是一种用于描述数学运算的有向无环图(DAG),其节点代表运算(如加减乘除、矩阵乘法、激活函数等),边代表运算的输入和输出(通常是张量)。计算图的…...

RPA 自动化实现自动发布

📕我是廖志伟,一名Java开发工程师、《Java项目实战——深入理解大型互联网企业通用技术》(基础篇)、(进阶篇)、(架构篇)清华大学出版社签约作家、Java领域优质创作者、CSDN博客专家、…...

博途软件直接寻址AMS348i读取位置值详解

一、AMS348i简介 AMS348i是一种高性能绝对值编码器,常用于工业自动化领域的位置检测。它具有以下特点: 高精度位置测量 多种通信接口(如SSI、PROFIBUS、PROFINET等) 坚固的工业设计 支持多种安装方式 二、元器件及配件 设备…...

MySQL 学习(十)执行一条查询语句的内部执行过程、MySQL分层

目录 一、MySQL 执行流程图二、MySQL的分层2.1 连接阶段2.2 查询缓存阶段(Query Cache,MySQL 8.0已移除)2.3 解析与预处理阶段(词法分析、语法分析、预处理器)2.4 查询优化阶段2.5 执行引擎阶段 三、常见面试题3.1 MyS…...

C语言中的指定初始化器

什么是指定初始化器? C99标准引入了一种更灵活、直观的初始化语法——指定初始化器(designated initializer), 可以在初始化列表中直接引用结构体或联合体成员名称的语法。通过这种方式,我们可以跳过某些不需要初始化的成员,并且可以以任意顺序对特定成员进行初始化。这…...

什么是 NB-IoT ?窄带IoT 应用

物联网使各种应用能够与大量无线通信设备进行连接和通信。它有望为智能城市、公用事业、制造设施、农业应用、远程工业机械等提供动力。这些应用均可使用窄带物联网(NB-IoT )网络协议。 例如,智能城市可使用 NB-IoT 监控整个城市的街道照明、…...

CSRF 和 XSS 攻击分析与防范

CSRF 和 XSS 攻击分析与防范 CSRF (跨站请求伪造) 什么是 CSRF? CSRF (Cross-Site Request Forgery) 是一种攻击方式,攻击者诱使用户在已登录目标网站的情况下,执行非预期的操作。 攻击流程: 用户登录可信网站 A在不登出 A 的…...

Window下Jmeter多机压测方法

1.概述 Jmeter多机压测的原理,是通过单个jmeter客户端,控制多个远程的jmeter服务器,使他们同步的对服务器进行压力测试。 以此方式收集测试数据的好处在于: 保存测试采样数据到本地机器通过单台机器管理多个jmeter执行引擎测试…...

Apache RocketMQ ACL 2.0 全新升级

📖知识延伸:本文相关知识库已收录至「RocketMQ 中文社区」,同步更新更多进阶内容 引言 RocketMQ 作为一款流行的分布式消息中间件,被广泛应用于各种大型分布式系统和微服务中,承担着异步通信、系统解耦、削峰填谷和消…...

第九讲 | 模板进阶

模板进阶 一、非类型模板参数1、模板参数的分类2、应用场景3、array4、注意 二、模板的特化1、概念2、函数模板特化3、类模板特化(1)、全特化:全部模板参数都特化成具体的类型(2)、偏/半特化:部分模板参数特…...

联合建模组织学和分子标记用于癌症分类|文献速递-深度学习医疗AI最新文献

Title 题目 Joint modeling histology and molecular markers for cancer classification 联合建模组织学和分子标记用于癌症分类 01 文献速递介绍 癌症是对人类致命的恶性肿瘤,早期准确诊断对癌症治疗至关重要。目前,病理诊断仍是癌症诊断的金标准…...

会计要素+借贷分录+会计科目+账户,几个银行会计的重要概念

1.借贷分录还是借贷分路 正确表述是“借贷分录”。 “分录”即会计分录,它是指预先确定每笔经济业务所涉及的账户名称,以及计入账户的方向和金额的一种记录,简称分录。 在借贷记账法下,会计分录通过“借”和“贷”来表示记账方向…...

【C++】set和multiset的常用接口详解

前⾯我们已经接触过STL中的部分容器如:string、vector、list、deque、array、forward_list等,本篇文章将介绍一下map和multiset的使用。 1. 序列式容器和关联式容器 在介绍set之前我们先简单介绍一下什么是序列式容器和关联式容器。 前⾯我们已经接触过S…...

PostgreSQL 联合索引生效条件

最近面试的时候,总会遇到一个问题 在 PostgreSQL 中,联合索引在什么条件下会生效? 特此记录~ 前置信息 数据库版本 PostgreSQL 14.13, compiled by Visual C build 1941, 64-bit 建表语句 CREATE TABLE people (id SERIAL PRIMARY KEY,c…...

聊聊redisson的lockWatchdogTimeout

序 本文主要研究一下redisson的lockWatchdogTimeout lockWatchdogTimeout redisson/src/main/java/org/redisson/config/Config.java private long lockWatchdogTimeout 30 * 1000;/*** This parameter is only used if lock has been acquired without leaseTimeout param…...

数据结构第七章(三)-树形查找:红黑树

树形查找(二) 红黑树一、红黑树1.定义2.黑高3.性质 二、插入1.插入步骤2.举例 总结 红黑树 红黑树来喽~ 我们在上一篇说了二叉排序树(BST)和平衡二叉树(AVL),那么既然都有这两个了,…...

C++篇——多态

目录 引言 1,什么是多态 2. 多态的定义及实现 2_1,多态的构成条件 2_2,虚函数 2_3,虚函数的重写 2_4,虚函数重写的两个例外 2_4_1,协变(基类与派生类虚函数返回值类型不同) 2_4_2. 析构函数的重写(基类…...

AI实时对话的通信基础,WebRTC技术综合指南

在通过您的网络浏览器进行音频和视频通话、屏幕共享或实时数据传输时,您可能并不常思考其背后的技术。推动这些功能的核心力量之一就是WebRTC。2011年由谷歌发布的这个开源项目,如今已发展成为一个高度全面且不断扩展的生态系统。尤其是在AI技术大幅突破…...

【寻找Linux的奥秘】第五章:认识进程

请君浏览 前言1. 冯诺依曼体系结构数据流动 2. 操作系统(Operating System)2.1 概念2.2 设计OS的目的2.3 如何理解“管理”2.4 系统调用和库函数概念 3. 进程3.1 基本概念3.1.1 查看进程3.1.2 创建进程 3.2 进程状态3.2.1 简单介绍3.2.2 运行&&阻…...

uniapp微信小程序-长按按钮百度语音识别回显文字

流程图&#xff1a; 话不多说&#xff0c;上代码&#xff1a; <template><view class"content"><view class"speech-chat" longpress"startSpeech" touchend"endSpeech"><view class"animate-block" …...

支付宝创建商家订单收款码(统一收单线下交易预创建).net开发的软件附带大型XML文件可以删除吗?AlipaySDKNet.OpenAPI.xml

支付宝创建商家订单收款码&#xff08;统一收单线下交易预创建&#xff09;一个程序55MB&#xff0c;XML就带了35MB AlipaySDKNet.OpenAPI.xml&#xff0c;BouncyCastle.Crypto.xml 支付宝店铺收款码创建的程序&#xff0c;这些文件可以不用吗 在支付宝店铺收款码创建的程序中…...

Profinet转Ethernet/IP网关模块通信协议适配配置

案例背景 在某自动化生产车间中&#xff0c;现有控制系统采用了西门子 S7 - 1500 PLC 作为主要控制器&#xff0c;负责生产流程的核心控制。同时&#xff0c;由于部分设备的历史原因&#xff0c;存在使用 AB 的 PLC 进行特定环节控制的情况。为了实现整个生产系统的信息交互与…...

4.6/Q1,GBD数据库最新文章解读

文章题目&#xff1a;Global burden, subtype, risk factors and etiological analysis of enteric infections from 1990-2021: population based study DOI&#xff1a;10.3389/fcimb.2025.1527765 中文标题&#xff1a;1990-2021 年肠道感染的全球负担、亚型、危险因素和病因…...

数字孪生技术:开启未来的“镜像”技术

想象一下&#xff0c;你拥有一个与现实世界一模一样的 “数字分身”&#xff0c;它不仅长得像你&#xff0c;行为举止、思维方式也和你毫无二致&#xff0c;甚至能提前预知你的下一步行动。这听起来像是科幻电影里的情节&#xff0c;但数字孪生技术却让它在现实中成为了可能。数…...

Java 序列化(Serialization)

一、理论说明 1. 序列化的定义 Java 序列化是指将对象转换为字节流的过程&#xff0c;以便将其存储到文件、数据库或通过网络传输。反序列化则是将字节流重新转换为对象的过程。通过实现java.io.Serializable接口&#xff0c;类可以被标记为可序列化的&#xff0c;该接口是一…...

Python解析Excel入库如何做到行的拆分

我们读取解析Excel入库经常会遇到这种场景&#xff0c;那就是行的拆分&#xff0c;如图&#xff1a; 比如我们入库&#xff0c;要以name为主键&#xff0c;可是表格name的值全是以逗号分割的多个&#xff0c;这怎么办呢&#xff1f;这就必须拆成多行了啊。 代码如下&#xff…...

信创国产化监控 | 达梦数据库监控全解析

达梦数据库&#xff08;DM Database&#xff09;是国产数据库的代表产品之一&#xff0c;在政府、金融、电信、能源等多个关键行业应用广泛&#xff0c;它具有高兼容性、高安全性、高可用性、高性能、自主可控等特点。随着国产化替代进程加速&#xff0c;达梦数据库在关键信息基…...

Parsec解决PnP连接失败的问题

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、准备环境二、DMZ三、端口映射1.Parsec设置固定端口2.路由器设置端口转发3.重启被控端Parsec四、多少一句1.有光猫管理员账号2.没有光猫管理员账号总结 前言…...

LLM笔记(二)LLM数据基础

核心目标: 构建 LLM 的数据基础&#xff0c;将原始文本转化为模型可处理的、包含丰富语义和结构信息的数值形式。 一、 环境与库准备 (Environment & Libraries): 必要库确认: 在开始之前&#xff0c;确保 torch (PyTorch深度学习框架) 和 tiktoken (OpenAI的高效BPE分词…...

让三个线程(t1、t2、t3)按顺序依次打印 A、B、C

public class ThreadWait {private static final Object lock = new Object();private static boolean t1Output=true;private static boolean t2Output=false;private static boolean t3Output=false;public static void main(String[] args) {//线程1new Thread(new Runnable…...