当前位置: 首页 > news >正文

rust-candle学习笔记12-实现因果注意力

参考:about-pytorch

定义结构体:

struct CausalAttention {w_qkv: Linear,dropout: Dropout, d_model: Tensor,mask: Tensor,device: Device,   
}

定义new方法:

impl CausalAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, dropout: f32, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,mask: Tensor::tril2(seq_len, DType::U32, &device)?,dropout: Dropout::new(dropout),device})}
}

定义forward方法:

    fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> { let qkv = self.w_qkv.forward(x)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let mut attn_score = q.matmul(&k.t()?)?;// println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);let dim = attn_score.rank() - 1;let mask_dim = attn_score.dims()[dim];let mask = self.mask.broadcast_as(attn_score.shape())?;// println!("mask: {:?}\n", mask);// println!("mask: {:?}\n", mask.to_vec3::<u32>()?);attn_score = masked_fill(&attn_score, &mask, f32::NEG_INFINITY)?;// println!("attn_score: {:?}\n", attn_score);// println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?; let attn_weights = ops::softmax(&attn_score, dim)?;// println!("attn_weights: {:?}\n", attn_weights);// println!("attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); let attn_weights = self.dropout.forward(&attn_weights, train)?;// println!("dropout attn_weights: {:?}\n", attn_weights);// println!("dropout attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;let model = CausalAttention::new(vb.clone(), 3, 2, 6, 0.5, device.clone())?;let output = model.forward(&input, true)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

相关文章:

rust-candle学习笔记12-实现因果注意力

参考&#xff1a;about-pytorch 定义结构体&#xff1a; struct CausalAttention {w_qkv: Linear,dropout: Dropout, d_model: Tensor,mask: Tensor,device: Device, } 定义new方法&#xff1a; impl CausalAttention {fn new(vb: VarBuilder, embedding_dim: usize, ou…...

vue3使用tailwindcss报错问题

npm create vitelatestnpm install -D tailwindcss postcss autoprefixernpx tailwindcss init 4. 不过执行 npx tailwindcss init 的时候控制台就报错了PS E:\vite-demo> npx tailwindcss init npm ERR! cb.apply is not a function npm ERR! A complete log of this run c…...

MySQL COUNT(*) 查询优化详解!

目录 前言1. COUNT(*) 为什么慢&#xff1f;—— InnoDB 的“计数烦恼” &#x1f914;2. MySQL 执行 COUNT(*) 的方式 (InnoDB)3. COUNT(*) 优化策略&#xff1a;快&#xff01;准&#xff01;狠&#xff01;策略一&#xff1a;利用索引优化带 WHERE 子句的 COUNT(*) (最常见且…...

5.Redission

5.1 前文锁问题 基于 setnx 实现的分布式锁存在下面的问题&#xff1a; 重入问题&#xff1a;重入问题是指 获得锁的线程可以再次进入到相同的锁的代码块中&#xff0c;可重入锁的意义在于防止死锁&#xff0c;比如 HashTable 这样的代码中&#xff0c;他的方法都是使用 sync…...

RAG 赋能客服机器人:多轮对话与精准回复

一、引言 在人工智能技术飞速发展的今天&#xff0c;客服机器人已成为企业提升服务效率的重要工具。然而&#xff0c;传统客服系统在多轮对话连贯性和精准回复能力上存在明显短板。检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;技术通过结合大语言…...

rust-candle学习笔记13-实现多头注意力

参考&#xff1a;about-pytorch 定义结构体&#xff1a; use core::f32;use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{embedding, linear_no_bias, linear, ops, Dropout, Linear, Module, VarBuilder, VarMap};struct MultiHeadAttention {w_qkv: Li…...

PyTorch API 5 - 全分片数据并行、流水线并行、概率分布

文章目录 全分片数据并行 (FullyShardedDataParallel)torch.distributed.fsdp.fully_shardPyTorch FSDP2 (fully_shard) Tensor Parallelism - torch.distributed.tensor.parallel分布式优化器流水线并行为什么需要流水线并行&#xff1f;什么是 torch.distributed.pipelining&…...

STL-list

一、 list的介绍 std::list 是 C 标准模板库&#xff08;STL&#xff09;中的一种双向链表容器。每个元素包含指向前后节点的指针&#xff0c;支持高效插入和删除操作&#xff0c;但随机访问性能较差。 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#x…...

WPF中如何自定义控件

WPF自定义控件简化版&#xff1a;账户菜单按钮&#xff08;AccountButton&#xff09; 我们以**“账户菜单按钮”为例&#xff0c;用更清晰的架构实现一个支持标题显示、渐变背景、选中状态高亮**的自定义控件。以下是分步拆解&#xff1a; 一、控件核心功能 我们要做一个类似…...

华为云Git使用与GitCode操作指南

案例介绍 本文档带领开发者学习如何在云主机上基于GitCode来使用Git来管理自己的项目代码,并使用一些常用的Git命令来进行Git环境的设置。 案例内容 1 概述 1.1 背景介绍 Git 是一个快速、可扩展的分布式版本控制系统,它拥有异常丰富的命令集,可以提供高级操作和对内部…...

UniRepLknet助力YOLOv8:高效特征提取与目标检测性能优化

文章目录 一、引言二、UniRepLknet 的框架原理&#xff08;一&#xff09;架构概述&#xff08;二&#xff09;架构优势 三、UniRepLknet 在 YOLOv8 中的集成&#xff08;一&#xff09;集成方法&#xff08;二&#xff09;代码实例 四、实验与对比&#xff08;一&#xff09;对…...

【软件工程】基于频谱的缺陷定位

基于频谱的缺陷定位&#xff08;Spectrum-Based Fault Localization, SBFL&#xff09;是一种通过分析程序执行覆盖信息&#xff08;频谱数据&#xff09;来定位代码中缺陷的方法。其核心思想是&#xff1a;通过测试用例的执行结果&#xff08;成功/失败&#xff09;和代码覆盖…...

stm32之IIC

目录 1.I2C1.1 简介1.2 硬件电路1.3 时序基本单元1.4 时序实例1.4.1 指定地址写1.4.2 当前地址读1.4.3 指定地址读 2.MPU60502.1 简介2.2 参数2.3 硬件电路2.4 框图2.5 文档 3.软件操作MPU60504.I2C通信外设4.1 简介4.2 I2C框图4.3 基本结构4.4 主机发送/接收4.5 软件/硬件波形…...

阿里云购买ECS 安装redis mysql nginx jdk 部署jar 部署web

阿里云服务维护 1.安装JDK 查询要安装jdk的版本,命令&#xff1a;yum -y list java* 命令&#xff1a;yum install -y java-1.8.0-openjdk.x86_64 yum install -y java-17-openjdk.x86_64 2.安装nginx 启用 EPEL 仓库 sudo yum install epel-release 安装 Nginx sudo yum …...

记录 ubuntu 安装中文语言出现 software database is broken

搜索出来的结果是 sudo apt-get install language-pack-zh-han* 然而,无效,最后手动安装如下 apt install language-pack-zh-hans apt install language-pack-zh-hans-base apt install language-pack-gnome-zh-hans apt install fonts-arphic-uming apt install libreoffic…...

质数和约数

一、知识和经验 把质数和约数放在一起就是因为他们有非常多的联系&#xff0c;为了验证这个观点我们可以先学习唯一分解定理&#xff1a;一个大于 1 的自然数一定能被唯一分解为有限个质数的乘积。 而且一个数不仅能被质数分解&#xff0c;原本也应该被自己的约数分解&#xf…...

OSPF的四种特殊区域(Stub、Totally Stub、NSSA、Totally NSSA)详解

OSPF的四种特殊区域&#xff08;Stub、Totally Stub、NSSA、Totally NSSA&#xff09;通过限制LSA的传播来优化网络性能&#xff0c;减少路由表规模。以下是它们的核心区别&#xff1a; 1. Stub 区域&#xff08;末梢区域&#xff09; 允许的LSA类型&#xff1a;Type 1-3&#…...

Docker中运行的Chrome崩溃问题解决

问题 各位看官是否在 Docker 容器中的 Linux 桌面环境&#xff08;如Xfce&#xff09;上启动Chrome &#xff0c;遇到了令人沮丧的频繁崩溃问题&#xff1f;尤其是在打开包含图片、视频的网页&#xff0c;或者进行一些稍复杂的操作时&#xff0c;窗口突然消失&#xff1f;如果…...

【从零实现JsonRpc框架#3】线程模型与性能优化

1.Muduo 的线程模型 Muduo 基于 Reactor 模式 &#xff0c;采用 单线程 Reactor 和 多线程 Reactor 相结合的方式&#xff0c;通过事件驱动和线程池实现高并发。 1. 单线程模型 核心思想 &#xff1a;所有 I/O 操作&#xff08;accept、read、write&#xff09;和业务逻辑均…...

Kubernetes资源管理之Request与Limit配置黄金法则

一、从"酒店订房"看K8s资源管理 想象你经营一家云上酒店&#xff08;K8s集群&#xff09;&#xff0c;每个房间&#xff08;Node节点&#xff09;都有固定数量的床位&#xff08;CPU&#xff09;和储物柜&#xff08;内存&#xff09;。当客人&#xff08;Pod&#…...

Windows 上使用 WSL 2 后端的 Docker Desktop

执行命令 docker pull hello-world 执行命令 docker run hello-world 执行命令 wsl -d Ubuntu...

OpenLayers根据任意数量控制点绘制贝塞尔曲线

以下是使用OpenLayers根据任意数量控制点绘制贝塞尔曲线的完整实现方案。该方案支持三个及以上控制点&#xff0c;使用递归算法计算高阶贝塞尔曲线。 实现思路 贝塞尔曲线原理&#xff1a;使用德卡斯特里奥算法&#xff08;De Casteljau’s Algorithm&#xff09;递归计算任意…...

使用 Jackson 在 Java 中解析和生成 JSON

JSON(JavaScript Object Notation)是一种轻量级、跨语言的数据交换格式,因其简单易读和高效解析而广泛应用于 Web 开发、API 通信和数据存储。在 Java 中,处理 JSON 是许多应用程序的核心需求,尤其是在与 RESTful 服务交互或管理配置文件时。Jackson 是一个功能强大且广受…...

Qt中在子线程中刷新UI的方法

Qt中在子线程中刷新UI的方法 在Qt中UI界面并不是线程安全的&#xff0c;意味着在子线程中不能随意操作UI界面组件&#xff08;比如按钮、标签&#xff09;等&#xff0c;如果强行操作这些组件有可能会导致程序崩溃。那么在Qt中如何在子线程中刷新UI控件呢&#xff1f; 两种方…...

封装 RabbitMQ 消息代理交互的功能

封装了与 RabbitMQ 消息代理交互的功能&#xff0c;包括发送和接收消息&#xff0c;以及管理连接和通道。 主要组件 依赖项&#xff1a; 代码使用了多个命名空间&#xff0c;包括 Microsoft.Extensions.Configuration&#xff08;用于配置管理&#xff09;、RabbitMQ.Client&a…...

关于ffmpeg的简介和使用总结

主要参考&#xff1a; 全网最全FFmpeg教程&#xff0c;从新手到高手的蜕变指南 - 知乎 (zhihu.com) FFmpeg入门教程&#xff08;非常详细&#xff09;从零基础入门到精通&#xff0c;看完这一篇就够了。-CSDN博客 FFmpeg教程&#xff08;超级详细版&#xff09; - 个人文章 - S…...

计算机图形学编程(使用OpenGL和C++)(第2版)学习笔记 08.阴影

阴影 没有阴影的渲染效果如下&#xff0c;看起来不真实&#xff1a; 有阴影的渲染效果如下&#xff0c;看起来真实&#xff1a; 显示阴影有两种方式&#xff0c;一种是原书中的方式&#xff0c;另一种是采用光线追踪技术&#xff0c;该技术可以参考ShaderToy学习笔记 08.阴…...

[面试]SoC验证工程师面试常见问题(七)低速接口篇

SoC验证工程师面试常见问题(七)低速接口篇 摘要:低速接口是嵌入式系统和 SoC (System on Chip) 中常用的通信接口,主要用于设备间的短距离、低带宽数据传输。相比高速接口(如 PCIe、USB 3.0),低速接口的传输速率较低(通常在 kbps 到几 Mbps 范围),但具有简单…...

算法训练营第十三天|226.翻转二叉树、101. 对称二叉树、 104.二叉树的最大深度、111.二叉树的最小深度

递归 递归三部曲&#xff1a; 1.确定参数和返回值2.确定终止条件3.确定单层逻辑 226.翻转二叉树 题目 思路与解法 第一想法&#xff1a; 递归&#xff0c;对每个结点进行反转 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, le…...

电子电器架构 --- 车载网关的设计

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 钝感力的“钝”&#xff0c;不是木讷、迟钝&#xff0c;而是直面困境的韧劲和耐力&#xff0c;是面对外界…...

`C_PiperInterface` 类接口功能列表

C_PiperInterface 类接口功能列表 C_PiperInterface 提供了全面的接口&#xff0c;用于控制 Piper 机械臂的运动、查询状态、设置参数以及管理 SDK 限制。 官仓链接 以下是 C_PiperInterface 类中所有接口的功能总结&#xff1a; 1. 初始化与连接相关接口 __new__: 实现单例…...

D. Apple Tree Traversing 【Codeforces Round 1023 (Div. 2)】

D. Apple Tree Traversing 题目大意 有一个包含 n n n 个节点的苹果树&#xff0c;初始时每个节点上有一个苹果。你有一张纸&#xff0c;初始时纸上没有任何内容。 你需要通过以下操作遍历苹果树&#xff0c;直到所有苹果都被移除&#xff1a; • 选择一个苹果路径 ( u , v…...

Docker镜像搬运工:save与load命令的实战指南

在日常的容器化开发中&#xff0c;镜像的搬运和部署是每个开发者必须掌握的技能。今天我们将深入探讨Docker的"save"和"load"这对黄金搭档&#xff0c;揭秘它们在镜像管理中的妙用。 一、基础认知&#xff1a;镜像的打包与解包 docker save 和 docker loa…...

查看Electron 应用的调试端口

以下是一些可以知道已发布第三方 Electron 应用调试端口的方法&#xff1a; * **通过命令行参数查看** &#xff1a; * 如果该 Electron 应用在启动时添加了类似 --remote-debugging-portxxxx 或 --inspectxxxx 的参数&#xff0c;那么其调试端口就是该参数指定的端口号。比…...

各种环境测试

加载测试专用属性 当在测试时想要加入某些配置且对其他测试类不产生影响是可以用Import注释添加配置 测试类中启动web环境 默认为none不开启...

腾讯云低代码实战:零基础搭建家政维修平台

目录 1. 欢迎与项目概览1.1 教程目的与受众1.2 项目愿景与目标&#xff1a;我们要搭建一个怎样的平台&#xff1f;1.3 平台核心构成与架构解析1.4 技术栈选择与考量1.5 如何高效阅读本教程 欢迎来到“腾讯云云开发低代码实战&#xff1a;从零搭建家政维修服务平台”开发教程&am…...

居然智家亮相全零售AI火花大会 AI大模型赋能家居新零售的进阶之路

当人工智能技术以摧枯拉朽之势重构商业世界时&#xff0c;零售业正在经历一场静默而深刻的革命。在这场变革中&#xff0c;居然智家作为新零售领域的创新标杆&#xff0c;凭借其在AI技术应用上的超前布局和持续深耕&#xff0c;已悄然构建起从消费场景到产业生态的智能化闭环。…...

微服务6大拆分原则

微服务6大拆分原则 微服务拆分是指将一个大型应用程序拆分成独立服务的过程&#xff0c;在微服务拆分时&#xff0c;需要考虑以下6大微服务拆分原则 一、单一职责原则 微服务单一职责原则&#xff0c;是指每个微服务应该专注于解决一个明确定义的业务领域或功能&#xff0c;…...

进程间通信--管道【Linux操作系统】

文章目录 进程间通信&#xff08;IPC&#xff09;进程间通信的目的1. 数据交换2. 资源共享3. 进程协同4. 系统解耦5. 分布式计算IPC 的典型方式对比总结 进程间通信的前提 匿名管道匿名管道的原理创建匿名管道的过程如果不关闭不需要的读写端会怎样&#xff1f;为什么父进程要同…...

模型实时自主训练系统设计

模型实时自主训练系统设计 一、系统架构 #mermaid-svg-MLuTBuo7ehvStoqS {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-MLuTBuo7ehvStoqS .error-icon{fill:#552222;}#mermaid-svg-MLuTBuo7ehvStoqS .error-text{f…...

5.1 神经网络: 层和块

1 层&#xff08;Layer&#xff09; 1.1 定义 层是深度学习模型中的基本构建单元&#xff0c;它由一组神经元组成&#xff0c;负责对输入数据进行特定的数学运算和变换&#xff0c;以提取数据的某种特征或表示。每一层可以看作是一个函数&#xff0c;它接收输入数据&#xff…...

鸿蒙系统使用ArkTS开发语言支持身份证阅读器、社保卡读卡器等调用二次开发SDK

har库导入&#xff1a; { "license": "", "devDependencies": {}, "author": "", "name": "entry", "description": "Please describe the basic information.", &qu…...

【Bootstrap V4系列】学习入门教程之 组件-输入组(Input group)

Bootstrap V4系列 学习入门教程之 组件-输入组(Input group) 输入组(Input group)Basic example一、Wrapping 包装二、Sizing 尺寸三、Multiple inputs 多输入四、Multiple addons 多个插件五、Button addons 按钮插件六、Buttons with dropdowns 带下拉按钮七、Custom for…...

图像处理篇--- HTTP|RTSP|MJPEG视频流格式

文章目录 前言一、MJPEG (Motion JPEG)基本概念技术特点编码方式传输协议数据格式 优势实现简单低延迟兼容性好容错性强 劣势带宽效率低不支持音频缺乏标准控制 典型应用 二、RTSP (Real Time Streaming Protocol)基本概念技术特点协议栈工作流程传输模式 优势专业流媒体支持高…...

`RotationTransition` 是 Flutter 中的一个动画组件,用于实现旋转动画效果

RotationTransition 是 Flutter 中的一个动画组件&#xff0c;用于实现旋转动画效果。它允许你对子组件进行动态的旋转变换&#xff0c;从而实现平滑的动画效果。RotationTransition 通常与 AnimationController 和 Tween 一起使用&#xff0c;以控制动画的开始、结束和过渡效果…...

养生:开启健康生活的密钥

在快节奏的现代生活中&#xff0c;养生已成为追求健康的重要方式。从饮食、运动到生活习惯&#xff0c;每一个细节都关乎身体的健康。以下为你介绍科学养生的实用方法&#xff0c;助你打造健康生活。 饮食养生&#xff1a;均衡营养&#xff0c;滋养身体 合理的饮食是养生的基…...

大模型微调算法原理:从通用到专用的桥梁

前言 本文聚焦大模型落地中的核心矛盾——理论快速发展与实际应用需求之间的脱节,并系统探讨微调技术作为解决这一矛盾的关键手段。尽管大模型展现出强大的通用能力,但其在垂直领域的直接应用仍面临适配性不足、计算成本高等挑战。微调通过在预训练模型基础上进行针对性优化,…...

引言:Client Hello 为何是 HTTPS 安全的核心?

当用户在浏览器中输入 https:// 时&#xff0c;看似简单的操作背后&#xff0c;隐藏着一场加密通信的“暗战”。Client Hello 作为 TLS 握手的首个消息&#xff0c;不仅决定了后续通信的加密强度&#xff0c;还可能成为攻击者的突破口。据统计&#xff0c;超过 35% 的网站因 TL…...

深度学习中的目标检测:从 PR 曲线到 AP

深度学习中的目标检测&#xff1a;从 PR 曲线到 AP 在目标检测任务中&#xff0c;评估模型的性能是非常重要的。通过使用不同的评估指标和标准&#xff0c;我们可以量化模型的准确性与效果。今天我们将重点讨论 PR 曲线&#xff08;Precision-Recall Curve&#xff09;、平均精…...

测试左移系列-产品经理实战-实战认知1

课程&#xff1a;B站大学 记录产品经理实战项目系统性学习&#xff0c;从产品思维&#xff0c;用户画像&#xff0c;用户体验&#xff0c;增长数据驱动等不同方向理解产品&#xff0c;从0到1去理解产品从需求到落地的全过程&#xff0c;测试左移方向&#xff08;靠近需求、设计…...