当前位置: 首页 > news >正文

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。

相关文章:

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。 1. 创建自定义数据集 首先,我们使用 NumPy 创建一个简单的…...

LangChain概述

文章目录 为什么需要LangChainLLM应用开发的最后1公里LangChain的2个关键词LangChain的3个场景LangChain的6大模块 为什么需要LangChain 首先想象一个开发者在构建一个LLM应用时的常见场景。当你开始构建一个新项目时,你可能会遇到许多API接口、数据格式和工具。对于…...

Ubuntu 16.04安装Lua

个人博客地址:Ubuntu 16.04安装Lua | 一张假钞的真实世界 在Linux系统上使用以下命令编译安装Lua: curl -R -O http://www.lua.org/ftp/lua-5.3.3.tar.gz tar zxf lua-5.3.3.tar.gz cd lua-5.3.3 make linux test 安装make 编译过程如果提示以下信息…...

独立开发者产品日刊:将 Figma 设计转化为全栈应用、对话 PDF生成思维导图、视频转 AI 笔记、AI问答引擎、Mac 应用启动器切换器

独立开发者产品日刊,每日汇集 ProductHunt 热榜产品介绍,用一个 Slogan 帮你概括产品内容,期望能够让你快速浏览get最新产品创意,激发在产品上的灵感。 Lovable Builder.io Slogan:将 Figma 设计转化为全栈应用 类别…...

【算法】经典博弈论问题——威佐夫博弈 python

目录 威佐夫博弈(Wythoff Game)【模板】 威佐夫博弈(Wythoff Game) 有两堆石子,数量任意,可以不同,游戏开始由两个人轮流取石子 游戏规定,每次有两种不同的取法 1)在任意的一堆中取走任意多的石子 2)可以在两堆中同时取走相同数量…...

Julius AI 人工智能数据分析工具介绍

Julius AI 是一款由 Casera Labs 开发的人工智能数据分析工具,旨在通过自然语言交互和强大的算法能力,帮助用户快速分析和可视化复杂数据。这款工具特别适合没有数据科学背景的用户,使数据分析变得简单高效。 核心功能 自然语言交互&#x…...

SpringBoot+Electron教务管理系统 附带详细运行指导视频

文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.查询课程表代码2.保存学生信息代码3.用户登录代码 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SpringBootElectron框架开发的教务管理系统。首先&#xff…...

1.23学习记录

web XYNU2024信安杯 哎~想她了 源代码找到提示,访问页面第一层数组绕过,第二层发现ls /可以执行,接着用less代替tac和cat less /fl[a-z]g exp: URL/?fj1[]1&fj2[]2&cmdless /fl[a-z]gmisc [SWPU 2020]套娃 下载附件…...

论文笔记(六十三)Understanding Diffusion Models: A Unified Perspective(三)

Understanding Diffusion Models: A Unified Perspective(三) 文章概括 文章概括 引用: article{luo2022understanding,title{Understanding diffusion models: A unified perspective},author{Luo, Calvin},journal{arXiv preprint arXiv:…...

NeetCode刷题第17天(2025.1.27)

文章目录 086 Course Schedule II 课程安排二087 Graph Valid Tree 图有效树088 Number of Connected Components in an Undirected Graph 无向图中的连接组件数量 086 Course Schedule II 课程安排二 您将获得一个数组 prerequisites ,其中 prerequisites[i] [a,…...

Seed Edge- AGI(人工智能通用智能)长期研究计划

Seed Edge 是字节跳动豆包大模型团队推出的 AGI(人工智能通用智能)长期研究计划12。以下是对它的具体介绍1: 名称含义 “Seed” 即豆包大模型团队名称,“Edge” 代表最前沿的 AGI 探索,整体意味着该项目将在 AGI 领域…...

企业知识管理推动企业整体效能提升与创新能力发展的路径探索

内容概要 企业知识管理是指通过对组织内外部知识的识别、获取、整合与应用,提升企业整体运营效能与竞争力的一系列管理活动。其重要性在于,知识作为一种无形资产,能够显著影响企业的决策质量和创新能力。在当今快速发展的市场环境中&#xf…...

【gopher的java学习笔记】一文讲懂controller,service,mapper,entity是什么

刚开始上手Java和Spring时,就被controller,service,mapper,entity这几个词搞懵了,搞不懂这些究竟代表什么,感觉使用golang开发的时候也没太接触过这些名词啊~ 经过两三个月的开发后,逐渐搞懂了这…...

ZZNUOJ(C/C++)基础练习1000——1010(详解版)

目录 1000 : AB Problem C语言版 C版 1001 : 植树问题 C语言版 C版 1002 : 简单多项式求和 C语言版 C版 1003 : 两个整数的四则运算 C语言版 C版 1004 : 三位数的数位分离 C语言版 C版 补充代…...

volatile之四类内存屏障指令 内存屏障 面试重点 底层源码

目录 volatile 两大特性 可见性 有序性 总结 什么是内存屏障 四个 CPU 指令 四大屏障 重排 重排的类型 为什么会有重排? 线程中的重排和可见性问题 如何防止重排引发的问题? 总结 happens-before 和 volatile 变量规则 内存屏障指令 写操作…...

【deepseek】deepseek-r1本地部署-第二步:huggingface.co替换为hf-mirror.com国内镜像

一、背景 由于国际镜像国内无法直接访问,会导致搜索模型时加载失败,如下: 因此需将国际地址替换为国内镜像地址。 二、操作 1、使用vscode打开下载路径 2、全局地址替换 关键字 huggingface.co 替换为 hf-mirror.com 注意:务…...

Lesson 121 The man in a hat

Lesson 121 The man in a hat 词汇 customer n. 顾客,回头客 相关:Customs 注意大写 n. 海关,关税    custom n. 1. 风俗习惯【通常是复数】    例句:这些是中国人的习俗。       These are Chinese customs.    …...

Excel中LOOKUP函数的使用

文章目录 VLOOKUP(垂直查找):HLOOKUP(水平查找):LOOKUP(基础查找):XLOOKUP(高级查找,较新版本Excel提供): 在Excel中&…...

【Convex Optimization Stanford】Lec3 Function

【Convex Optimization Stanford】Lec3 Function 前言凸函数的定义对凸函数在一条线上的限制增值扩充? 一阶条件二阶条件一些一阶/二阶条件的例子象集和sublevel set关于函数凸性的扩展(Jesen Inequality)保持函数凸性的操作非负加权和 & 仿射函数的…...

香港维尔利健康科技集团重金投资,内地多地体验中心同步启动

香港维尔利健康科技集团近期宣布,将投资数亿港元在内地多个城市建立全新的健康科技体验中心。这一战略举措旨在进一步拓展集团在内地市场的布局,推动创新医疗技术的普及和应用。 多地布局,覆盖主要城市 据悉,维尔利健康科技集团将…...

2025春晚临时直播源接口

临时接口 https://lualist.kooldns.cn/d/%E6%96%87%E4%BB%B6/%E4%B8%B4%E6%97%B6.txt?signCuMCVC5Q0Sb4wA8OQVdNep55oKtaiSxT1du4-DgWeOo:0 备用接口 https://vip.123pan.cn/1833444709/q/%E4%B8%B4%E6%97%B6.txt...

二叉树的层序遍历||力扣--107

目录 题目 思路 代码 题目 给你二叉树的根节点 root ,返回其节点值 自底向上的层序遍历 。 (即按从叶子节点所在层到根节点所在的层,逐层从左向右遍历) 示例 1: 输入:root [3,9,20,null,null,15,7] 输出…...

DeepSeek LLM解读

背景: 量化巨头幻方探索AGI(通用人工智能)新组织“深度求索”在成立半年后,发布的第一代大模型DeepSeek试用地址:DeepSeek ,免费商用,完全开源。作为一家隐形的AI巨头,幻方拥有1万枚…...

如何写美赛(MCM/ICM)论文中的Summary部分

美赛(MCM/ICM)作为一个数学建模竞赛,要求参赛者在有限的时间内解决一个复杂的实际问题,并通过数学建模、数据分析和计算机模拟等手段给出有效的解决方案。在美赛的论文中,Summary部分(通常也称为摘要)是非常关键的,它是整个论文的缩影,能让评审快速了解你解决问题的思…...

C++ 拷贝构造

拷贝构造函数会在以下几种场景中被调用: 1. 用一个对象显式初始化另一个对象。 2. 对象按值传递给函数。 3. 函数按值返回对象。 4. 将对象插入到容器中。 5. 明确调用拷贝构造函数。 1. 当用一个对象显式初始化另一个对象时 MyClass obj1("Hello"); MyClass obj2…...

kotlin内联函数——runCatching

1.runCatching作用 代替try{}catch{}异常处理,用于捕获异常。 2.runCatching函数介绍 参数:上下文引用对象为参数返回值:lamda表达式结果 调用runCatching函数,如果调用成功则返回其封装的结果,并可回调onSuccess函…...

代码随想录算法训练营第三十八天-动态规划-完全背包-322. 零钱兑换

太难了 但听了前面再听这道题感觉递推公式也不是不难理解 动规五部曲 dp[j]代表装满容量为j(也就是目标值)的背包最少物品数量递推公式:dp[j] std::min(dp[j], dp[j - coins[i]] 1)当使用coins[i]这张纸币时,要向前找到容量为…...

跨域问题及解决方案

跨域问题不仅影响开发效率,还可能导致项目进度延误。因此,理解和掌握跨域问题的原理及其解决方案对于前端开发者和后端开发者来说都至关重要。本文将详细介绍什么是跨域、跨域产生的原因,以及常见的后端跨域解决方案。 文章目录 一、什么是跨…...

Python Matplotlib库:从入门到精通

Python Matplotlib库:从入门到精通 在数据分析和科学计算领域,可视化是一项至关重要的技能。Matplotlib作为Python中最流行的绘图库之一,为我们提供了强大的绘图功能。本文将带你从Matplotlib的基础开始,逐步掌握其高级用法&…...

相互作用感知的蛋白-小分子对接模型 - Interformer 评测

Interformer 是一个应用于分子对接和亲和力预测的深度学习模型,基于 Graph-Transdormer 架构的模型,利用相互作用(氢键、疏水)感知的混合密度网络(interaction-aware mixture den sity network, MDN&#x…...

力扣【669. 修剪二叉搜索树】Java题解

一开始在想为什么题目说存在唯一答案。然后发现是二叉搜索树就合理了。如下图:如果0节点小于low,那其左子树也都小于low,故可以排除;对于4,其右子树也是可以排除。 代码如下: class Solution {public Tre…...

装机爱好者的纯净工具箱

对于每一位电脑用户来说,新电脑到手后的第一件事通常是检测硬件性能。今天为大家介绍一款开源且无广告的硬件检测工具——入梦工具箱。 主要功能 硬件信息一目了然 打开入梦工具箱,首先看到的是硬件信息概览。这里不仅包含了内存、主板、显卡、硬盘等常…...

Spring Boot 实现文件上传和下载

文章目录 Spring Boot 实现文件上传和下载一、引言二、文件上传1、配置Spring Boot项目2、创建文件上传控制器3、配置文件上传大小限制 三、文件下载1、创建文件下载控制器 四、使用示例1、文件上传2、文件下载 五、总结 Spring Boot 实现文件上传和下载 一、引言 在现代Web应…...

【开源免费】基于SpringBoot+Vue.JS在线考试学习交流网页平台(JAVA毕业设计)

本文项目编号 T 158 ,文末自助获取源码 \color{red}{T158,文末自助获取源码} T158,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...

Electron学习笔记,安装环境(1)

1、支持win7的Electron 的版本是18,这里node.js用的是14版本(node-v14.21.3-x86.msi)云盘有安装包 Electron 18.x (截至2023年仍在维护中): Chromium: 96 Node.js: 14.17.0 2、安装node环境,node-v14.21.3-x86.msi双击运行选择安…...

个人通知~~~

因学业问题,作品发布比较【慢】所以将间隔调整为3天一作 另外声明:二月一号正式改名:饼干帅成渣 (谐音) 没关住的快关注,求求了。不求点赞,评论,收藏。 最后祝大家新年快乐&…...

C# 中 [MethodImpl(MethodImplOptions.Synchronized)] 的使用详解

总目录 前言 在C#中,[MethodImpl(MethodImplOptions.Synchronized)] 是一个特性(attribute),用于标记方法,使其在执行时自动获得锁。这类似于Java中的 synchronized 关键字,确保同一时刻只有一个线程可以执…...

盛水最多的容器

hello 大家好!今天开写一个新章节,每一天一道算法题。让我们一起来学习算法思维吧! function maxArea(height) {// 初始化最大水量为 0let maxWater 0;// 初始化左指针,指向数组的第一个元素let left 0;// 初始化右指针&#xf…...

数据分析系列--①RapidMiner软件安装

目录 一、软件下载及账号注册 1.软件下载 1.1 CSDN下载国内下载,国内镜像相对快,点击下载 1.2 官网软件下载地址:AI Studio 2025.0 ,服务器在国外相对较慢. 2.软件注册 2.1 点击 注册界面 开始注册,如图: 3.邮箱验证 二、软件安装 1. 新年文件夹,名字最好为英文名 2. 双…...

DeepSeek-R1:开源Top推理模型的实现细节、使用与复现

核心观点 ● 直接用强化学习就可以让模型获得显著的推理能力,说明并不一定需要SFT才行。 ● 强化学习并不一定需要复杂的奖励模型,使用简单的规则反而取得意想不到的效果。 ● 通过知识蒸馏让小模型一定程度上也有推理能力,甚至在某些场景下…...

Linux(19)——使用正则表达式匹配文本

新年快乐! 目录 一、正则表达式: 二、通过 grep 匹配正则表达式: 三、查找匹配项: 一、正则表达式: 正则表达式使用模式匹配机制查找特定内容,vim、grep 和 less 命令都可以使用正则表达式,P…...

STM32 对射式红外传感器配置

这次用的是STM32F103的开发板(这里面的exti.c文件没有how to use this driver 配置说明) 对射式红外传感器 由一个红外发光二极管和NPN光电三极管组成,M3固定安装孔,有输出状态指示灯,输出高电平灯灭,输出…...

2025年1月25日(赋值前引用)

pycharm 提示&#xff1a; 局部变量 ‘start_time’ 可能在赋值前引用 局部变量 ‘stop_time’ 可能在赋值前引用 Traceback (most recent call last):File "/home/raspberry/Desktop/python/01_其他/04_超声波/ml_01_超声波测距.py", line 63, in <module>mai…...

C++基础(1)

目录 1. C发展历史 2. C第一个程序 3. 命名空间 3.1 namespace的价值 3.2 命名空间的定义 3.3 命名空间的使用 4. C输入和输出 5. 缺省参数 6. 函数重载 6.1 实现函数重载的条件 6.2 函数重载的应用 1. C发展历史 C的起源可以追溯到1979年&#xff0c;当时Bjarne…...

【第十天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-两种常见的字符串算法(持续更新)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的字符串算法2.字符串算法3.详细的字符串算法1&#xff09;KMP算法2&#xff09;Rabin-Karp算法 总结 前言…...

开发者交流平台项目部署到阿里云服务器教程

本文使用PuTTY软件在本地Windows系统远程控制Linux服务器&#xff1b;其中&#xff0c;Windows系统为Windows 10专业版&#xff0c;Linux系统为CentOS 7.6 64位。 1.工具软件的准备 maven&#xff1a;https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-m…...

一次端口监听正常,tcpdump无法监听到指定端口报文问题分析

tcpdump命令&#xff1a; sudo tcpdump -i ens2f0 port 6471 -XXnnvvv 下面是各个部分的详细解释&#xff1a; 1.tcpdump: 这是用于捕获和分析网络数据包的命令行工具。 2.-i ens2f0: 指定监听的网络接口。ens2f0 表示本地网卡&#xff09;&#xff0c;即计算机该指定网络接口捕…...

【C++题解】1393. 与7无关的数?

欢迎关注本专栏《C从零基础到信奥赛入门级&#xff08;CSP-J&#xff09;》 问题&#xff1a;1393. 与7无关的数&#xff1f; 类型&#xff1a;简单循环 题目描述&#xff1a; 一个整数&#xff0c;如果这个数能够被 7 整除&#xff0c;或者其中有一位是7&#xff0c;我们称…...

【自学嵌入式(6)天气时钟:软硬件准备、串口模块开发】

天气时钟&#xff1a;软硬件准备、串口模块开发 软硬件准备接线及模块划分ESP8266开发板引脚图软件准备 串口模块编写串口介绍Serial库介绍 近期跟着网上一些教学视频&#xff0c;编写了一个天气时钟&#xff0c;本篇及往后数篇都将围绕天气时钟的制作过程展开。本文先解决硬件…...

高低频混合组网系统中基于地理位置信息的信道测量算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 matlab2022a 3.部分核心程序 &#xff08;完整版代码包含详细中文注释和操作步骤视频&#xff09…...