当前位置: 首页 > news >正文

Pytorch基础教程:从零实现手写数字分类

文章目录

  • 1.Pytorch简介
  • 2.理解tensor
    • 2.1 一维矩阵
    • 2.2 二维矩阵
    • 2.3 三维矩阵
  • 3.创建tensor
    • 3.1 你可以直接从一个Python列表或NumPy数组创建一个tensor:
    • 3.2 创建特定形状的tensor
    • 3.3 创建三维tensor
    • 3.4 使用随机数填充tensor
    • 3.5 指定tensor的数据类型
  • 4.tensor基本运算
    • 4.1 ‌算术运算:
      • 4.1.1 ‌加法‌:
      • 4.1.2 ‌减法‌:
      • 4.1.3 ‌乘法‌(逐元素乘法,不是矩阵乘法):
      • 4.1.4 ‌除法‌:
      • 4.1.5 ‌幂运算‌:
    • 4.2 ‌‌矩阵运算:
      • 4.2.1 ‌‌矩阵乘法‌:
      • 4.2.2 ‌‌‌矩阵转置‌:
      • 4.2.3 ‌‌‌逐元素运算的广播机制:
    • 4.3 ‌‌‌‌统计运算:
      • 4.3.1 ‌‌‌‌求和‌:
      • 4.3.2 ‌‌‌‌平均值‌:
      • 4.3.3 ‌‌‌‌最大值和最小值‌:
    • 4.4 形状操作:
      • 4.4.1 ‌‌‌‌改变形状‌:
      • 4.4.2 ‌‌‌‌展平‌:
  • 5.理解神经网络
    • 5.1 什么是分类,什么是回归
      • 5.1.1 分类
      • 5.1.2 回归
    • 5.2 有什么函数可以实现分类和回归?答案:线性回归
      • 5.2.1 从二元一次方程组到Simple Linear Regression
      • 5.2.1 什么是线性关系(Linear Relationship)?
    • 5.3 线性回归到神经网络
      • 5.3.1 理解新模型 - 简易版神经网络
  • 6.定义网络结构
  • 7.整理数据集
  • 8.训练模型
  • 9.保存模型
  • 10.加载模型
  • 11.使用GPU加速
    • 11.1 要将model(model里面的参数w和b)从内存放到显存。
    • 11.2 把加载的数据从内存加载到显存
    • 11.3 只要将模型的数据,测试数据和训练数据加载到显存,自然会使用GPU进行处理。

1.Pytorch简介

‌PyTorch是一个开源的深度学习框架,由Facebook的人工智能研究院(FAIR)开发,并于2017年1月正式推出。‌ PyTorch以其灵活性和易用性著称,特别适合于深度学习模型的构建和训练。它基于Torch张量库开发,提供了动态计算图的功能,允许在运行时改变计算图,这使得模型构建更加灵活

2.理解tensor

Tensor 是PyTorch中最近本的数据结构,可以将其视为n维数组或者矩阵。n维矩阵在我们生活里非常常见。

2.1 一维矩阵

在这里插入图片描述
基本操作
在这里插入图片描述

2.2 二维矩阵

二维矩阵是一个表格,其中包含行和列。例如:
在这里插入图片描述
在这个矩阵中,
a11a 11 、a12a12 、a13a13 、a21a21 、a22a22 、a23a23 、a31a31 、a32a32 、a33a33 是矩阵的元素。第一个方括号内的元素属于第一行,第二个方括号内的元素属于第二行,第三个方括号内的元素属于第三行。

2.3 三维矩阵

三维矩阵可以看作是由多个二维矩阵组成的“矩阵的矩阵”,通常用于表示多维数据。例如,一个3x3x3的三维矩阵可以表示为:
在这里插入图片描述
在这个三维矩阵中,每一个二维矩阵(由方括号包围的部分)可以看作是一个“层”,整个三维矩阵由这些层组成。

3.创建tensor

3.1 你可以直接从一个Python列表或NumPy数组创建一个tensor:

import torch# 从Python列表创建tensor
data = [[1, 2], [3, 4]]
tensor_from_list = torch.tensor(data)
print(tensor_from_list)# 从NumPy数组创建tensor
import numpy as np
np_array = np.array(data)
tensor_from_np = torch.from_numpy(np_array)
print(tensor_from_np)

在这里插入图片描述

3.2 创建特定形状的tensor

你可以使用torch的内置函数来创建具有特定形状和值的tensor:

import torch# 创建一个全为零的tensor,形状为(2, 3)
zeros_tensor = torch.zeros((2, 3))
print(zeros_tensor)# 创建一个全为一的tensor,形状为(2, 3)
ones_tensor = torch.ones((2, 3))
print(ones_tensor)# 创建一个未初始化的tensor,形状为(2, 3),其值可能是随机的
uninit_tensor = torch.empty((2, 3))
print(uninit_tensor)# 创建一个具有指定值的tensor,所有元素都设为5,形状为(2, 3)
full_tensor = torch.full((2, 3), 5)
print(full_tensor)

在这里插入图片描述

3.3 创建三维tensor

要创建一个三维tensor,你只需指定三个维度的大小:

import torch# 创建一个形状为(2, 3, 4)的三维tensor,所有元素都初始化为0
three_dim_tensor = torch.zeros((2, 3, 4))
print(three_dim_tensor)

3.4 使用随机数填充tensor

你还可以使用随机数来填充tensor:

import torch# 创建一个形状为(2, 3)的tensor,其元素是从均匀分布[0, 1)中抽取的随机数
rand_tensor = torch.rand((2, 3))
print(rand_tensor)# 创建一个形状为(2, 3)的tensor,其元素是从标准正态分布中抽取的随机数
randn_tensor = torch.randn((2, 3))
print(randn_tensor)

在这里插入图片描述

3.5 指定tensor的数据类型

在创建tensor时,你还可以指定其数据类型(dtype):

import torch# 创建一个形状为(2, 3)的tensor,元素类型为浮点数(默认)
float_tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(float_tensor)
print(float_tensor.dtype)  # 输出: torch.float32# 创建一个形状为(2, 3)的tensor,元素类型为整数
int_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)
print(int_tensor)
print(int_tensor.dtype)  # 输出: torch.int32

在这里插入图片描述

4.tensor基本运算

4.1 ‌算术运算:

4.1.1 ‌加法‌:

import torch
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
z = x + y  # 对应元素相加
print(z)  # 输出: tensor([5., 7., 9.])

4.1.2 ‌减法‌:

z = x - y  # 对应元素相减
print(z)  # 输出: tensor([-3., -3., -3.])

4.1.3 ‌乘法‌(逐元素乘法,不是矩阵乘法):

z = x * y  # 对应元素相乘
print(z)  # 输出: tensor([4., 10., 18.])

4.1.4 ‌除法‌:

z = x / y  # 对应元素相除
print(z)  # 输出: tensor([0.2500, 0.4000, 0.5000])

4.1.5 ‌幂运算‌:

z = x ** 2  # 每个元素求平方
print(z)  # 输出: tensor([1., 4., 9.])

4.2 ‌‌矩阵运算:

4.2.1 ‌‌矩阵乘法‌:

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
C = torch.mm(A, B)  # 或者使用 A @ B 在PyTorch 1.0及以上版本
print(C)  # 输出: tensor([[19, 22], [43, 50]])

4.2.2 ‌‌‌矩阵转置‌:

A_t = A.t()
print(A_t)  # 输出: tensor([[1, 3], [2, 4]])

4.2.3 ‌‌‌逐元素运算的广播机制:

PyTorch支持广播机制,当两个tensor的形状不完全相同时,较小的tensor会自动扩展以匹配较大的tensor的形状,然后进行逐元素运算。

x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor(2.0)  # 这是一个标量,将被广播到与x相同的形状
z = x + y
print(z)  # 输出: tensor([3., 4., 5.])

4.3 ‌‌‌‌统计运算:

4.3.1 ‌‌‌‌求和‌:

sum_x = x.sum()  # 对所有元素求和
print(sum_x)  # 输出: tensor(6.)

4.3.2 ‌‌‌‌平均值‌:

mean_x = x.mean()  # 对所有元素求平均值
print(mean_x)  # 输出: tensor(2.)

4.3.3 ‌‌‌‌最大值和最小值‌:

max_x, _ = x.max()  # 返回最大值及其索引(这里只关心最大值)
min_x, _ = x.min()  # 返回最小值及其索引(这里只关心最小值)
print(max_x)  # 输出: tensor(3.)
print(min_x)  # 输出: tensor(1.)

4.4 形状操作:

4.4.1 ‌‌‌‌改变形状‌:

x_reshaped = x.view(-1, 1)  # 将x改变为列向量
print(x_reshaped)  # 输出: tensor([[1.], [2.], [3.]])

4.4.2 ‌‌‌‌展平‌:

x_flattened = x.flatten()  # 将x展平为一维数组
print(x_flattened)  # 输出: tensor([1., 2., 3.])

5.理解神经网络

在开始之前,我觉得有必要提个醒, 神经网络本质上是数学,但我们仅仅作为开发者,我们要做的事情是通过pytorch或类似的工具实现我们想要的功能,达到我们想要的目的。至于为什么里面的数学公式是这样子,为什么神经网络的架构是这样搭,为什么他这个结构这么好,我们根本不用去管里面的最底层机理。我们只需要知道这样子这样子就能达到效果,就可以了。我们不要被这些数学公式吓跑了。

5.1 什么是分类,什么是回归

‌分类和回归是机器学习中两种基本的预测方法,它们的核心区别在于预测的输出类型‌。

5.1.1 分类

分类,简单来说,就是给数据打上标签,预测它属于哪一个类别。比如,我们有一堆邮件,需要判断哪些是垃圾邮件,哪些是正常邮件。这就是一个典型的二分类问题,因为输出只有两种可能:垃圾邮件或正常邮件。再比如,我们有一张动物图片,需要判断它是猫、狗还是其他动物,这就是一个多分类问题,因为输出有多个可能的类别‌。

举例:想象你手里有一堆水果,你需要把它们分成苹果和橙子两类,这就是分类任务。你要看水果的颜色、形状等特征,然后决定它是苹果还是橙子。

5.1.2 回归

回归,则是预测一个具体的数值,这个数值可以是任何实数。比如,我们想知道一套房子的价格,根据它的面积、位置等特征来预测。这就是一个回归问题,因为输出是一个连续的数值,而不是一个类别标签。再比如,我们想预测明天的温度或者股票的价格,这些都是回归问题‌。

举例:想象你要预测一辆汽车的价格,你会考虑它的品牌、型号、年份、里程数等特征,然后给出一个具体的价格预测,比如10万元、15万元等,这就是回归任务。
图片转自‘人工智能教学实践’
上图转自‘人工智能教学实践’博客,我觉得这张图很好地解析什么是分类,什么是回归,预测天气的是分类,预测温度是多少,是回归。

5.2 有什么函数可以实现分类和回归?答案:线性回归

5.2.1 从二元一次方程组到Simple Linear Regression

初中时,我们通过两个点的坐标求解二元一次方程。例如:已知直线y = ax + b 经过点(1,1)和(3,2),求解a, b的值。
在这里插入图片描述

解法是将坐标值带入方程,得到一个二元一次方程组,并对其进行求解:

在这里插入图片描述
坐标图如下:
在这里插入图片描述
以上二元一次方程组的求解可以看作一个简单的线性回归问题,估算出x和y之间的线性关系,得到公式:y = 1/2*x + 1/2

5.2.1 什么是线性关系(Linear Relationship)?

在这里插入图片描述
在这里插入图片描述
需要注明的是,线性关系不仅仅指存在于两个变量之间,它也可以存在于三个或者更多变量之间。比如y = a + bx1 + cx2,这条直线可以在三维空间中表达。
但实际情况是,我们在真实世界的数据不会完美的落在一个直线上,即使两个数据存在线性关系,它们或多或少离完美的直线都还有一些偏差。图像表示如下:
在这里插入图片描述
以上直线表达的是predictor和outcome之间近似的线性关系:y ≈ ax + b

5.3 线性回归到神经网络

可以看出,线性回归非常直观且易于实现,同时也过于简单,只能适用于比较简单的模型,如果模型稍显复杂,则不能很好地反映数据的分布。
在这里插入图片描述
例如在该图中,红色线段为我们想要拟合的图形,蓝色线段为我们的线性回归模型,无论怎样调整斜率w 或截距 b,都无法与红色较好地匹配。说明在此例中,线性函数模型过于简单,我们需要一个稍微复杂一些的模型。这里我们引入一个分段函数作为模型:
在这里插入图片描述
可以看出,当 x<x1 或 x>x2 时,函数值恒等于一个值,而当 x∈[x1,x2] 时,函数值则是呈线性变化的。该函数图像如下:
在这里插入图片描述
所以如果用这个分段函数来模拟上面的红色线段的话,可以采用下面几个步骤:
在这里插入图片描述
step 0:取常数 b 作为红色线段在 y 轴上的截距;
step 1:令分段函数中间部分的斜率和长度与红色线段第一部分相同(w1,b1);
step 2:令分段函数中间部分的斜率和长度与红色线段第二部分相同(w2,b2);
step 3:令分段函数中间部分的斜率和长度与红色线段第三部分相同(w3,b3)。
所以红色线段可用几个不同的蓝色线段表示为:
在这里插入图片描述
问题又来了,这种分段函数看着非常复杂,而且计算不方便,能否再使用一个近似的函数进行替换呢?于是这里又引入了 Sigmoid 函数:
在这里插入图片描述
Sigmoid 函数可以很好的表现上面的蓝色线段,而且是非线性的,没有线性的那么直(或者说hard),所以也把上面蓝色的图形叫做hard sigmoid。
于是,我们的模型又等价于:
在这里插入图片描述
这里 n 的大小,取决于我们要模拟多复杂的函数,n越大,意味着要模拟的函数越复杂。
在这里插入图片描述

5.3.1 理解新模型 - 简易版神经网络

对于我们的新模型
在这里插入图片描述
来说,可以看做是先进行线性计算,然后放入Sigmoid函数计算后,再加上常数b的过程。可用下图表示:
在这里插入图片描述
如此一来,原本稍显繁琐的公式一下子就显得直观了不少。而且整个图看起来与神经网络非常相似,线性函数 wx+b 为输入层,Sigmoid 可看做隐藏层,加总后的y可看做输出层。如果让模型变得更复杂点,就更像了:
在这里插入图片描述
这里,我们增加了输入样本的复杂度(由单一连接变成了全连接),并且增加了多层 Sigmoid 函数,虽然模型整体更加复杂,但本质上还是没有变。所以我们完全可以从一个简单的线性模型过渡到一个复杂的神经网络。
再来看公式,看到这个加总符号 ∑,就说明里面进行的都是一系列相似的计算,所以用向量替换比较合适。
如果先约定好,统一使用列向量来表示数据,并使用 σ(x) 表示 Sigmoid 函数,则以上数据及参数可表示为:
在这里插入图片描述
然后我们就可以用这些向量来表示模型的公式了:
在这里插入图片描述

使用这种表示方式,不仅显得更简洁,而且计算速度也非常快,特别是当样本非常多的情况下。

6.定义网络结构

如果觉得上面讲解还是有点复杂,我们可以将神经网络简单粗暴地理解成是多对多的线性方程,比如说,我现在有个784像素的图片,图片里都是0-9的单独数字,想通过一张图片,预测是哪个数字。

输入: 784个参数
输出: 10个预测的概率

哪个概率最大哪个就是预测的数字。这就是我们要搭建的结构。
之前提过了, 一个神经网络中间是有几个隐藏层。而且隐藏层的参数个数我们可以自己定义,那上面的结构就可以这样表示。

输入: 784个参数
隐藏层1: 555个参数
隐藏层2: 888个参数
输出: 10个预测的概率

隐藏层的参数个数我们可以自己定义,这里隐藏层1参数个数我随便设置成555,隐藏层2参数个数设置成888.
但层与层之前的我们通过什么的方式连接呢?我们可以通过in_features, out_features的数值来设置,比如说这一层的输入的参数个数,是上一层的输出参数的个数, 这一层的输出的参数个数,是下一成的输入的参数个数。

输入: in_features = 784,out_features = 555
隐藏层1: in_features = 555,out_features = 888
隐藏层2: in_features = 888,out_features = 888
输出: in_features = 888,out_features = 10

到这里,大体的网络框架大概搭建了,但刚才说到了除了输入和输出外,我们还需要一个激活函数,激活函数我们之前学到了Sigmoid ,让线性的函数变得更加符合现实化,非线性化。也就是:

输入: in_features = 784,out_features = 555
激活函数: Sigmoid()
隐藏层1: in_features = 555,out_features = 888
激活函数: Sigmoid()
隐藏层2: in_features = 888,out_features = 888
激活函数: Sigmoid()
输出: in_features = 888,out_features = 10

到这里的话其实神经网络已经体现出来了,但这个输出的值是一个正无穷到负无穷的值。它还不是一个概率值。

比如: [[-0.0575, -0.0059, 0.0094, 0.0205, -0.0239, 0.0034, -0.0519, -0.0335, 0.0502, 0.0181]]

我们还需要对他进行归一化。我们会用到一个函数叫Softmax(), 这个函数的作用就是对这些正无穷到负无穷的值进行处理,使其变成0-1之间的数,变成0-1之间的数后,我们就可以大胆地称之为概率。
在这里插入图片描述
在这里插入图片描述
ok, 那经过上面的套路,总结起来,神经网络就是

输入: in_features = 784,out_features = 555
激活函数: Sigmoid()
隐藏层1: in_features = 555,out_features = 888
激活函数: Sigmoid()
隐藏层2: in_features = 888,out_features = 888
激活函数: Sigmoid()
输出: in_features = 888,out_features = 10
归一化: Softmax()

到此为止,我们已经手撕了一个简单的神经网络,回到我们的主题,如何通过Pytorch搭建神经网络。Here we go!

import torch
import torch.nn as nn#784个像素点构成的灰度图->函数->10个概念(0,1,2,3,4,5,6,7,8,9)#输入层 in_channel=784 out_channel=555
#隐藏层1 in_channel=555 out_channel=888
#隐藏层2 in_channel=888 out_channel=888
#输出层 in_channel=888 out_channel=10data = torch.rand(1, 784)model = nn.Sequential(nn.Linear(784, 555),nn.Sigmoid(),nn.Linear(555, 888),nn.Sigmoid(),nn.Linear(888, 888),nn.Sigmoid(),nn.Linear(888, 10),nn.Softmax()
)predict = model(data)
print(model)
print(predict)

在这里插入图片描述
Pytorch已经帮我们封装好了函数,我们更重要的是理解然后手撕了一个简单的神经网络,然后通过Pytorch提供的函数进行调用了完事了。

7.整理数据集

我们可以从kaggle比赛中下载数据集。
https://www.kaggle.com/competitions/digit-recognizer/data

在这里插入图片描述
在这里插入图片描述
又或者到我的CSDN资源直接下载: kaggle手写数字识别竞赛中用到的数据集


下载完后,我们主要来看train.csv
在这里插入图片描述
从train.csv可以看到,第一列是标签, 其他列就是像素,0,1,2。。。。783
我们需要python中pandas的工具包来帮我们分解train.csv,并且用来隔开特征和标签。

import pandas as pdraw_df = pd.read_csv('dataset/train.csv')
# 标签
label = raw_df['label'].values
# 特征
feature = raw_df.drop(['label'], axis=1)

特征这一列中,我们先删除label这一列,然后保留其他列,剩下的就是特征。
分割完标签和特征后,我们还要分测试集和训练集。我采取4:1的分割方式。

import torch
import torch.nn as nn
import pandas as pdraw_df = pd.read_csv('dataset/train.csv')
# 标签
label = raw_df['label'].values
# 特征
feature = raw_df.drop(['label'], axis=1).values#整个数据集划分成两个数据集 训练集 测试集, 4:1train_feature = feature[:int(len(feature)*0.8)]
train_label = label[:int(len(label)*0.8)]
test_feature = feature[int(len(feature)*0.8):]
test_label = label[int(len(label)*0.8):]print(len(train_feature),len(train_label),len(test_feature),len(test_label))

打印结果

33600 33600 8400 8400

至此,数据准备完成。

8.训练模型

模型架构有了, 数据也准备好了,我们可以训练起来了。
说到训练,我们就要进行梯度下降。找到一组合适的w和b,让损失值越小越好。
优先我们要准备一个损失函数,交叉熵损失函数。用来计算损失值。

lossfunction = nn.CrossEntropyLoss()

计算完损失值,我们下一步就是优化里面的w和b, pytorch给我们提供了很多优化函数,我们用Adam就足够了。

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

import torch
import torch.nn as nn
import pandas as pdraw_df = pd.read_csv('dataset/train.csv')
# 标签
label = raw_df['label'].values
# 特征
feature = raw_df.drop(['label'], axis=1).values#整个数据集划分成两个数据集 训练集 测试集, 4:1train_feature = feature[:int(len(feature)*0.8)]
train_label = label[:int(len(label)*0.8)]
test_feature = feature[int(len(feature)*0.8):]
test_label = label[int(len(label)*0.8):]train_feature = torch.tensor(train_feature).to(torch.float)
train_label = torch.tensor(train_label)
test_feature = torch.tensor(test_feature).to(torch.float)
test_label = torch.tensor(test_label)print(len(train_feature),len(train_label),len(test_feature),len(test_label))#784个像素点构成的灰度图->函数->10个概念(0,1,2,3,4,5,6,7,8,9)#输入层 in_channel=784 out_channel=555
#隐藏层1 in_channel=555 out_channel=888
#隐藏层2 in_channel=888 out_channel=888
#输出层 in_channel=888 out_channel=10data = torch.rand(1, 784)model = nn.Sequential(nn.Linear(784, 555),nn.Sigmoid(),nn.Linear(555, 888),nn.Sigmoid(),nn.Linear(888, 888),nn.Sigmoid(),nn.Linear(888, 10),nn.Softmax()
)# predict = model(data)
# print(model)
# print(predict)# 梯度下降。找到一组合适的w和b,让损失值越小越好。
lossfunction = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)# 训练的轮数
for i in range(100):# 清空优化器的梯度(偏导)optimizer.zero_grad()predict = model(train_feature)loss = lossfunction(predict, train_label)loss.backward()optimizer.step()print(loss.item())

至此,简单的神经网络训练程序已经搭建出来了, 我们可以开始运行,打印损失值。
在这里插入图片描述
我们可以看到损失值越来越小。意味着我们预测的准确率越来越高。
其实到这里我们已经结束了,但我们还想看看准确率,那我们可以在中间打印一下。首先

predict = model(train_feature)

这个predict 是一个数组,是包含所有概率的数组,我们简单理解成这样的数据。

0.3,0.1,0.2,0.1,0.0,0.0,0.0,0.0,0.0,0.0
0.1,0.1,0.2,0.1,0.0,0.0,0.0,0.0,0.5,0.0
0.2,0.1,0.2,0.1,0.0,0.0,0.0,3.0,0.3,0.0
0.1,0.1,0.2,0.1,0.0,0.0,0.0,0.1,0.0,0.0
0.1,0.1,0.2,0.1,0.0,0.0,0.0,0.3,0.0,0.0

我们首先要是每一行里面找到最大值,然后拿出它对应的数字,这里的数字其实就是对应的索引

result = torch.argmax(predict, axis=1)

然后准确率我们可以用这个result和train_label进行对比

train_acc = torch.mean((result==train_label).to(torch.float))

打印结果再看看

# 训练的轮数
for i in range(100):# 清空优化器的梯度(偏导)optimizer.zero_grad()predict = model(train_feature)result = torch.argmax(predict, axis=1)train_acc = torch.mean((result == train_label).to(torch.float))loss = lossfunction(predict, train_label)loss.backward()optimizer.step()print('train loss:{} train acc:{}'.format(loss.item(), train_acc.item()))

在这里插入图片描述

损失值越来越小,准确率越来越高,这就是我们想要的结果。

9.保存模型

torch.save(model.state_dict(), '/mymodel.pt')

就这么简单。model.state_dict()就是训练完的w和b, '/mymodel.pt’就是要保存的路径和名称。

10.加载模型

这里我新建了一个python文件check_model.py

parameters = torch.load('/mymodel.pt')

先拿到之前模型的参数。
但我们要先把之前定义的网络结构搬回来。

model = nn.Sequential(nn.Linear(784, 555),nn.Sigmoid(),nn.Linear(555, 888),nn.Sigmoid(),nn.Linear(888, 888),nn.Sigmoid(),nn.Linear(888, 10),nn.Softmax()
)

然后再将参数塞进model

model.load_state_dict(parameters)

然后我们再用pandas加载数据, 拿出对应的标签和我们预测的值进行对比。

import torch
import torch.nn as nn
import pandas as pdmodel = nn.Sequential(nn.Linear(784, 555),nn.Sigmoid(),nn.Linear(555, 888),nn.Sigmoid(),nn.Linear(888, 888),nn.Sigmoid(),nn.Linear(888, 10),nn.Softmax()
)parameters = torch.load('/mymodel.pt')model.load_state_dict(parameters)raw_df = pd.read_csv('dataset/train.csv')
# 标签
label = raw_df['label'].values
# 特征
feature = raw_df.drop(['label'], axis=1).values
test_feature = feature[int(len(feature)*0.8):]
test_label = label[int(len(label)*0.8):]
test_feature = torch.tensor(test_feature).to(torch.float)
test_label = torch.tensor(test_label)new_test_feature = test_feature[100:111]
new_test_label = test_label[100:111]predict = model(new_test_feature)
result = torch.argmax(predict, axis=1)
print(new_test_label)
print(result)

在这里插入图片描述
预测结果正常,有六个预测正确,因为之前的trainning里面只有0.5左右的正确率,有可能训练的次数有关,训练次数太少。所以这个结果符合我们预期。

11.使用GPU加速

有人会问,我们已经装好了GPU版本的pytorch,为什么还要GPU加速。因为我们默认是用CPU运行,我们需要对代码再进行GPU调用。
使用GPU加速主要可以用两个方向,model方向和数据方向

11.1 要将model(model里面的参数w和b)从内存放到显存。

model = model.cuda() 

11.2 把加载的数据从内存加载到显存

train_feature = torch.tensor(train_feature).to(torch.float).cuda() 
train_label = torch.tensor(train_label).cuda() 
test_feature = torch.tensor(test_feature).to(torch.float).cuda() 
test_label = torch.tensor(test_label).cuda() 

11.3 只要将模型的数据,测试数据和训练数据加载到显存,自然会使用GPU进行处理。

相关文章:

Pytorch基础教程:从零实现手写数字分类

文章目录 1.Pytorch简介2.理解tensor2.1 一维矩阵2.2 二维矩阵2.3 三维矩阵 3.创建tensor3.1 你可以直接从一个Python列表或NumPy数组创建一个tensor&#xff1a;3.2 创建特定形状的tensor3.3 创建三维tensor3.4 使用随机数填充tensor3.5 指定tensor的数据类型 4.tensor基本运算…...

使用Flink-JDBC将数据同步到Doris

在现代数据分析和处理环境中&#xff0c;数据同步是一个至关重要的环节。Apache Flink和Doris是两个强大的工具&#xff0c;分别用于实时数据处理和大规模并行处理&#xff08;MPP&#xff09;SQL数据库。本文将介绍如何使用Flink-JDBC连接器将数据同步到Doris。 一、背景介绍…...

【深度学习】自编码器(Autoencoder, AE)

自编码器&#xff08;Autoencoder, AE&#xff09;是一种无监督学习模型&#xff0c;主要用于特征提取、数据降维、去噪和生成模型等任务。它的核心思想是通过将输入压缩到一个低维的潜在空间表示&#xff08;编码过程&#xff09;&#xff0c;然后再从这个潜在表示重构输入&am…...

跨专业毕业论文写作

跨专业毕业论文写作是一项具有挑战性的任务&#xff0c;但通过合理的规划和方法&#xff0c;你可以顺利完成这篇论文。以下是一些关键步骤和建议&#xff0c;帮助你撰写一篇高质量的跨专业毕业论文。 一、确定研究方向和课题 选择与本科专业相关或感兴趣的研究方向&#xff1a;…...

在 Go语言中一个字段可以包含多种类型的值的设计与接种解决方案

在 Go 中&#xff0c;如果你希望一个字段可以包含多种类型的值&#xff0c;你可以使用以下几种方式来实现&#xff1a; ### 1. **使用空接口 (interface{})** Go 的空接口 interface{} 可以接受任何类型的值&#xff0c;因此&#xff0c;你可以将字段定义为一个空接口&#x…...

为AI聊天工具添加一个知识系统 之32 三“中”全“会”:推理式的ISA(父类)和IOS(母本)以及生成式CMN (双亲委派)之1

本文要点和问题 要点 三“中”全“会”&#xff1a;推理式的ISA的&#xff08;父类-父类源码&#xff09;和IOS的&#xff08;母本-母类脚本&#xff09;以及生成式 CMN &#xff08;双亲委派-子类实例&#xff09;。 数据中台三端架构的中间端(信息系统架构ISA &#xff1a…...

手撕Transformer -- Day6 -- DecoderBlock

手撕Transformer – Day6 – DecoderBlock 目录 手撕Transformer -- Day6 -- DecoderBlockTransformer 网络结构图DecoderBlock 代码Part1 库函数Part2 实现一个解码器Block&#xff0c;作为一个类Part3 测试 参考 Transformer 网络结构图 Transformer 网络结构 DecoderBlock 代…...

Docker常用命令大全

Docker容器相关命令&#xff1a; 创建并启动容器&#xff1a; docker run&#xff1a;创建一个新的容器并运行一个命令。例如&#xff1a;docker run -d -p 8080:80 nginx这将后台(-d)运行一个Nginx容器&#xff0c;并映射宿主机的8080端口到容器的80端口。 列出容器&#x…...

【Linux探索学习】第二十五弹——动静态库:Linux 中静态库与动态库的详细解析

Linux学习笔记&#xff1a; https://blog.csdn.net/2301_80220607/category_12805278.html?spm1001.2014.3001.5482 前言&#xff1a; 在 Linux 系统中&#xff0c;静态库和动态库是开发中常见的两种库文件类型。它们在编译、链接、内存管理以及程序的性能和可维护性方面有着…...

Vue 实现当前页面刷新的几种方法

以下是 Vue 中实现当前页面刷新的几种方法&#xff1a; 方法一&#xff1a;使用 $router.go(0) 方法 通过Vue Router进行重新导航&#xff0c;可以实现页面的局部刷新&#xff0c;而不丢失全局状态。具体实现方式有两种&#xff1a; 实现代码&#xff1a; <template&g…...

python mysql库的三个库mysqlclient mysql-connector-python pymysql如何选择,他们之间的区别

三者的区别 1. mysqlclient 特点&#xff1a; 是一个用于Python的MySQL数据库驱动程序&#xff0c;用于与MySQL数据库进行交互。 依赖于MySQL的本地库&#xff0c;因此在安装时需要确保系统上已安装了必要的依赖项&#xff0c;如libmysqlclient-dev等。 性能较好&#xff0c…...

【可持久化线段树】 [SDOI2009] HH的项链 主席树(两种解法)

文章目录 1.题目描述2.思路3.解法一解法一代码 4.解法二解法二代码&#xff08;版本一&#xff09;解法二代码&#xff08;版本二&#xff09; 1.题目描述 原题&#xff1a;https://www.luogu.com.cn/problem/P1972 [SDOI2009] HH的项链 题目描述 HH 有一串由各种漂亮的贝壳…...

【C语言】线程----同步、互斥、条件变量

目录 3. 同步 3.1 概念 3.2 同步机制 3.3 函数接口 1. 同步 1.1 概念 同步(synchronization)指的是多个任务(线程)按照约定的顺序相互配合完成一件事情 1.2 同步机制 通过信号量实现线程间的同步 信号量&#xff1a;通过信号量实现同步操作&#xff1b;由信号量来决定…...

15. 三数之和【力扣】--三指针

三数之和 已解答 中等 相关标签 相关企业 提示 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意&#x…...

大数据学习(35)- spark- action算子

&&大数据学习&& &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 承认自己的无知&#xff0c;乃是开启智慧的大门 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4dd;支持一下博主哦&#x1f91…...

vim使用指南

&#x1f3dd;️专栏&#xff1a;计算机操作系统 &#x1f305;主页&#xff1a;猫咪-9527-CSDN博客 “欲穷千里目&#xff0c;更上一层楼。会当凌绝顶&#xff0c;一览众山小。” 目录 一、Vim 的基本概念 1.Vim 的主要模式&#xff1a; 1.1普通模式 (Normal Mode) 1.2插入…...

Docker 镜像制作原理 做一个自己的docker镜像

一.手动制作镜像 启动容器进入容器定制基于容器生成镜像 1.启动容器 启动容器之前我们首先要有一个镜像&#xff0c;这个镜像可以是从docker拉取&#xff0c;例如&#xff1a;现在pull一个ubuntu镜像到本机。 docker pull ubuntu:22.04 我们接下来可以基于这个容器进行容器…...

基于Java+SpringBoot+Vue的前后端分离的在线BLOG网

基于JavaSpringBootVue的前后端分离的在线BLOG网 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末附源码下载链接&#x1f3…...

Linux网络_套接字_UDP网络_TCP网络

一.UDP网络 1.socket()创建套接字 #include<sys/socket.h> int socket(int domain, int type, int protocol);domain (地址族): AF_INET网络 AF_UNIX本地 AF_INET&#xff1a;IPv4 地址族&#xff0c;适用于 IPv4 协议。用于网络通信AF_INET6&#xff1a;IPv6 地址族&a…...

Java学习教程,从入门到精通,JDBC驱动程序类型及语法知识点(91)

JDBC驱动程序类型及语法知识点 一、JDBC驱动程序类型 JDBC驱动程序主要有以下四种类型&#xff1a; 1. Type 1&#xff1a;JDBC - ODBC桥驱动程序&#xff08;JDBC - ODBC Bridge Driver&#xff09; 特点&#xff1a;这种驱动程序是Java与ODBC&#xff08;Open Database C…...

YOLOv8从菜鸟到精通(二):YOLOv8数据标注以及模型训练

数据标注 前期准备 先打开Anaconda Navigator&#xff0c;点击Environment&#xff0c;再点击new(new是我下载anaconda的文件夹名称)&#xff0c;然后点击创建 点击绿色按钮&#xff0c;并点击Open Terminal 输入labelimg便可打开它,labelimg是图像标注工具&#xff0c;在上篇…...

3D目标检测数据集——Nusence数据集

链接地址 [官网] nuScenes[arXiv] nuScenes: A multimodal dataset for autonomous driving[GitHub] nuScenes devkitnuScenes devkit教程数据集概述 2.1 数据采集 2.1.1 传感器配置 nuScenes的数据采集车辆为Renault Zoe迷你电动车,配备6个周视相机&#x...

网站收录入口提交的方法有哪些(网站收录的方式都有哪些)

网站被搜索引擎收录是获得流量和曝光的重要前提&#xff0c;以下为你介绍常见的网站收录方式&#xff1a; 搜索引擎提交入口 各大搜索引擎都设有专门的网站收录入口&#xff0c;供站长提交网站。例如百度搜索资源平台、谷歌搜索控制台等。以百度为例&#xff0c;在百度搜索资…...

移动端H5缓存问题

移动端页面缓存问题是指页面的静态资源&#xff08;如图片、JS 和 CSS 文件&#xff09;在浏览器中被缓存后&#xff0c;用户在下次访问时可以直接从本地获取缓存数据&#xff0c;而不需要每次都从服务器重新获取&#xff0c;不过这样可能会导致页面不能正确地更新或者加载最新…...

11-1.Android 项目结构 - androidTest 包与 test 包(单元测试与仪器化测试)

androidTest 包与 test 包 在 Android 项目中&#xff0c;androidTest 包与 test 包用于存放不同类型的测试代码的 1、测试类型 &#xff08;1&#xff09;androidTest 包 主要用于存放单元测试&#xff08;Unit Tests&#xff09;代码 单元测试是针对应用程序中的独立模块…...

计算机网络(五)——传输层

一、功能 传输层的主要功能是向两台主机进程之间的通信提供通用的数据传输服务。功能包括实现端到端的通信、多路复用和多路分用、差错控制、流量控制等。 复用&#xff1a;多个应用进程可以通过同一个传输层发送数据。 分用&#xff1a;传输层在接收数据后可以将这些数据正确分…...

ZCC9159 -7V 300mA 超低功耗高速 LDO

功能描述 ZCC9195是一款超低功耗并具有快速响应、关断快速放电功能的高速LDO。静态电流低至 0.8uA&#xff0c;输出电流最大为300mA。 ZCC9195具有输出过流保护、输出短路保护、温度保护等功能&#xff0c;确保芯片在异常工作条件 下不会损坏。 ZCC9195只需要1uF的陶瓷电容即…...

微信小程序实现个人中心页面

文章目录 1. 官方文档教程2. 编写静态页面3. 关于作者其它项目视频教程介绍 1. 官方文档教程 https://developers.weixin.qq.com/miniprogram/dev/framework/ 2. 编写静态页面 mine.wxml布局文件 <!--index.wxml--> <navigation-bar title"个人中心" ba…...

【C语言算法刷题】第7题

题目描述 一个XX产品行销总公司&#xff0c;只有一个boss&#xff0c;其有若干一级分销&#xff0c;一级分销又有若干二级分销&#xff0c;每个分销只有唯一的上级分销。 规定&#xff0c;每个月&#xff0c;下级分销需要将自己的总收入&#xff08;自己的下级上交的&#xf…...

BERT与CNN结合实现糖尿病相关医学问题多分类模型

完整源码项目包获取→点击文章末尾名片&#xff01; 使用HuggingFace开发的Transformers库&#xff0c;使用BERT模型实现中文文本分类&#xff08;二分类或多分类&#xff09; 首先直接利用transformer.models.bert.BertForSequenceClassification()实现文本分类 然后手动实现B…...

RocketMQ消息发送---源码解析

我们知道rocketMQ的消息发送支持很多特性&#xff0c;如同步发送&#xff0c;异步发送&#xff0c;oneWay发送&#xff0c;也支持超时机制&#xff0c;回调机制&#xff0c;并且能够保证消息的可靠性和消息发送的限流&#xff0c;底层使用netty框架等等&#xff0c;如此多的特性…...

机器学习06-正则化

机器学习06-正则化 文章目录 机器学习06-正则化0-核心逻辑脉络1-参考网址3-大模型训练中的正则化1.正则化的定义与作用2.常见的正则化方法及其应用场景2.1 L1正则化&#xff08;Lasso&#xff09;2.2 L2正则化&#xff08;Ridge&#xff09;2.3 弹性网络正则化&#xff08;Elas…...

如何开放2375和2376端口供Docker daemon监听

Linux (以 Ubuntu 为例) 1. 修改 Docker 配置文件 打开 Docker 的配置文件 /etc/docker/daemon.json。如果该文件不存在&#xff0c;则可以创建一个新的。 bash sudo nano /etc/docker/daemon.json在配置文件中添加以下内容&#xff1a; json {"hosts": ["un…...

Vue.js组件开发-如何实现路由懒加载

在Vue.js应用中&#xff0c;路由懒加载是一种优化性能的技术&#xff0c;它允许在需要时才加载特定的路由组件&#xff0c;而不是在应用启动时加载所有组件。这样可以显著减少初始加载时间&#xff0c;提高用户体验。在Vue Router中&#xff0c;实现路由懒加载非常简单&#xf…...

rclone,云存储备份和迁移的瑞士军刀,千字常文解析,附下载链接和安装操作步骤...

一、什么是rclone&#xff1f; rclone是一个命令行程序&#xff0c;全称&#xff1a;rsync for cloud storage。是用于将文件和目录同步到云存储提供商的工具。因其支持多种云存储服务的备份&#xff0c;如Google Drive、Amazon S3、Dropbox、Backblaze B2、One Drive、Swift、…...

集成学习算法

目录 1.必要的导入 2.Bagging集成 3.基于matplotlib写一个函数对决策边界做可视化 4.总结图中结论 5.扩展说明 1.必要的导入 # To support both python 2 and python 3 from __future__ import division, print_function, unicode_literals# Common imports import numpy as np…...

vue3之pinia学习

最近查看了pinia这个状态管理管理&#xff0c;想跟大家一起学习下&#xff0c;下面是我的个人理解&#xff0c;希望对大家有帮助&#xff0c;我们开始吧&#xff01; 第一步&#xff1a;安装pinia npm install pinia 第二步&#xff1a;创建pinia <script setup langts&…...

Flink (七): DataStream API (四) Watermarks

1. Event Time and Processing Time 1. 1 处理时间&#xff08;Processing time&#xff09; 处理时间是指执行相应操作的机器的系统时间。当流处理程序基于处理时间运行时&#xff0c;所有基于时间的操作&#xff08;如时间窗口&#xff09;将使用执行相应算子的机器的系统时…...

卷积神经05-GAN对抗神经网络

卷积神经05-GAN对抗神经网络 使用Python3.9CUDA11.8Pytorch实现一个CNN优化版的对抗神经网络 简单的GAN图片生成 CNN优化后的图片生成 优化模型代码对比 0-核心逻辑脉络 1&#xff09;Anacanda使用CUDAPytorch2&#xff09;使用本地MNIST进行手写图片训练3&#xff09;…...

【原创】大数据治理入门(2)《提升数据质量:质量评估与改进策略》入门必看 高赞实用

提升数据质量&#xff1a;质量评估与改进策略 引言&#xff1a;数据质量的概念 在大数据时代&#xff0c;数据的质量直接影响到数据分析的准确性和可靠性。数据质量是指数据在多大程度上能够满足其预定用途&#xff0c;确保数据的准确性、完整性、一致性和及时性是数据质量的…...

GLM: General Language Model Pretraining with Autoregressive Blank Infilling论文解读

论文地址&#xff1a;https://arxiv.org/abs/2103.10360 参考&#xff1a;https://zhuanlan.zhihu.com/p/532851481 GLM混合了自注意力和masked注意力&#xff0c;而且使用了2D位置编码。第一维的含义是在PartA中的位置&#xff0c;如5 5 5。第二维的含义是在Span内部的位置&a…...

总结SpringBoot项目中读取resource目录下的文件多种方法

系列文章目录 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 例如&#xff1a;第一章 Python 机器学习入门之pandas的使用 提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目…...

云原生第四次作业

下载 [rootopenEuler-1 ~]# wget https://archive.apache.org/dist/httpd/httpd-2.4.46.tar.gz 压缩 配置实验环境 [rootopenEuler-1 httpd-2.4.46]# yum -y install apr apr-devel cyrus-sasl-devel expat-devel libdb-devel openldap-devel apr-util-devel apr-util pcre-d…...

day10_Structured Steaming

文章目录 Structured Steaming一、结构化流介绍&#xff08;了解&#xff09;1、有界和无界数据2、基本介绍3、使用三大步骤(掌握)4.回顾sparkSQL的词频统计案例 二、结构化流的编程模型&#xff08;掌握&#xff09;1、数据结构2、读取数据源2.1 File Source2.2 Socket Source…...

设计模式-工厂模式/抽象工厂模式

工厂模式 定义 定义一个创建对象的接口&#xff0c;让子类决定实列化哪一个类&#xff0c;工厂模式使一个类的实例化延迟到其子类&#xff1b; 工厂方法模式是简单工厂模式的延伸。在工厂方法模式中&#xff0c;核心工厂类不在负责产品的创建&#xff0c;而是将具体的创建工作…...

【算法学习】——整数划分问题详解(动态规划)

&#x1f9ee;整数划分问题是一个较为常见的算法题&#xff0c;很多问题从整数划分这里出发&#xff0c;进行包装&#xff0c;形成新的题目&#xff0c;所以完全理解整数划分的解决思路对于之后的进一步学习算法是很有帮助的。 「整数划分」通常使用「动态规划」解决&#xff0…...

【新教程】Ubuntu 24.04 单节点安装slurm

背景 网上教程老旧&#xff0c;不适用。 详细步骤 1、安装slurm sudo apt install slurm-wlm slurm-wlm-doc -y检查是否安装成功&#xff1a; slurmd --version如果得到slurm-wlm 23.11.4&#xff0c;表明安装成功。 2、配置slurm。 使用命令&#xff1a; sudo vi /etc/s…...

window下用vim

Windows 默认不支持 vim 命令&#xff0c;需要手动安装后才能使用。以下是解决方案&#xff1a; 1. 安装 Vim 编辑器 方法 1&#xff1a;通过 Scoop 或 Chocolatey 安装 使用 Scoop&#xff1a; 安装 Scoop&#xff08;如果尚未安装&#xff09;&#xff1a;iwr -useb get.sco…...

citrix netscaler13.1 重写负载均衡响应头(基础版)

在 Citrix NetScaler 13.1 中&#xff0c;Rewrite Actions 用于对负载均衡响应进行修改&#xff0c;包括替换、删除和插入 HTTP 响应头。这些操作可以通过自定义策略来完成&#xff0c;帮助你根据需求调整请求内容。以下是三种常见的操作&#xff1a; 1. Replace (替换响应头)…...

使用PWM生成模式驱动BLDC三相无刷直流电机

引言 在 TI 的无刷直流 (BLDC) DRV8x 产品系列使用的栅极驱动器应用中&#xff0c;通常使用一些控制模式来切换MOSFET 开关的输出栅极。这些控制模式包括&#xff1a;1x、3x、6x 和独立脉宽调制 (PWM) 模式。   不过&#xff0c;DRV8x 产品系列&#xff08;例如 DRV8311&…...