当前位置: 首页 > news >正文

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

代码:

import torch
import numpy as np
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score# 定义数据:x_data 是特征,y_data 是标签(目标值)
data = [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]]# 将数据转为 numpy 数组
data = np.array(data)# 提取 x_data 和 y_data
x_data = data[:, 0]  # 取第一列作为输入特征
y_data = data[:, 1]  # 取第二列作为目标标签# 将数据转换为 PyTorch 张量
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标标签# 使用 TensorDataset 来创建一个数据集
from torch.utils.data import DataLoader, TensorDatasetdataset = TensorDataset(x_train, y_train)  # 使用训练数据创建数据集
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # 将数据集转换为 DataLoader,批大小为 2,且每个 epoch 都会随机打乱数据# 定义损失函数:均方误差损失 (MSELoss)
criterion = nn.MSELoss()# 定义线性回归模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()# 使用一个线性层,输入为1维,输出为1维self.layers = nn.Linear(1, 1)def forward(self, x):# 直接返回线性层的输出return self.layers(x)model=LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoches =500
for n in range(1,epoches+1):epoch_loss=0#以前都是所有数据一块训练,现在是按照批次进行训练for batch_x,batch_y in dataloader:#现在x_train 相当于10个样本,但是现在维度,添加一个维度#10x1   变成样本 x 维度形式y_prd=model(batch_x.unsqueeze(1))#计算损失#y_prd在前面,y_true 是后面batch_loss=criterion(y_prd.squeeze(1),batch_y)#梯度更新#清空之前存储在优化器中的梯度optimizer.zero_grad()#损失函数对模型参数的梯度batch_loss.backward()#根据优化算法更新参数optimizer.step()#计算一下epoch的损失epoch_loss=epoch_loss+batch_loss# 5、显示频率设置#计算一下epoch的平均损失avg_loss=epoch_loss/(len(dataloader))# 不先画图if n % 10 == 0 or n == 1:print(f"epoches:{n},loss:{avg_loss}")torch.save(model.state_dict(),'model.pth')model.load_state_dict(torch.load("model.pth"))
#评估模型
# 评估模型一定要加下面这句话
model.eval()
# 定义数据
x_test=torch.tensor([[1.8]],dtype=torch.float32)
#添加上下文不需要计算梯度
with torch.no_grad():y_pred=model(x_test)threshold = 50  # 设定阈值
y_pred_class = int(y_pred.item() > threshold)# 输出预测结果
print(f"预测值 : {y_pred.item():.4f}")
print(f"预测类 : {y_pred_class}")# 假设真实标签也是 1 或 0,我们用一个假的真实标签来计算评估指标(你可以根据实际情况替换)
y_true_class = 1 if y_data[1] > threshold else 0  # 假设我们预测的是第二个样本# 计算精确度、召回率和 F1 分数
accuracy = accuracy_score([y_true_class], [y_pred_class])
precision = precision_score([y_true_class], [y_pred_class])
recall = recall_score([y_true_class], [y_pred_class])
f1 = f1_score([y_true_class], [y_pred_class])# 输出分类评估指标
print(f"precision : {precision:.4f}")
print(f"recall : {recall:.4f}")
print(f"f1 : {f1:.4f}")

结果:

相关文章:

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

代码: import torch import numpy as np import torch.nn as nn from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score# 定义数据:x_data 是特征,y_data 是标签(目标值) data [[-0…...

Linux02——Linux的基本命令

目录 ls 常用选项及功能 综合示例 注意事项 cd和pwd命令 cd命令 pwd命令 相对路径、绝对路径和特殊路径符 特殊路径符号 mkdir命令 1. 功能与基本用法 2. 示例 3. 语法与参数 4. -p选项 touch-cat-more命令 1. touch命令 2. cat命令 3. more命令 cp-mv-rm命…...

MySQL数据库(二)- SQL

目录 ​编辑 一 DDL (一 数据库操作 1 查询-数据库(所有/当前) 2 创建-数据库 3 删除-数据库 4 使用-数据库 (二 表操作 1 创建-表结构 2 查询-所有表结构名称 3 查询-表结构内容 4 查询-建表语句 5 添加-字段名数据类型 6 修改-字段数据类…...

Docker自定义镜像

Dockerfile自定义镜像 一:镜像结构 镜像是将应用程序及其需要的系统函数库、环境、配置、依赖打包而成。 我们以MySQL为例,来看看镜像的组成结构: 简单来说,镜像就是在系统函数库、运行环境基础上,添加应用程序文件、…...

网络协议基础

文章目录 前言一、网络协议分层1.应用层2.传输层3.网络层4.数据链路层5.物理层 二、图解IP1.IP基本认识(1)IP的作用(2)IP与MAC的关系 2.IP地址的基础知识(1)IP地址的定义(2)IP地址的…...

c语言进阶(简单的函数 数组 指针 预处理 文件 结构体)

c语言补充 格式 void函数头 {} 中的是函数体 sum函数名 &#xff08;&#xff09; 参数表 #include <stdio.h>void sum(int begin, int end) {int i;int sum 0;for (i begin ; i < end ; i) {sum i;}printf("%d到%d的和是%d\n", begin, end, sum); …...

Pytorch框架从入门到精通

目录 一、Tensors 1.1 初始化一个Tensor 1&#xff09;赋值初始化 2&#xff09;从 NumPy 数组初始化 3&#xff09;从另一个张量 4&#xff09;使用随机值或常量值 1.2 Tensor 的属性 1.3 对 Tensor 的操作 1.3.1 总体介绍 1.3.2 索引和切片 1.3.3 算术运算 矩阵乘…...

Vue.js组件开发-实现全屏图片文字缩放切换特效

使用 Vue 实现全屏图片文字缩放切换特效 步骤 创建 Vue 项目&#xff1a;使用 Vue CLI 来快速创建一个新的 Vue 项目。设计组件结构&#xff1a;创建一个包含图片和文字的组件&#xff0c;并实现缩放和切换效果。实现样式&#xff1a;使用 CSS 来实现全屏显示、缩放和切换动画…...

在 WSL2 中重启 Ubuntu 实例

在 WSL2 中重启 Ubuntu 实例&#xff0c;可以按照以下步骤操作&#xff1a; 方法 1: 使用 wsl 命令 关闭 Ubuntu 实例: 打开 PowerShell 或命令提示符&#xff0c;运行以下命令&#xff1a; wsl --shutdown这会关闭所有 WSL2 实例。 重新启动 Ubuntu: 再次打开 Ubuntu&#x…...

Flutter 新春第一弹,Dart 宏功能推进暂停,后续专注定制数据处理支持

在去年春节&#xff0c;Flutter 官方发布了宏&#xff08;Macros&#xff09;编程的原型支持&#xff0c; 同年的 5 月份在 Google I/O 发布的 Dart 3.4 宣布了宏的实验性支持&#xff0c;但是对于 Dart 内部来说&#xff0c;从启动宏编程实验开始已经过去了几年&#xff0c;但…...

Signature

打开得到加密脚本&#xff1a; import ecdsa import randomdef ecdsa_test(dA,k):sk ecdsa.SigningKey.from_secret_exponent(secexpdA,curveecdsa.SECP256k1)sig1 sk.sign(databHi., kk).hex()sig2 sk.sign(databhello., kk).hex()r1 int(sig1[:64], 16)s1 int(sig1[64:…...

UE求职Demo开发日志#18 数据表获取物品信息,添加背包模块

1 把获取物品信息改为读取数据表 先创建结构&#xff0c;暂时有这几个属性&#xff1a; USTRUCT(BlueprintType) struct ARPG_CPLUS_API FMyItemData:public FTableRowBase {GENERATED_USTRUCT_BODY()UPROPERTY(EditAnywhere, BlueprintReadWrite)int ItemId;//物品Id&#x…...

neo4j-community-5.26.0 create new database

1.edit neo4j.conf 把 # The name of the default database initial.dbms.default_databasehonglouneo4j # 写上自己的数据库名称 和 # Name of the service #5.0 server.windows_service_nameneo4j #4.0 dbms.default_databaseneo4j #dbms.default_databaseneo4jwind serve…...

项目中用的网关Gateway及SpringCloud

在现代微服务架构中&#xff0c;网关&#xff08;Gateway&#xff09;起到了至关重要的作用。它不仅负责路由请求&#xff0c;还提供了统一的认证、授权、负载均衡、限流等功能。Spring Cloud Gateway 是 Spring Cloud 生态系统中的一个重要组件&#xff0c;专门为微服务架构提…...

​《Ollama Python 库​》

Ollama Python 库 Ollama Python 库提供了将 Python 3.8 项目与 Ollama 集成的最简单方法。 先决条件 应该安装并运行 Ollama拉取一个模型以与库一起使用&#xff1a;例如ollama pull <model>ollama pull llama3.2 有关可用模型的更多信息&#xff0c;请参阅 Ollama.com。…...

大模型概述(方便不懂技术的人入门)

1 大模型的价值 LLM模型对人类的作用&#xff0c;就是一个百科全书级的助手。有多么地百科全书&#xff0c;则用参数的量来描述&#xff0c; 一般地&#xff0c;大模型的参数越多&#xff0c;则该模型越好。例如&#xff0c;GPT-3有1750亿个参数&#xff0c;GPT-4可能有超过1万…...

Ubuntu16.04编译安装Cartographer 1.0版本

说明 官方文档 由于Ubuntu16.04已经是很老的系统&#xff0c;如果直接按照Cartographer官方安装文档安装会出现代码编译失败的问题&#xff0c;本文给出了解决这些问题的办法。正常情况下执行本文给出的安装方法即可成功安装。 依赖安装 # 这里和官方一致 # Install the req…...

AI-ISP论文Learning to See in the Dark解读

论文地址&#xff1a;Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…...

2 MapReduce

2 MapReduce 1. MapReduce 介绍1.1 MapReduce 设计构思 2. MapReduce 编程规范3. Mapper以及Reducer抽象类介绍1.Mapper抽象类的基本介绍2.Reducer抽象类基本介绍 4. WordCount示例编写5. MapReduce程序运行模式6. MapReduce的运行机制详解6.1 MapTask 工作机制6.2 ReduceTask …...

OpenCV:SIFT关键点检测与描述子计算

目录 1. 什么是 SIFT&#xff1f; 2. SIFT 的核心步骤 2.1 尺度空间构建 2.2 关键点检测与精细化 2.3 方向分配 2.4 计算特征描述子 3. OpenCV SIFT API 介绍 3.1 cv2.SIFT_create() 3.2 sift.detect() 3.3 sift.compute() 3.4 sift.detectAndCompute() 4. SIFT 关…...

初识Cargo:Rust的强大构建工具与包管理器

初识Cargo&#xff1a;Rust的强大构建工具与包管理器 如果你刚刚开始学习Rust&#xff0c;一定会遇到一个名字&#xff1a;Cargo。Cargo是Rust的官方构建工具和包管理器&#xff0c;它让Rust项目的创建、编译、测试和依赖管理变得非常简单。本文将带你快速了解Cargo的基本用法…...

LightM-UNet(2024 CVPR)

论文标题LightM-UNet: Mamba Assists in Lightweight UNet for Medical Image Segmentation论文作者Weibin Liao, Yinghao Zhu, Xinyuan Wang, Chengwei Pan, Yasha Wang and Liantao Ma发表日期2024年01月01日GB引用> Weibin Liao, Yinghao Zhu, Xinyuan Wang, et al. Ligh…...

2025年02月01日Github流行趋势

项目名称&#xff1a;oumi 项目地址url&#xff1a;https://github.com/oumi-ai/oumi 项目语言&#xff1a;Python 历史star数&#xff1a;544 今日star数&#xff1a;103 项目维护者&#xff1a;xrdaukar, oelachqar, taenin, wizeng23, kaisopos 项目简介&#xff1a;一切你需…...

自动化测试框架搭建-封装requests-优化

目的 1、实际的使用场景&#xff0c;无法避免的需要区分GET、POST、PUT、PATCH、DELETE等不同的方式请求&#xff0c;以及不同请求的传参方式 2、python中requests中&#xff0c;session.request方法&#xff0c;GET请求&#xff0c;只支持params传递参数 session.request(me…...

什么是线性化PDF?

线性化PDF是一种特殊的PDF文件组织方式。 总体而言&#xff0c;PDF是一种极为优雅且设计精良的格式。PDF由大量PDF对象构成&#xff0c;这些对象用于创建页面。相关信息存储在一棵二叉树中&#xff0c;该二叉树同时记录文件中每个对象的位置。因此&#xff0c;打开文件时只需加…...

XML DOM 浏览器差异

DOM 解析中的浏览器差异 所有现代的浏览器都支持 W3C DOM 规范。 然而&#xff0c;浏览器之间是有差异的。一个重要的差异是&#xff1a; 处理空白和换行的方式 DOM - 空白和换行 XML 经常在节点之间包含换行或空白字符。这是在使用简单的编辑器&#xff08;比如记事本&…...

电子电气架构 --- 汽车电子拓扑架构的演进过程

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 简单&#xff0c;单纯&#xff0c;喜欢独处&#xff0c;独来独往&#xff0c;不易合同频过着接地气的生活…...

01-六自由度串联机械臂(ABB)位置分析

ABB工业机器人&#xff08;IRB2600&#xff09;如下图所示&#xff08;d1444.8mm&#xff0c;a1150mm&#xff0c;a2700mm&#xff0c;a3115mm&#xff0c;d4795mm&#xff0c;d685mm&#xff09;&#xff0c;利用改进DH法建模&#xff0c;坐标系如下所示&#xff1a; 利用改进…...

04树 + 堆 + 优先队列 + 图(D1_树(D6_B树(B)))

目录 一、学习前言 二、基本介绍 三、特性 1. 从概念上说起 2. 举个例子 四、代码实现 节点准备 大体框架 实现分裂 实现新增 实现删除 五、完整源码 一、学习前言 前面我们已经讲解过了二叉树、二叉搜索树&#xff08;BST&#xff09;、平衡二叉搜索树&#xff08…...

350.两个数组的交集 ②

目录 题目过程解法 题目 给你两个整数数组 nums1 和 nums2 &#xff0c;请你以数组形式返回两数组的交集。返回结果中每个元素出现的次数&#xff0c;应与元素在两个数组中都出现的次数一致&#xff08;如果出现次数不一致&#xff0c;则考虑取较小值&#xff09;。可以不考虑…...

C#,入门教程(09)——运算符的基础知识

上一篇&#xff1a; C#&#xff0c;入门教程(08)——基本数据类型及使用的基础知识https://blog.csdn.net/beijinghorn/article/details/123906998 一、算术运算符号 算术运算符号包括&#xff1a;四则运算 加 , 减-, 乘*, 除/与取模%。 // 加法&#xff0c;运算 int va 1 …...

Python-基于PyQt5,wordcloud,pillow,numpy,os,sys等的智能词云生成器

前言&#xff1a;日常生活中&#xff0c;我们有时后就会遇见这样的情形&#xff1a;我们需要将给定的数据进行可视化处理&#xff0c;同时保证呈现比较良好的量化效果。这时候我们可能就会用到词云图。词云图&#xff08;Word cloud&#xff09;又称文字云&#xff0c;是一种文…...

海外问卷调查之渠道查,企业经营的指南针

海外问卷调查&#xff0c;是企业调研最常用到的方法&#xff0c;有目的、有计划、有系统地收集研究对象的现实状况或历史状况的一种有效手段&#xff0c;是指导企业经营的有效手段。 海外问卷调查充分运用历史法、观察法等方法&#xff0c;同时使用谈话、问卷、个案研究、测试…...

C++:虚函数与多态性习题

题目内容&#xff1a; 构建一个车&#xff08;vehicle&#xff09;基类&#xff0c;包含Run、Stop两个纯虚函数。由此基类&#xff0c;派生出&#xff08;Car&#xff09;轿车类&#xff0c;&#xff08;truck&#xff09;卡车类&#xff0c;在这两个类中别分定义Run和Stop两个…...

单片机基础模块学习——超声波传感器

一、超声波原理 左边发射超声波信号&#xff0c;右边接收超声波信号 左边的芯片用来处理超声波发射信号&#xff0c;中间的芯片用来处理接收的超声波信号 二、超声波原理图 T——transmit 发送R——Recieve 接收 U18芯片对输入的N_A1信号进行放大&#xff0c;然后输入给超声…...

通过protoc工具生成proto的pb.go文件以及使用protoc-go-inject-tag工具注入自定义标签

1.ProtoBuf认识,安装以及用法 参考:[golang 微服务] 3. ProtoBuf认识&#xff0c;安装以及golang 中ProtoBuf使用 2. 使用protoc-go-inject-tag工具注入自定义标签 这里有一个案例: syntaxproto3; package test;option go_package ".;test";message MyMessage {int6…...

42【语言的编码架构】

不同语言采用的编码架构不一样 火山采用&#xff1a;UTF-16 易语言采用&#xff1a;GBK php采用&#xff1a;UTF-8 这个编码架构指的就是文本所代表的字节集&#xff0c;比如易语言中“你好”表示的就是{196,227,186,195} 窗口程序集名保 留 保 留备 注窗口程序集_启动窗口 …...

TOF技术原理和静噪对策

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时&#xff0c;也能帮助其他需要参考的朋友。如有谬误&#xff0c;欢迎大家进行指正。 一、什么是TOF TOF 是Time of Flight的缩写&#xff0c;它是一种通过利用照射波和反射波之间的时间差来测量到物体的距离的测…...

ssh调试:fatal: Could not read from remote repository.

我遇到的原因和网上说的什么在生产密钥时没加邮箱&#xff0c;以及多个密钥的配置问题都不一样&#xff1b; 例如https://blog.csdn.net/baoyin0822/article/details/122584931 或https://blog.csdn.net/qq_55558061/article/details/124117445 我遇到的问题的原因跟他们都i不…...

win10部署本地deepseek-r1,chatbox,deepseek联网(谷歌网页插件)

win10部署本地deepseek-r1&#xff0c;chatbox&#xff0c;deepseek联网&#xff08;谷歌网页插件&#xff09; 前言一、本地部署DeepSeek-r1step1 安装ollamastep2 下载deepseek-r1step2.1 找到模型deepseek-r1step2.2 cmd里粘贴 后按回车&#xff0c;进行下载 step3 测试指令…...

SpringCloud系列教程:微服务的未来(十九)请求限流、线程隔离、Fallback、服务熔断

前言 前言 在现代微服务架构中&#xff0c;系统的高可用性和稳定性至关重要。为了解决系统在高并发请求或服务不可用时出现的性能瓶颈或故障&#xff0c;常常需要使用一些技术手段来保证服务的平稳运行。请求限流、线程隔离、Fallback 和服务熔断是微服务中常用的四种策略&…...

Hot100之子串

560和为K的子数组 题目 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列 思路解析 ps&#xff1a;我们的presum【0】就是0&#xff0c;如果没有这个0的话我们的第一个元素就无法减去上…...

SpringBoot笔记

1.创建 使用idea提供的脚手架创建springboot项目&#xff0c;选上需要的模块&#xff0c;会自动进行导包 打成jar包&#xff0c;之前直接用原生的maven打包的是一个瘦jar&#xff0c;不能直接跑&#xff0c;把服务器上部署的jar排除在外了&#xff0c;但是现在加上打包查件&am…...

一、TensorFlow的建模流程

1. 数据准备与预处理&#xff1a; 加载数据&#xff1a;使用内置数据集或自定义数据。 预处理&#xff1a;归一化、调整维度、数据增强。 划分数据集&#xff1a;训练集、验证集、测试集。 转换为Dataset对象&#xff1a;利用tf.data优化数据流水线。 import tensorflow a…...

4 Hadoop 面试真题

4 Hadoop 面试真题 1. Apache Hadoop 3.0.02. HDFS 3.x 数据存储新特性-纠删码Hadoop面试真题 1. Apache Hadoop 3.0.0 Apache Hadoop 3.0.0在以前的主要发行版本&#xff08;hadoop-2.x&#xff09;上进行了许多重大改进。 最低要求的Java版本从Java 7增加到Java 8 现在&…...

信息学奥赛一本通 ybt 1608:【 例 3】任务安排 3 | 洛谷 P5785 [SDOI2012] 任务安排

【题目链接】 ybt 1608&#xff1a;【 例 3】任务安排 3 洛谷 P5785 [SDOI2012] 任务安排 【题目考点】 1. 动态规划&#xff1a;斜率优化动规 2. 单调队列 3. 二分答案 【解题思路】 与本题题面相同但问题规模不同的题目&#xff1a; 信息学奥赛一本通 1607&#xff1a…...

实验六 项目二 简易信号发生器的设计与实现 (HEU)

声明&#xff1a;代码部分使用了AI工具 实验六 综合考核 Quartus 18.0 FPGA 5CSXFC6D6F31C6N 1. 实验项目 要求利用硬件描述语言Verilog&#xff08;或VHDL&#xff09;、图形描述方式、IP核&#xff0c;结合数字系统设计方法&#xff0c;在Quartus开发环境下&#xff…...

基于最近邻数据进行分类

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 完整代码&#xff1a; import torch import numpy as np from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt# 生成一个简单的数据…...

SpringSecurity:There is no PasswordEncoder mapped for the id “null“

文章目录 一、情景说明二、分析三、解决 一、情景说明 在整合SpringSecurity功能的时候 我先是去实现认证功能 也就是&#xff0c;去数据库比对用户名和密码 相关的类&#xff1a; UserDetailsServiceImpl implements UserDetailsService 用于SpringSecurity查询数据库 Logi…...

redex快速体验

第一步&#xff1a; 2.回调函数在每次state发生变化时候自动执行...