当前位置: 首页 > news >正文

# 构建和训练一个简单的CBOW词嵌入模型

构建和训练一个简单的CBOW词嵌入模型

在自然语言处理(NLP)领域,词嵌入是一种将词汇映射到连续向量空间的技术,这些向量能够捕捉词汇之间的语义关系。在这篇文章中,我们将构建和训练一个简单的Continuous Bag of Words(CBOW)词嵌入模型,使用PyTorch框架。
在这里插入图片描述

简介

CBOW模型是一种预测给定上下文中目标词的模型。它通过学习上下文词的向量表示来预测目标词。这种方法在处理大量文本数据时非常有效,因为它可以捕捉词汇之间的语义和语法关系。

环境准备

在开始之前,确保你的环境中安装了以下库:

  • PyTorch
  • NumPy
  • tqdm(用于显示进度条)

如果未安装,可以通过以下命令安装:

pip install torch numpy tqdm

数据准备

我们将使用一个简单的文本语料库来训练我们的模型。这个语料库包含一些句子,我们将从中提取词汇和构建训练数据集。

raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells."""

数据预处理

首先,我们需要将文本转换为模型可以理解的格式。这包括创建词汇表、单词到索引的映射以及构建训练数据集。

vocab= set(raw_text)  # 集合。词库,里面内容就是无人
vocab_size = len(vocab) # 计算词汇集合的大小word_to_idx = {word: i for i, word in enumerate(vocab)}  # for 循环的复合写法。第1次循环,i得到的索引号, word 取1个单词
idx_to_word = {i: word for i, word in enumerate(vocab)}data = []  # 获取上下文词, 将上下文词作为输入, 目标词作为输出。构建训练数据集。
# 遍历文本,提取上下文和目标词对,构建训练数据集
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):  # (2, 60)。context = ([raw_text[i - (2-j)] for j in range(CONTEXT_SIZE)] + # 获取左边的上下文词[raw_text[i + j + 1] for j in range(CONTEXT_SIZE)]   # 获取右边的上下文词)  # 获取上下文词 ['we','are','to','study']target = raw_text[i]  # 获取目标词 'about'data.append((context, target))  # 将上下文词和目标词保存到data中[((['we','are','to','study']),'about')]

模型定义

接下来,我们定义CBOW模型。这个模型包括嵌入层、投影层和输出层。

# 定义CBOW模型类
class CBOW(nn.Module):def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim) # 定义嵌入层self.proj = nn.Linear(embedding_dim, 128) # 定义投影层self.output = nn.Linear(128, vocab_size) # 定义输出层def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1) # 计算上下文词的嵌入向量的平均值out = F.relu(self.proj(embeds)) # 通过ReLU激活函数out = self.output(out) # 通过输出层nll_prob = F.log_softmax(out, dim=1) # 计算负对数似然损失return nll_prob

训练模型

现在,我们训练模型。我们将使用Adam优化器和负对数似然损失函数。

# 确定设备(GPU或CPU)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)
# 初始化模型并将其移动到指定设备
model = CBOW(vocab_size, 10).to(device)# 初始化优化器
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器# 初始化损失列表
losses = []
# 定义损失函数
loss_function = nn.NLLLoss()
# 设置模型为训练模式
model.train()
# 训练模型100个epoch
for epoch in tqdm(range(100)): # 开始训练total_loss = 0# 遍历所有数据for context, target in data:context_vector = make_context_vector(context, word_to_idx).to(device)target = torch.tensor([word_to_idx[target]]).to(device)# 开始前向传播train_predict = model(context_vector)loss = loss_function(train_predict, target)# 反向传播optimizer.zero_grad()  # 梯度值清零loss.backward() # 反向传播计算得到每个参数的梯度值optimizer.step() #根据梯度更新网络参数total_loss += loss.item()# 记录每个epoch的损失losses.append(total_loss)print(losses)

测试模型

最后,我们测试模型,看看它如何预测给定上下文的下一个词。

# 定义一个上下文列表,包含四个词
context = ['People', 'create', 'to', 'direct']# 将上下文词转换为模型可以理解的向量形式
context_vector = make_context_vector(context, word_to_idx).to(device)# 将模型设置为评估模式,在评估模式下,模型的行为会有所改变,例如不会应用Dropout等
model.eval()# 使用模型进行预测,传入上下文向量
predict = model(context_vector)# 获取预测结果中概率最高的索引值,这个索引值对应预测的下一个词
max_idx = predict.argmax(1)# 打印出输入的上下文
print('context', context)# 使用索引到单词的映射字典,将预测的索引值转换为对应的单词,并打印出来
print("Predicted next word:", idx_to_word[max_idx.item()])

运行结果

4cf4867c49b51598c4ac.png)

结论

通过这篇文章,我们构建了一个简单的CBOW词嵌入模型,并使用PyTorch框架进行了训练和测试。这个模型能够学习词汇的向量表示,并预测给定上下文的下一个词。这对于许多NLP任务,如文本分类、情感分析等,都是非常有用的。

相关文章:

# 构建和训练一个简单的CBOW词嵌入模型

构建和训练一个简单的CBOW词嵌入模型 在自然语言处理(NLP)领域,词嵌入是一种将词汇映射到连续向量空间的技术,这些向量能够捕捉词汇之间的语义关系。在这篇文章中,我们将构建和训练一个简单的Continuous Bag of Words…...

Ubuntu20.04下GraspNet复现流程中的问题

pytorchcudacudnn的版本问题相对于GraspNet来说至关重要!!!至关重要!!!至关重要!!!(重要的事情说三边) 我的显卡是3070 那么首先说结论 使用30系…...

【ROS2】机器人操作系统安装到Ubuntu简介

主要参考: https://book.guyuehome.com/ROS2/1.系统架构/1.3_ROS2安装方法/ 官方文档:https://docs.ros.org/en/humble/Installation.html 虚拟机与ubuntu系统安装 略,见参考文档 ubutun换国内源,略 1. 设置本地语言 确保您有…...

从0到1掌握机器学习核心概念:用Python亲手构建你的第一个AI模型(超多代码+可视化)

🧠 一、开始 真正动手实现一个完整的AI项目!从数据预处理、特征工程、模型训练,到评估与调优,一步步还原你在动画视频中看到的所有核心知识点。 📦 二、环境准备 建议使用 Python 3.8,推荐工具&#xff1…...

Java面试题汇总

1王二哥 https://javabetter.cn/sidebar/sanfene/redis.html#_10-redis-%E6%8C%81%E4%B9%85%E5%8C%96%E6%96%B9%E5%BC%8F%E6%9C%89%E5%93%AA%E4%BA%9B-%E6%9C%89%E4%BB%80%E4%B9%88%E5%8C%BA%E5%88%AB 2.小林 https://www.xiaolincoding.com/redis/data_struct/command.html#…...

Ollama API 应用指南

1. 基础信息 默认地址: http://localhost:11434/api数据格式: application/json支持方法: POST(主要)、GET(部分接口) 2. 模型管理 API (1) 列出本地模型 端点: GET /api/tags功能: 获取已下载的模型列表。示例:curl http://lo…...

React SSR + Redux 导致的 Hydration 报错踩坑记录与修复方案

一条“Hydration failed”的错误,让我损失了半天时间 背景 我在用 Next.js App Router Redux 开发一个任务管理应用,一切顺利,直到打开了 SSR(服务端渲染),突然看到这个令人头皮发麻的报错: …...

【论文精读】Reformer:高效Transformer如何突破长序列处理瓶颈?

目录 一、引言:当Transformer遇到长序列瓶颈二、核心技术解析:从暴力计算到智能优化1. 局部敏感哈希注意力(LSH Attention):用“聚类筛选”替代“全量计算”关键步骤:数学优化: 2. 可逆残差网络…...

iOS18 MSSBrowse闪退

iOS18 MSSBrowse闪退 问题方案结果 问题 最近升级了电脑系统(15.4.1),并且也升级了xcode(16.3)开发工具。之后打包公司很早之前开发的项目。 上线之后发现在苹果手机系统18以上,出现了闪退问题。 涉及到的是第三方MSSBrowse,在选择图片放大的…...

create_function()漏洞利用

什么是 create_function() create_function() 是 PHP 早期提供的一个用来创建匿名函数的函数: $func create_function($a,$b, return $a $b;); echo $func(1, 2); // 输出 3 第一个参数是函数的参数列表(字符串形式),第二个参…...

leetcode-数组

数组 31. 下一个排列 题目 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如,arr [1,2,3] ,以下这些都可以视作 arr 的排列:[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序更大…...

Tailwind CSS 实战:基于 Kooboo 构建个人博客页面

在现代 web 开发中,Tailwind CSS 作为一款实用优先的 CSS 框架,能让开发者迅速搭建出具有良好视觉效果的页面;Kooboo 则是一个强大的快速开发平台,提供了便捷的页面管理和数据处理功能。本文将详细介绍如何结合 Tailwind CSS 和 K…...

C#学习1_认识项目/程序结构

一、C#项目文件的构成 1.新建一个项目 2.运行项目 3.认识文件 1)解决方案(Solution):组织多个项目的容器 抽象理解:餐厅 解决方案.sln文件,点击即可进入VS编辑 2)项目(…...

边缘计算在工业自动化中的应用:开启智能制造新时代

在工业4.0的浪潮中,智能制造成为推动工业发展的核心驱动力。随着物联网(IoT)技术的广泛应用,工业设备之间的互联互通变得越来越紧密,但这也带来了数据处理和传输的挑战。边缘计算作为一种新兴技术,通过将计…...

《MySQL:MySQL表的内外连接》

表的连接分为内连接和外连接。 内连接 内连接实际上就是利用where子句对两种表形成的笛卡尔积进行筛选,之前的文章中所用的查询都是内连接,也是开发中使用的最多的连接查询。 select 字段 from 表1 inner join 表2 on 连接条件 and 其他条件&#xff1…...

人工智能催化民航业变革:五大应用案例

航空业正在经历一场前所未有的技术革命,人工智能正成为变革的主要催化剂。从停机坪到航站楼,从维修机库到客户服务中心,人工智能正在从根本上重塑航空公司的运营和服务方式。这种转变并非仅仅停留在理论上——全球主要航空公司已从人工智能投…...

机器视觉中有哪些常见的光学辅助元件及其作用?

在机器视觉领域,光学元件如透镜、反射镜和棱镜扮演着至关重要的角色。它们不仅是高精度图像捕获的基础,也是提升机器视觉系统性能的关键。深入了解这些光学元件的功能和应用,可以帮助我们更好地掌握机器视觉技术的精髓。 透镜:精…...

Stream API 对两个 List 进行去重操作

在 Java 8 及以上版本中&#xff0c;可以使用 Lambda 表达式和 Stream API 对两个 List 进行去重操作。以下是几种常见的去重场景及对应的 Lambda 表达式实现方式&#xff1a; 1. 合并两个 List 并去重 List<String> list1 Arrays.asList("A", "B"…...

lerna 8.x 详细教程

全局安装 lerna npm install lerna -g初始化项目 mkdir lerna-cli-do cd lerna-cli-do npm init -y初始化项目 lerna init --packages="packages/*"lerna create 创建子项目 lerna create core lerna create util...

ROS第十二梯:ros-noetic和Anaconda联合使用

1) 概述 ros-noetic默认Python版本是Python2.7,但在使用过程中,通常需要明确调用python3进行编译。 Anaconda: 支持创建独立的python2/3环境,避免系统库冲突; 方便安装ROS依赖的科学计算库(如Numpy,Pandas)和机器学习框架; 核心目标:在anaconda环…...

网络原理 - 5(TCP - 2 - 三次握手与四次挥手)

目录 3. 连接管理 建立连接 - 三次挥手 三次握手的意义 断开连接 - 四次挥手 握手和挥手的相似和不同之处 连接管理过程中涉及到的 TCP 状态转换 完&#xff01; 3. 连接管理 连接管理分为建立连接 和 断开连接~&#xff08;important 重点&#xff01;&#xff09; 建…...

【开源】STM32HAL库移植Arduino OneWire库驱动DS18B20和MAX31850

项目开源链接 github主页https://github.com/snqx-lqh本项目github地址https://github.com/snqx-lqh/STM32F103C8T6HalDemo作者 VXQinghua-Li7 &#x1f4d6; 欢迎交流 如果开源的代码对你有帮助&#xff0c;希望可以帮我点个赞&#x1f44d;和收藏 项目说明 最近在做一个项目…...

【maven-7.1】POM文件中的属性管理:提升构建灵活性与可维护性

在Maven项目中&#xff0c;POM (Project Object Model) 文件是核心配置文件&#xff0c;而属性管理则是POM中一个强大但常被低估的特性。良好的属性管理可以显著提升项目的可维护性、减少重复配置&#xff0c;并使构建过程更加灵活。本文将深入探讨Maven中的属性管理机制。 1.…...

DC-2寻找Flag1、2、3、4、5,wpscan爆破、git提权

一、信息收集 1、主机探测 arp-scan -l 探测同网段2、端口扫描 nmap -sS -sV 192.168.66.136 80/tcp open http Apache httpd 2.4.10 ((Debian)) 7744/tcp open ssh OpenSSH 6.7p1 Debian 5deb8u7 (protocol 2.0)这里是扫描出来两个端口&#xff0c;80和ssh&…...

数据结构手撕--【栈和队列】

目录 1、栈 2、队列 1、栈 先进后出&#xff08;都在栈顶进行操作&#xff09; 使用数组结构比使用链式结构更优&#xff0c;因为数组在尾上插入数据的代价更小。并且采用动态长度的数组来表示。 定义结构体 #include <stdio.h> #include <stdlib.h> #include &l…...

八大排序——选择排序/堆排序

八大排序——选择排序/堆排序 目录 一、选择排序 二、堆排序 2.1 大顶堆&#xff08;升序&#xff09; 2.1.1 步骤 2.1.2 代码实现 2.2 小顶堆&#xff08;降序&#xff09; 一、选择排序 每一趟从待排序序列中找到其最小值&#xff0c;然后和待排序序列的第一个值进行交换&am…...

【KWDB 创作者计划】_深度学习篇---归一化反归一化

文章目录 前言一、归一化(Normalization)1. 定义2. 常用方法Min-Max归一化Z-Score标准化(虽常称“标准化”,但广义属归一化)小数缩放(Decimal Scaling)3. 作用4. 注意事项二、反归一化(Denormalization)1. 定义2.方法3. 应用场景三、Python示例演示四、归一化 vs. 标准…...

windows端远程控制ubuntu运行脚本程序并转发ubuntu端脚本输出的网页

背景 对于一些只能在ubuntu上运行的脚本&#xff0c;并且这个脚本会在ubuntu上通过网页展示运行结果。我们希望可以使用windows远程操控ubuntu&#xff0c;在windows上查看网页内容。 方法 start cmd.exe /k "sshpass -p passwd ssh namexxx.xxx.xxx.xxx "cd /hom…...

推荐系统(二十四):Embedding层的参数是如何在模型训练过程中学习的?

近来有不少读者私信我关于嵌入层&#xff08;Embedding层&#xff09;参数在模型训练过程中如何学习的问题。虽然之前已经在不少文章介绍过 Embedding&#xff0c;但是为了读者更好地理解&#xff0c;笔者将通过本文详细解读嵌入层&#xff08;Embedding Layer&#xff09;的参…...

【Ubuntu】关于系统分区、挂载点、安装位置的一些基本信息

在ubuntu22及以前的版本中&#xff0c;最好是手动配置分区及其挂载点&#xff0c;通常我们会配置成3/4个分区&#xff1a; 引导区&#xff0c;交换区&#xff0c;根挂载点&#xff0c;home挂载点&#xff08;有时根挂载点和home合二为一&#xff09; 配置各种环境所占用的内存 …...

概率dp总结

概率 DP 用于解决概率问题与期望问题&#xff0c;建议先对 概率 & 期望 的内容有一定了解。一般情况下&#xff0c;解决概率问题需要顺序循环&#xff0c;而解决期望问题使用逆序循环&#xff0c;如果定义的状态转移方程存在后效性问题&#xff0c;还需要用到 高斯消元 来优…...

深入解析:RocketMQ、RabbitMQ和Kafka的区别与使用场景

互联网大厂Java求职者面试&#xff1a;RocketMQ、RabbitMQ和Kafka的深入解析 故事场景&#xff1a;严肃且专业的面试官与架构师程序员马架构 在一家知名的互联网大厂&#xff0c;Java求职者正在接受一场严格的面试。面试官是一位经验丰富的技术专家&#xff0c;他将通过多轮提…...

探秘Transformer系列之(30)--- 投机解码

探秘Transformer系列之&#xff08;30&#xff09;— 投机解码 文章目录 探秘Transformer系列之&#xff08;30&#xff09;--- 投机解码0x00 概述0x01 背景1.1 问题1.2 自回归解码 0x02 定义 & 历史2.1 投机解码2.2 发展历史 0x03 Blockwise Parallel Decoding3.1 动机3.2…...

【CSS】层叠,优先级与继承(三):超详细继承知识点

目录 继承一、什么是继承&#xff1f;2.1 祖先元素2.2 默认继承/默认不继承 二、可继承属性2.1 字体相关属性2.2 文本相关属性2.3 列表相关属性 三、不可继承属性3.1 盒模型相关属性3.2 背景相关属性 四、属性初始值4.1 根元素4.2 属性的初始值4.3 得出结论 五、强制继承5.1 in…...

SpringBoot中6种自定义starter开发方法

在SpringBoot生态中,starter是一种特殊的依赖,它能够自动装配相关组件,简化项目配置。 自定义starter的核心价值在于: • 封装复杂的配置逻辑,实现开箱即用 • 统一技术组件的使用规范,避免"轮子"泛滥 • 提高开发效率,减少重复代码 方法一:基础配置类方式 …...

时间自动填写——电子表格公式的遗憾(DeepSeek)

now()/today()缘源来&#xff0c;人肉值粘胜无依。用函数抓取系统时间&#xff0c;人肉CTRLC“永葆青春”——直接时间数据存储。 笔记模板由python脚本于2025-04-23 23:21:44创建&#xff0c;本篇笔记适合想要研究电子表格日期自动填写的coder翻阅。 【学习的细节是欢悦的历程…...

AUTODL关闭了程序内存依然占满怎么办

AutoDL帮助文档 关闭了程序&#xff0c;使用nvidia-smi查看&#xff0c;内存任然爆满&#xff1a; 执行 ps -ef | grep train | awk {print $2} | xargs kill -9...

Spark集群搭建之Yarn模式

1.把spark安装包复制到你存放安装包的目录下&#xff0c;例如我的是/opt/software cd /opt/software&#xff0c;进入到你存放安装包的目录 然后tar -zxvf 你的spark安装包的完整名字 -C /opt/module&#xff0c;进行解压。例如我的spark完整名字是spark-3.1.1-bin-hadoop3.2.…...

CSS-跟随图片变化的背景色

CSS-跟随图片变化的背景色 获取图片的主要颜色并用于背景渐变需要安装依赖 colorthief获取图片的主要颜色. 并丢给背景注意 getPalette并不是个异步方法 import styles from ./styles.less; import React, { useState } from react; import Colortheif from colorthief;cons…...

一,开发环境安装

环境安装选择的版本如下 Python3.7 Anaconda5.3.1 CUDA 10.0 Pycharm Anaconda安装:下载地址 CUDA 10.0安装,包下载地址...

局部最小实验--用最小成本确保方向正确

### **将「局部最小实验」融入「简单、专注、本分」认知框架的实践方案** 你的核心认知框架是 **「简单、专注、本分」**&#xff0c;而 **「局部最小实验」**&#xff08;MVP思维&#xff09;本质上是一种 **低成本验证、快速迭代** 的方法论。二者看似矛盾&#xff08;简单…...

【网络应用程序设计】实验三:网络聊天室

个人博客&#xff1a;https://alive0103.github.io/ 代码在GitHub&#xff1a;https://github.com/Alive0103/XDU-CS-lab 能点个Star就更好了&#xff0c;欢迎来逛逛哇~❣ 主播写的刚够满足基本功能&#xff0c;多有不足&#xff0c;仅供参考&#xff0c;还请提PR指正&#xff…...

【泊松过程和指数分布】

泊松过程的均值函数与方差函数计算 1. 泊松过程的定义 泊松过程是一个计数过程 { N ( t ) , t ≥ 0 } \{N(t), t \geq 0\} {N(t),t≥0}&#xff0c;满足以下条件&#xff1a; 独立增量&#xff1a;在不相交时间段内事件发生次数相互独立&#xff1b;平稳增量&#xff1a;在时…...

leetcode-排序

排序 面试题 01.01. 判定字符是否唯一 题目 实现一个算法&#xff0c;确定一个字符串 s 的所有字符是否全都不同。 示例 1&#xff1a; 输入: s "leetcode" 输出: false 示例 2&#xff1a; 输入: s "abc" 输出: true限制&#xff1a; 0 < len(s) &…...

Axure中继器表格:实现复杂交互设计的利器

在产品原型设计领域&#xff0c;Axure凭借其强大的元件库和交互功能&#xff0c;成为设计师们手中的得力工具。其中&#xff0c;中继器元件在表格设计方面展现出了独特的优势&#xff0c;结合动态面板等元件&#xff0c;能够打造出功能丰富、交互体验良好的表格原型。本文将深入…...

容器内部无法访问宿主机服务的原因及解决方法

容器内部无法访问宿主机服务的原因及解决方法 问题原因 当你在Docker容器内部尝试访问宿主机上的服务(如192.168.130.148:8000)时失败,通常有以下几种原因: 网络隔离:Docker容器默认使用自己的网络命名空间,与宿主机网络隔离IP地址误解:容器内看到的宿主机IP与外部网络不…...

IMU---MPU6050

一、芯片概述 1. 基本定位 型号&#xff1a;MPU6050&#xff0c;InvenSense&#xff08;现TDK&#xff09;推出的全球首款6轴MEMS运动传感器&#xff0c;集成3轴加速度计、3轴陀螺仪&#xff0c;内置温度传感器&#xff08;非6轴核心功能&#xff09;。定位&#xff1a;低成本…...

提高Spring Boot开发效率的实践

Spring Boot开发效率的重要性 Spring Boot 作为一个开源的 Java 框架,旨在简化新 Spring 应用和微服务的创建与开发 1。其核心特性,如自动配置、约定优于配置以及内嵌服务器,极大地降低了开发门槛,使得开发者可以更专注于业务逻辑的实现 1。在现代应用开发领域,Spring Bo…...

Spring Boot的优点:赋能现代Java开发的利器

Spring Boot 是基于 Spring 框架的快速开发框架&#xff0c;自 2014 年发布以来&#xff0c;凭借其简洁性、灵活性和强大的生态系统&#xff0c;成为 Java 后端开发的首选工具。尤其在 2025 年&#xff0c;随着微服务、云原生和 DevOps 的普及&#xff0c;Spring Boot 的优势更…...

Python内置函数---breakpoint()

用于在代码执行过程中动态设置断点&#xff0c;暂停程序并进入调试模式。 1. 基本语法与功能 breakpoint(*args, kwargs) - 参数&#xff1a;接受任意数量的位置参数和关键字参数&#xff0c;但通常无需传递&#xff08;默认调用pdb.set_trace()&#xff09;。 - 功能&#x…...