当前位置: 首页 > news >正文

从零推导线性回归:最小二乘法与梯度下降的数学原理


欢迎来到我的主页:【Echo-Nie】

本篇文章收录于专栏【机器学习】

在这里插入图片描述

本文所有内容相关代码都可在以下仓库中找到:
Github-MachineLearning


1 线性回归

1.1 什么是线性回归

线性回归是一种用来预测分析数据之间关系的工具。它的核心思想是找到一条直线(或者一个平面),让这条直线尽可能地“拟合”已有的数据点,通过这条直线,我们可以预测新的数据。

eg:

假设你想预测房价,你知道房子的大小(面积)和房价之间有关系。线性回归可以帮助你找到一条直线,表示“房子越大,房价越高”的关系。

这条直线的方程可以写成:
y = θ 0 + θ 1 x y = \theta_0 + \theta_1 x y=θ0+θ1x
其中:

  • y y y 是房价(目标变量)。
  • x x x 是房子的大小(特征)。
  • θ 0 \theta_0 θ0 是截距(当房子大小为 0 时的房价)。
  • θ 1 \theta_1 θ1 是斜率(房子每增加一平米,房价增加多少)。

简单线性回归:

如果只有一个特征(比如房子的大小),线性回归就是找到一条直线来拟合数据。公式是:
y = θ 0 + θ 1 x y = \theta_0 + \theta_1 x y=θ0+θ1x

  • 目标:找到 θ 0 \theta_0 θ0 θ 1 \theta_1 θ1,使得这条直线最接近所有的数据点。
  • 怎么找?通过调整 θ 0 \theta_0 θ0 θ 1 \theta_1 θ1,让预测值 y y y 和实际值之间的误差最小。

多元线性回归:

如果有多个特征(比如房子的大小、房间数量、地段等),线性回归就是找到一个“平面”来拟合数据。公式是:
y = θ 0 + θ 1 x 1 + θ 2 x 2 + ⋯ + θ n x n y = \theta_0 + \theta_1 x_1 + \theta_2 x_2 + \dots + \theta_n x_n y=θ0+θ1x1+θ2x2++θnxn
这里:

x i x_i xi 是多个特征(比如房子大小、房间数量等); θ i \theta_i θi 是模型的参数,表示每个特征对房价的影响

  1. 拟合数据:线性回归就像在一堆散点图中画一条直线,让这条直线尽可能靠近所有的点。
  2. 预测:有了这条直线后,如果有一个新的数据点(比如一个新房子的大小),我们可以用这条直线来预测它的房价。
  3. 误差:预测值和实际值之间的差距叫误差。线性回归的目标是让所有数据的误差最小。

1.2 年龄和金钱举例

x 1 x_1 x1 x 2 x_2 x2就是我们的两个特征(年龄,工资); Y Y Y是银行最终会借给我们多少钱。找到最合适的一条线(想象一个高维)来最好的拟合我们的数据点。

假设 θ 1 \theta_1 θ1是年龄的参数, θ 2 \theta_2 θ2是工资的参数

  1. 拟合平面
    h θ ( x ) = θ 0 + θ 1 x 1 + θ 2 x 2 h_{\theta}(x) = \theta_0 + \theta_1 x_1 + \theta_2 x_2 hθ(x)=θ0+θ1x1+θ2x2
    这里, θ 0 \theta_0 θ0 是偏置项, θ 1 \theta_1 θ1 θ 2 \theta_2 θ2 是权重, x 1 x_1 x1 x 2 x_2 x2 是输入特征。

  2. 引入 x 0 = 1 x_0 = 1 x0=1
    为了将偏置项 θ 0 \theta_0 θ0 也纳入向量化的形式,我们引入一个额外的特征 x 0 x_0 x0,并设 x 0 = 1 x_0 = 1 x0=1。相当于我们做数据预处理,矩阵最外围加了一圈1,方便计算。这样,公式可以写成:
    h θ ( x ) = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 h_{\theta}(x) = \theta_0 x_0 + \theta_1 x_1 + \theta_2 x_2 hθ(x)=θ0x0+θ1x1+θ2x2

  3. 向量化表示
    现在,我们可以将权重 θ \theta θ 和特征 x x x 表示为向量:
    θ = [ θ 0 θ 1 θ 2 ] , x = [ x 0 x 1 x 2 ] \theta = \begin{bmatrix} \theta_0 \\ \theta_1 \\ \theta_2 \end{bmatrix}, \quad x = \begin{bmatrix} x_0 \\ x_1 \\ x_2 \end{bmatrix} θ= θ0θ1θ2 ,x= x0x1x2
    这样,原始公式可以表示为这两个向量的点积:
    h θ ( x ) = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 = θ T x h_{\theta}(x) = \theta_0 x_0 + \theta_1 x_1 + \theta_2 x_2 = \theta^T x hθ(x)=θ0x0+θ1x1+θ2x2=θTx
    其中, θ T \theta^T θT 是向量 θ \theta θ 的转置。

  4. 一般化到 n n n 个特征

    如果有 n n n 个特征,公式可以推广为:
    h θ ( x ) = ∑ i = 0 n θ i x i = θ T x h_{\theta}(x) = \sum_{i=0}^{n} \theta_i x_i = \theta^T x hθ(x)=i=0nθixi=θTx
    这里, θ \theta θ x x x 都是 n + 1 n+1 n+1 维的向量。

  5. 误差表示:

    误差真实值和预测值之间肯定是要存在差异的(用 ε \varepsilon ε 来表示该误差)。

对于每个样本:
y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i)

参数意义:

  • y ( i ) y^{(i)} y(i) 是第 i i i 个样本的真实值。
  • θ T x ( i ) \theta^T x^{(i)} θTx(i) 是模型对第 i i i 个样本的预测值。
  • ε ( i ) \varepsilon^{(i)} ε(i) 是第 i i i 个样本的误差,表示真实值和预测值之间的差异。

1.3 误差

误差是模型预测值和真实值之间的差距。比如你预测房价是 10 万,但实际房价是 15 万,那么误差就是 5 万。

误差 ε ( i ) \varepsilon^{(i)} ε(i) 是独立并且具有相同的分布,并且服从均值为 0 0 0、方差为 σ 2 \sigma^2 σ2 的高斯分布。

  • 独立:每个数据的误差是独立的,不会互相影响。比如张三贷款多了,不会影响李四的贷款。
  • 同分布:所有数据的误差都来自同一个“规律”。比如所有贷款误差都来自同一家银行的规则。
  • 高斯分布:误差大多数情况下比较小,偶尔会有一些大的误差,但这种情况很少。比如银行通常不会多给或少给太多钱,但偶尔可能会有一些特殊情况。

误差 ε ( i ) \varepsilon^{(i)} ε(i)​ 服从高斯分布:
ε ( i ) ∼ N ( 0 , σ 2 ) \varepsilon^{(i)} \sim \mathcal{N}(0, \sigma^2) ε(i)N(0,σ2)

为什么假设误差是高斯分布?

因为高斯分布(也叫正态分布)符合现实世界中很多现象的规律。比如大多数人的身高、体重等都集中在某个平均值附近,极端值很少。


1.4 误差的数学推导

线性回归模型假设目标变量 y y y 是输入特征 x x x 的线性组合加上一个误差项 ε \varepsilon ε
y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i)
其中:

  • y ( i ) y^{(i)} y(i) 是第 i i i个观测值的实际值。
  • θ T x ( i ) \theta^T x^{(i)} θTx(i) 是模型对第 i i i个观测值的预测值。
  • ε ( i ) \varepsilon^{(i)} ε(i) 是第 i i i个观测值的误差。

1.4.1 误差的高斯分布假设

假设误差 ε ( i ) \varepsilon^{(i)} ε(i) 服从均值为 0 0 0、方差为 σ 2 \sigma^2 σ2 的高斯分布(正态分布)。其概率密度函数为:
p ( ε ( i ) ) = 1 2 π σ exp ⁡ ( − ( ε ( i ) ) 2 2 σ 2 ) p(\varepsilon^{(i)}) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(\varepsilon^{(i)})^{2}}{2\sigma^{2}}\right) p(ε(i))=2π σ1exp(2σ2(ε(i))2)
这个假设的意义是:

  1. 误差是随机的,且大多数误差集中在 0 附近。
  2. 误差的分布是对称的,且较大的误差出现的概率较低。

1.4.2 将线性回归模型代入误差分布

从线性回归模型 y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i),我们可以将误差 ε ( i ) \varepsilon^{(i)} ε(i) 表示为:
ε ( i ) = y ( i ) − θ T x ( i ) \varepsilon^{(i)} = y^{(i)} - \theta^T x^{(i)} ε(i)=y(i)θTx(i)

将这个表达式代入误差的高斯分布公式中,得到:
p ( ε ( i ) ) = 1 2 π σ exp ⁡ ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) p(\varepsilon^{(i)}) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) p(ε(i))=2π σ1exp(2σ2(y(i)θTx(i))2)


1.4.3 条件概率分布

由于 ε ( i ) = y ( i ) − θ T x ( i ) \varepsilon^{(i)} = y^{(i)} - \theta^T x^{(i)} ε(i)=y(i)θTx(i) ,我们可以将 p ( ε ( i ) ) p(\varepsilon^{(i)}) p(ε(i)) 重新表示为 y ( i ) y^{(i)} y(i) 的条件概率分布:
p ( y ( i ) ∣ x ( i ) ; θ ) = 1 2 π σ exp ⁡ ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) p(y^{(i)}|x^{(i)};\theta) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) p(y(i)x(i);θ)=2π σ1exp(2σ2(y(i)θTx(i))2)

在给定输入 x ( i ) x^{(i)} x(i) 和参数 θ \theta θ 的情况下, y ( i ) y^{(i)} y(i) 的概率分布。 y ( i ) y^{(i)} y(i) 的分布是以( θ T \theta^T θT x ( i ) x^{(i)} x(i) )的某个组合为均值、 σ 2 \sigma^2 σ2为方差的高斯分布。什么意思呢?就是说,让( θ T \theta^T θT x ( i ) x^{(i)} x(i) )的组合成为 y y y 的可能性越大越好


1.4.4 最大似然估计

为了找到最优的参数 θ \theta θ ,我们使用最大似然估计(Maximum Likelihood Estimation, MLE)。具体步骤如下:

似然函数:

假设所有观测值 y ( i ) y^{(i)} y(i) 是独立同分布的,则整个数据集的联合概率(似然函数)为:
L ( θ ) = ∏ i = 1 n p ( y ( i ) ∣ x ( i ) ; θ ) L(\theta) = \prod_{i=1}^{n} p(y^{(i)}|x^{(i)};\theta) L(θ)=i=1np(y(i)x(i);θ)

在独立同分布的前提下,联合概率等于边缘密度的乘积。

将条件概率分布代入:
L ( θ ) = ∏ i = 1 n 1 2 π σ exp ⁡ ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) L(\theta) = \prod_{i=1}^{n} \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) L(θ)=i=1n2π σ1exp(2σ2(y(i)θTx(i))2)

上述表示的是:什么样的参数与我们的数据组合之后恰好是真实值。

对数似然:

为了简化计算,取对数似然函数(乘法难解,改成加法):
ℓ ( θ ) = log ⁡ L ( θ ) = ∑ i = 1 n log ⁡ ( 1 2 π σ exp ⁡ ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) ) \ell(\theta) = \log L(\theta) = \sum_{i=1}^{n} \log \left( \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) \right) (θ)=logL(θ)=i=1nlog(2π σ1exp(2σ2(y(i)θTx(i))2))

Q:为什么能这么做呢?

A:因为我们只需要求出极值点 θ \theta θ ,而不是极值,改成 l o g log log 之后,极值 m m m 会变为 l o g m log^{m} logm ,但是极值点不变啊。

展开:
ℓ ( θ ) = ∑ i = 1 n ( log ⁡ 1 2 π σ − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) \ell(\theta) = \sum_{i=1}^{n} \left( \log \frac{1}{\sqrt{2\pi}\sigma} - \frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}} \right) (θ)=i=1n(log2π σ12σ2(y(i)θTx(i))2)

去掉常数项(与 θ \theta θ 无关的项):
ℓ ( θ ) = − 1 2 σ 2 ∑ i = 1 n ( y ( i ) − θ T x ( i ) ) 2 + 常数 \ell(\theta) = -\frac{1}{2\sigma^{2}} \sum_{i=1}^{n} (y^{(i)} - \theta^T x^{(i)})^{2} + \text{常数} (θ)=2σ21i=1n(y(i)θTx(i))2+常数

最大化对数似然:

最大化对数似然函数等价于最小化以下目标函数:
J ( θ ) = 1 2 ∑ i = 1 n ( y ( i ) − θ T x ( i ) ) 2 J(\theta) = \frac{1}{2} \sum_{i=1}^{n} (y^{(i)} - \theta^T x^{(i)})^{2} J(θ)=21i=1n(y(i)θTx(i))2

Q:什么越大什么越小,乱七八糟的?

A:因为 ℓ ( θ ) \ell(\theta) (θ) 是一个常数减去对数似然,我们要让 ℓ ( θ ) \ell(\theta) (θ) 最大的话,就是让对数似然最小,因为对数似然恒正。

这就是最小二乘法的目标函数,以上就是通过最大似然估计,推导出最小二乘法。


1.5 目标函数

目标函数 J ( θ ) J(\theta) J(θ) 定义为:

J ( θ ) = 1 2 ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 = 1 2 ( X θ − y ) T ( X θ − y ) J(\theta) = \frac{1}{2} \sum_{i=1}^m (h_\theta(x^{(i)}) - y^{(i)})^2 = \frac{1}{2} (X\theta - y)^T (X\theta - y) J(θ)=21i=1m(hθ(x(i))y(i))2=21(y)T(y)

这里, h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 是假设函数, y ( i ) y^{(i)} y(i) 是实际值, X X X 是特征矩阵, θ \theta θ 是参数向量。目标函数衡量了预测值与实际值之间的误差。

这个目标函数实际上就是最小二乘法的目标函数,其中 θ T x ( i ) \theta^T x^{(i)} θTx(i) 就是假设函数 h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i))

h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 是为了表示模型的预测值,它与实际值 y ( i ) y^{(i)} y(i) 之间的差异(误差)被用来构建目标函数。

求偏导:

为了找到最小化目标函数的参数 θ \theta θ,对 J ( θ ) J(\theta) J(θ) 求偏导并令其等于零。首先展开目标函数:

J ( θ ) = 1 2 ( X θ − y ) T ( X θ − y ) J(\theta) = \frac{1}{2} (X\theta - y)^T (X\theta - y) J(θ)=21(y)T(y)

展开后得到:

J ( θ ) = 1 2 ( θ T X T − y T ) ( X θ − y ) J(\theta) = \frac{1}{2} (\theta^T X^T - y^T)(X\theta - y) J(θ)=21(θTXTyT)(y)

进一步展开:

J ( θ ) = 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) J(\theta) = \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) J(θ)=21(θTXTθTXTyyT+yTy)

θ \theta θ 求偏导:

∇ θ J ( θ ) = ∇ θ ( 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) ) \nabla_\theta J(\theta) = \nabla_\theta \left( \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) \right) θJ(θ)=θ(21(θTXTθTXTyyT+yTy))

矩阵微分,得到:

∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − ( y T X ) T ) = X T X θ − X T y \nabla_\theta J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - (y^T X)^T) = X^T X\theta - X^T y θJ(θ)=21(2XTXTy(yTX)T)=XTXTy

偏导等于零:

为了找到最小值,令偏导等于零:

X T X θ − X T y = 0 X^T X\theta - X^T y = 0 XTXTy=0

解这个方程得到 θ \theta θ 的解析解:

θ = ( X T X ) − 1 X T y \theta = (X^T X)^{-1} X^T y θ=(XTX)1XTy

这个解称为正规方程Normal Equation,它给出了最小化目标函数的参数 θ \theta θ 的值

注:这里我们发现有逆矩阵,众所周知,不是所有矩阵均可逆,那么解不出 θ \theta θ 怎么办?接下来就引出了我们经常说的优化算法。


2 梯度下降

梯度下降是一种优化方法,用来找到目标函数的最小值。你可以把它想象成“下山”:你站在山顶上,目标是找到山谷的最低点。梯度下降就是一步步往下走,直到你到达最低点。

为什么要用梯度下降?

  1. 直接求解不一定可行:有时候目标函数很复杂,甚至没有直接的数学公式可以解出来(比如非线性问题)。线性回归是个特例,可以直接用公式求解,但大多数情况下不行。
  2. 机器学习的套路:我们给机器一堆数据,告诉它“什么样的学习方式是对的”(目标函数),然后让它自己慢慢调整参数,找到最好的结果。

梯度下降怎么工作?

  1. 一口吃不成胖子:你不能一步就从山顶跳到山谷,而是需要一步步往下走。每一步都根据当前的位置,找到最陡的下坡方向(梯度),然后往那个方向迈一小步。
  2. 迭代优化:每次优化一点点,累积起来就是一个很大的进步。比如,你第一次走一步,第二次再走一步,重复很多次后,你就离目标越来越近了。

举例

想象你在玩一个游戏,目标是找到地图上的最低点(比如宝藏的位置)。你每走一步,都会看看周围哪个方向是下坡,然后往那个方向走。梯度下降就是这个过程:

  1. 目标函数:地图的高度(越低越好)。
  2. 梯度:你每走一步时,判断哪个方向是下坡。
  3. 迭代:一步步走,直到找到最低点。

梯度下降


2.1 目标函数

J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 J(\theta) = \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2

目标:我们希望模型的预测值 h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 尽可能接近实际值 y ( i ) y^{(i)} y(i)

平方误差: ( h θ ( x ( i ) ) − y ( i ) ) 2 (h_\theta(x^{(i)}) - y^{(i)})^2 (hθ(x(i))y(i))2 表示单个样本的预测误差。平方的作用是:消除正负误差的抵消(比如 − 2 -2 2 2 2 2 的误差平方后都是 4 4 4)。对大误差给予更大的惩罚。

均值: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1i=1m 表示对所有样本的误差取平均,反映模型的整体误差。

1 2 \frac{1}{2} 21:为了简化求导时的计算,因为对平方项求导后会多出一个 2 2 2 1 2 \frac{1}{2} 21 可以抵消这个 2 2 2


2.2 批量梯度下降(Batch Gradient Descent)

梯度公式:
∂ J ( θ ) ∂ θ j = − 1 m ∑ i = 1 m ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) \frac{\partial J(\theta)}{\partial \theta_j} = -\frac{1}{m} \sum_{i=1}^{m} (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} θjJ(θ)=m1i=1m(y(i)hθ(x(i)))xj(i)

梯度:梯度是目标函数对参数 θ j \theta_j θj 的偏导数,表示目标函数在 θ j \theta_j θj 方向上的变化率。

误差项: ( y ( i ) − h θ ( x ( i ) ) ) (y^{(i)} - h_\theta(x^{(i)})) (y(i)hθ(x(i))) 是第 i i i 个样本的预测误差。

特征值: x j ( i ) x_j^{(i)} xj(i) 是第 i i i 个样本的第 j j j 个特征值,表示误差对 θ j \theta_j θj 的影响。

均值: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1i=1m 表示对所有样本的梯度取平均,确保更新方向是全局最优的。

参数更新公式:
θ j : = θ j − α ⋅ ∂ J ( θ ) ∂ θ j \theta_j := \theta_j - \alpha \cdot \frac{\partial J(\theta)}{\partial \theta_j} θj:=θjαθjJ(θ)

负梯度方向:梯度 ∂ J ( θ ) ∂ θ j \frac{\partial J(\theta)}{\partial \theta_j} θjJ(θ) 表示目标函数上升最快的方向,因此负梯度方向是下降最快的方向。

学习率 α \alpha α:控制每次更新的步长。如果 α \alpha α 太大,可能会跳过最优解;如果 α \alpha α 太小,收敛速度会变慢。

容易得到最优解,但是每次要考虑所有的样本,所以很慢。


2.3 随机梯度下降(Stochastic Gradient Descent, SGD)

参数更新公式:
θ j : = θ j − α ⋅ ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) \theta_j := \theta_j - \alpha \cdot (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} θj:=θjα(y(i)hθ(x(i)))xj(i)

单个样本:每次只使用一个样本 ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i)) 计算梯度。

随机性:由于每次只用一个样本,梯度方向可能不完全准确,但计算速度快。

更新方向: ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} (y(i)hθ(x(i)))xj(i) 是单个样本的梯度,表示当前样本对参数 θ j \theta_j θj 的影响。

每次找一个样本,迭代速度快,但不一定每次都朝着收敛的方向(说直白点就是看命)


2.4 小批量梯度下降(Mini-batch Gradient Descent)

参数更新公式:
θ j : = θ j − α ⋅ 1 b ∑ k = i i + b − 1 ( h θ ( x ( k ) ) − y ( k ) ) x j ( k ) \theta_j := \theta_j - \alpha \cdot \frac{1}{b} \sum_{k=i}^{i+b-1} (h_\theta(x^{(k)}) - y^{(k)})x_j^{(k)} θj:=θjαb1k=ii+b1(hθ(x(k))y(k))xj(k)

小批量样本:每次使用 b b b 个样本计算梯度,既不是全部样本,也不是单个样本。

均值: 1 b ∑ k = i i + b − 1 \frac{1}{b} \sum_{k=i}^{i+b-1} b1k=ii+b1 表示对小批量样本的梯度取平均,平衡了批量梯度下降的稳定性和随机梯度下降的效率。

更新方向: ( h θ ( x ( k ) ) − y ( k ) ) x j ( k ) (h_\theta(x^{(k)}) - y^{(k)})x_j^{(k)} (hθ(x(k))y(k))xj(k) 是小批量样本的梯度,表示当前小批量样本对参数 θ j \theta_j θj 的影响。

b a t c h batch batch 一般有32、64、128、256;batch越大结果越精确,但是会更慢。所以需要自己去权衡,一般越大越好,很少有低于64的。

每次更新选择一小部分数据来算,比较合理的


补充 对 θ \theta θ 求偏导

怕之后会忘记这部分的知识,笔者在这里就补充一下推导过程。

θ \theta θ 求偏导:

∇ θ J ( θ ) = ∇ θ ( 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) ) \nabla_\theta J(\theta) = \nabla_\theta \left( \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) \right) θJ(θ)=θ(21(θTXTθTXTyyT+yTy))

矩阵微分,得到:

∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − ( y T X ) T ) = X T X θ − X T y \nabla_\theta J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - (y^T X)^T) = X^T X\theta - X^T y θJ(θ)=21(2XTXTy(yTX)T)=XTXTy

我们需要计算目标函数 J ( θ ) J(\theta) J(θ) θ \theta θ 的偏导数 ∇ θ J ( θ ) \nabla_{\theta} J(\theta) θJ(θ)。具体步骤如下:

J ( θ ) J(\theta) J(θ) 中的每一项分别求导:

第一项: θ T X T X θ \theta^T X^T X\theta θTXT

关于 θ \theta θ 的二次型:
∇ θ ( θ T X T X θ ) = 2 X T X θ \nabla_{\theta} (\theta^T X^T X\theta) = 2X^T X\theta θ(θTXT)=2XT
第二项: − θ T X T y -\theta^T X^T y θTXTy

关于 θ \theta θ 的线性项:
∇ θ ( − θ T X T y ) = − X T y \nabla_{\theta} (-\theta^T X^T y) = -X^T y θ(θTXTy)=XTy
第三项: − y T X θ -y^T X\theta yT

关于 θ \theta θ 的线性项:
∇ θ ( − y T X θ ) = − X T y \nabla_{\theta} (-y^T X\theta) = -X^T y θ(yT)=XTy
(因为 y T X θ y^T X\theta yT 是一个标量,其转置等于自身,即 y T X θ = ( y T X θ ) T = θ T X T y y^T X\theta = (y^T X\theta)^T = \theta^T X^T y yT=(yT)T=θTXTy

第四项: y T y y^T y yTy

这是一个常数项,与 θ \theta θ 无关:
∇ θ ( y T y ) = 0 \nabla_{\theta} (y^T y) = 0 θ(yTy)=0
将上述各项的导数合并:

∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − X T y + 0 ) \nabla_{\theta} J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - X^T y + 0) θJ(θ)=21(2XTXTyXTy+0)

简化后:

∇ θ J ( θ ) = 1 2 ( 2 X T X θ − 2 X T y ) = X T X θ − X T y \nabla_{\theta} J(\theta) = \frac{1}{2} (2X^T X\theta - 2X^T y) = X^T X\theta - X^T y θJ(θ)=21(2XT2XTy)=XTXTy

相关文章:

从零推导线性回归:最小二乘法与梯度下降的数学原理

​ 欢迎来到我的主页:【Echo-Nie】 本篇文章收录于专栏【机器学习】 本文所有内容相关代码都可在以下仓库中找到: Github-MachineLearning 1 线性回归 1.1 什么是线性回归 线性回归是一种用来预测和分析数据之间关系的工具。它的核心思想是找到一条直…...

OpenSIPS-从安装部署开始认识一个组件

前期讲到了Kamailio,它是一个不错的开源SIP(Session Initiation Protocol)服务器,主要用于构建高效的VoIP(Voice over IP)平台以及即时通讯服务。但是在同根同源(OpenSER)的分支上&a…...

数据结构(树)

每一个节点包含&#xff1a;父节点地址 值 左子节点地址 右子节点地址 如果一个节点不含有&#xff1a;父节点地址或左子节点地址 右子节点地址就记为null 二叉树 度&#xff1a;每一个节点的子节点数量 二叉树中&#xff0c;任意节点的度<2 树的结构&#xff1a; 二叉查…...

[Dialog屏幕开发] 设置搜索帮助

阅读该篇文章之前&#xff0c;可先阅读下述资料 [Dialog屏幕开发] 屏幕绘制(使用向导创建Tabstrip Control标签条控件)https://blog.csdn.net/Hudas/article/details/145372195?spm1001.2014.3001.5501https://blog.csdn.net/Hudas/article/details/145372195?spm1001.2014.…...

C语言从入门到进阶

视频&#xff1a;https://www.bilibili.com/video/BV1Vm4y1r7jY?spm_id_from333.788.player.switch&vd_sourcec988f28ad9af37435316731758625407&p23 //枚举常量 enum Sex{MALE,FEMALE,SECRET };printf("%d\n", MALE);//0 printf("%d\n", FEMALE…...

Node.js下载安装及环境配置教程 (详细版)

Node.js&#xff1a;是一个基于 Chrome V8 引擎的 JavaScript 运行时&#xff0c;用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型&#xff0c;使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序&#xff0…...

Mac Electron 应用签名(signature)和公证(notarization)

在MacOS 10.14.5之后&#xff0c;如果应用没有在苹果官方平台进行公证notarization(我们可以理解为安装包需要审核&#xff0c;来判断是否存在病毒)&#xff0c;那么就不能被安装。当然现在很多人的解决方案都是使用sudo spctl --master-disable&#xff0c;取消验证模式&#…...

redis安装 windows版本

下载 github下载5.x版本redis 安装以及启动 解压文件&#xff0c;目标如下 进入cmd至安装路径 执行如下命令启动redis redis-server.exe redis.windows.conf 进入redis,另外启动cmd在当前目录执行进入redis 服务 redis-cli 测试命令 至此安装成功&#xff0c;但是这只是…...

关联传播和 Python 和 Scikit-learn 实现

文章目录 一、说明二、什么是 Affinity Propagation。2.1 先说Affinity 传播的工作原理2.2 更多细节2.3 传播两种类型的消息2.4 计算责任和可用性的分数2.4.1 责任2.4.2 可用性分解2.4.3 更新分数&#xff1a;集群是如何形成的2.4.4 估计集群本身的数量。 三、亲和力传播的一些…...

若依基本使用及改造记录

若依框架想必大家都了解得不少&#xff0c;不可否认这是一款及其简便易用的框架。 在某种情况下&#xff08;比如私活&#xff09;使用起来可谓是快得一匹。 在这里小兵结合自身实际使用情况&#xff0c;记录一下我对若依框架的使用和改造情况。 一、源码下载 前往码云进行…...

c语言网 1127 尼科彻斯定理

原题 题目描述 验证尼科彻斯定理&#xff0c;即&#xff1a;任何一个整数m的立方都可以写成m个连续奇数之和。 输入格式 任一正整数 输出格式 该数的立方分解为一串连续奇数的和 样例输入 13 样例输出 13*13*132197157159161163165167169171173175177179181 ​ #include<ios…...

能说说MyBatis的工作原理吗?

大家好&#xff0c;我是锋哥。今天分享关于【Redis为什么这么快?】面试题。希望对大家有帮助&#xff1b; 能说说MyBatis的工作原理吗&#xff1f; MyBatis 是一款流行的持久层框架&#xff0c;它通过简化数据库操作&#xff0c;帮助开发者更高效地与数据库进行交互。MyBatis…...

卡特兰数学习

1&#xff0c;概念 卡特兰数&#xff08;英语&#xff1a;Catalan number&#xff09;&#xff0c;又称卡塔兰数&#xff0c;明安图数。是组合数学中一种常出现于各种计数问题中的数列。它在不同的计数问题中频繁出现。 2&#xff0c;公式 卡特兰数的递推公式为&#xff1a;f(…...

【算法】多源 BFS

多源 BFS 1.矩阵距离2.刺杀大使 单源最短路问题 vs 多源最短路问题 当问题中只存在一个起点时&#xff0c;这时的最短路问题就是单源最短路问题。当问题中存在多个起点而不是单一起点时&#xff0c;这时的最短路问题就是多源最短路问题。 多源 BFS&#xff1a;多源最短路问题…...

解锁数字经济新动能:探寻 Web3 核心价值

随着科技的快速发展&#xff0c;我们正迈入一个全新的数字时代&#xff0c;Web3作为这一时代的核心构成之一&#xff0c;正在为全球数字经济带来革命性的变革。本文将探讨Web3的核心价值&#xff0c;并如何推动数字经济的新动能。 Web3是什么&#xff1f; Web3&#xff0c;通常…...

CAN总线数据采集与分析

CAN总线数据采集与分析 目录 CAN总线数据采集与分析1. 引言2. 数据采集2.1 数据采集简介2.2 数据采集实现 3. 数据分析3.1 数据分析简介3.2 数据分析实现 4. 数据可视化4.1 数据可视化简介4.2 数据可视化实现 5. 案例说明5.1 案例1&#xff1a;数据采集实现5.2 案例2&#xff1…...

appium自动化环境搭建

一、appium介绍 appium介绍 appium是一个开源工具、支持跨平台、用于自动化ios、安卓手机和windows桌面平台上面的原生、移动web和混合应用&#xff0c;支持多种编程语言(python&#xff0c;java&#xff0c;Ruby&#xff0c;Javascript、PHP等) 原生应用和混合应用&#xf…...

二叉树高频题目——下——不含树型dp

一&#xff0c;普通二叉树上寻找两个节点的最近的公共祖先 1&#xff0c;介绍 LCA&#xff08;Lowest Common Ancestor&#xff0c;最近公共祖先&#xff09;是二叉树中经常讨论的一个问题。给定二叉树中的两个节点&#xff0c;它的LCA是指这两个节点的最低&#xff08;最深&…...

Java并发学习:进程与线程的区别

进程的基本原理 一个进程是一个程序的一次启动和执行&#xff0c;是操作系统程序装入内存&#xff0c;给程序分配必要的系统资源&#xff0c;并且开始运行程序的指令。 同一个程序可以多次启动&#xff0c;对应多个进程&#xff0c;例如同一个浏览器打开多次。 一个进程由程…...

【ProxyBroker】用Python打破网络限制的利器

ProxyBroker 1. 什么是ProxyBroker2. ProxyBroker的功能3. ProxyBroker的优势4. ProxyBroker的使用方法5. ProxyBroker的应用场景6.结语项目地址&#xff1a; 1. 什么是ProxyBroker ProxyBroker是一个开源工具&#xff0c;它可以异步地从多个来源找到公共代理&#xff0c;并同…...

Gradle buildSrc模块详解:集中管理构建逻辑的利器

文章目录 buildSrc模块二 buildSrc的使命三 如何使用buildSrc1. 创建目录结构2. 配置buildSrc的构建脚本3. 编写共享逻辑4. 在模块中引用 四 典型使用场景1. 统一依赖版本管理2. 自定义Gradle任务 3. 封装通用插件4. 扩展Gradle API 五 注意事项六 与复合构建&#xff08;Compo…...

2025数学建模美赛|F题成品论文

国家安全政策与网络安全 摘要 随着互联网技术的迅猛发展&#xff0c;网络犯罪问题已成为全球网络安全中的重要研究课题&#xff0c;且网络犯罪的形式和影响日益复杂和严重。本文针对网络犯罪中的问题&#xff0c;基于多元回归分析和差异中的差异&#xff08;DiD&#xff09;思…...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.10 文本数据炼金术:从CSV到结构化数组

1.10 《文本数据炼金术&#xff1a;从CSV到结构化数组》 目录 #mermaid-svg-TNkACjzvaSXnULaB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-TNkACjzvaSXnULaB .error-icon{fill:#552222;}#mermaid-svg-TNkACjzva…...

「蓝桥杯题解」蜗牛(Java)

题目链接 这道题我感觉状态定义不太好想&#xff0c;需要一定的经验 import java.util.*; /*** 蜗牛* 状态定义&#xff1a;* dp[i][0]:到达(x[i],0)最小时间* dp[i][1]:到达 xi 上方的传送门最小时间*/public class Main {static Scanner in new Scanner(System.in);static f…...

基于springboot+vue的流浪动物救助系统的设计与实现

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…...

51单片机开发:IO扩展(串转并)实验

实验目标&#xff1a;通过扩展口从下至上依次点亮点阵屏的行。 下图左边是74HC595 模块电路图&#xff0c;右边是点阵屏电图图。 SRCLK上升沿时&#xff0c;将SER输入的数据移送至内部的移位寄存器。 RCLK上升沿时&#xff0c;将数据从移位寄存器移动至存储寄存器&#xff0c…...

JAVA实战开源项目:购物商城网站(Vue+SpringBoot) 附源码

本文项目编号 T 032 &#xff0c;文末自助获取源码 \color{red}{T032&#xff0c;文末自助获取源码} T032&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…...

C++学习——认识和与C的区别

目录 前言 一、什么是C 二、C关键字 三、与C语言不同的地方 3.1头文件 四、命名空间 4.1命名空间的概念写法 4.2命名空间的访问 4.3命名空间的嵌套 4.4命名空间在实际中的几种写法 五、输入输出 5.1cout 5.2endl 5.3cin 总结 前言 开启新的篇章&#xff0c;这里…...

Open FPV VTX开源之ardupilot双OSD配置摄像头

Open FPV VTX开源之ardupilot双OSD配置 1 源由2. 分析3. 配置4. 解决办法5. 参考资料 1 源由 鉴于笔者这台Mark4 Copter已经具备一定的历史&#xff0c;目前机载了两个FPV摄像头&#xff1a; 模拟摄像头数字摄像头(OpenIPC) 测试场景&#xff1a; 从稳定性的角度&#xff1…...

基于微信小程序高校课堂教学管理系统 课堂管理系统微信小程序(源码+文档)

目录 一.研究目的 二.需求分析 三.数据库设计 四.系统页面展示 五.免费源码获取 一.研究目的 困扰管理层的许多问题当中,高校课堂教学管理也是不敢忽视的一块。但是管理好高校课堂教学又面临很多麻烦需要解决,如何在工作琐碎,记录繁多的情况下将高校课堂教学的当前情况反…...

unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?

需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分&#xff0c;则为真。 如果一个点在可步行导航网表面之上或之下&#xff0c;在任何距离&#xff0c;如果它不在更近的不可步行节点之上 / 之下&#xff0c;则认为它在导航网上。 使用方法 Ast…...

三星手机人脸识别解锁需要点击一下电源键,能够不用点击直接解锁吗

三星手机的人脸识别解锁功能默认需要滑动或点击屏幕来解锁。这是为了增强安全性&#xff0c;防止误解锁的情况。如果希望在检测到人脸后直接进入主界面&#xff0c;可以通过以下设置调整&#xff1a; 打开设置&#xff1a; 进入三星手机的【设置】应用。 进入生物识别和安全&a…...

read+write实现:链表放到文件+文件数据放到链表 的功能

思路 一、 定义链表&#xff1a; 1 节点结构&#xff08;数据int型&#xff09; 2 链表操作&#xff08;创建节点、插入节点、释放链表、打印链表&#xff09;。 二、链表保存到文件 1打开文件 2遍历链表、写文件&#xff1a; 遍历链表,write()将节点数据写入文件。…...

猫怎么分公的母的?

各位铲屎官们&#xff0c;是不是刚领养了一只小猫咪&#xff0c;却分不清它是公是母&#xff1f;别急&#xff0c;今天就来给大家好好揭秘&#xff0c;如何轻松辨别猫咪的性别&#xff0c;让你不再为“它”是“他”还是“她”而烦恼&#xff01; 一、观察生殖器位置 最直接的方…...

为何SAP S4系统中要设置MRP区域?MD04中可否同时显示工厂级、库存地点级的数据?

【SAP系统PP模块研究】 一、物料主数据的MRP区域设置 SAP ECC系统中想要指定不影响MRP运算的库存地点,是针对库存地点设置MRP标识,路径为:SPRO->生产->物料需求计划->计划->定义每一个工厂的存储地点MRP,如下图所示: 另外,在给物料主数据MMSC扩充库存地点时…...

Redis for AI

Redis存储和索引语义上表示非结构化数据&#xff08;包括文本通道、图像、视频或音频&#xff09;的向量嵌入。将向量和关联的元数据存储在哈希或JSON文档中&#xff0c;用于索引和查询。 Redis包括一个高性能向量数据库&#xff0c;允许您对向量嵌入执行语义搜索。可以通过过…...

初阶2 类与对象

本章重点 上篇1.面向过程和面向对象初步认识2.类的引入---结构体3.类的定义3.1 语法3.2 组成3.3 定义类的两种方法&#xff1a; 4.类的访问限定符及封装4.1 访问限定符4.2封装---面向对象的三大特性之一 5.类的作用域6.类的实例化7.类对象模型7.1 如何计算类对象的大小 8.this指…...

kafka-部署安装

一. 简述&#xff1a; Kafka 是一个分布式流处理平台&#xff0c;常用于构建实时数据管道和流应用。 二. 安装部署&#xff1a; 1. 依赖&#xff1a; a). Java&#xff1a;Kafka 需要 Java 8 或更高版本。 b). zookeeper&#xff1a; #tar fxvz zookeeper-3.7.0.tar.gz #…...

深入探讨防抖函数中的 this 上下文

深入剖析防抖函数中的 this 上下文 最近我在研究防抖函数实现的时候&#xff0c;发现一个耗费脑子的问题&#xff0c;出现了令我困惑的问题。接下来&#xff0c;我将通过代码示例&#xff0c;深入探究这些现象背后的原理。 示例代码 function debounce(fn, delay) {let time…...

人工智能丨Midscene:让UI自动化测试变得更简单

在这个数字化时代&#xff0c;每一个细节的优化都可能成为产品脱颖而出的关键。而对于测试人员来说&#xff0c;确保产品界面的稳定性和用户体验的流畅性至关重要。今天&#xff0c;我要向大家介绍一款名为Midscene的工具&#xff0c;它利用自然语言处理&#xff08;NLP&#x…...

【数据结构】_链表经典算法OJ(力扣版)

目录 1. 移除链表元素 1.1 题目描述及链接 1.2 解题思路 1.3 程序 2. 反转链表 2.1 题目描述及链接 2.2 解题思路 2.3 程序 3. 链表的中间结点 3.1 题目描述及链接 3.2 解题思路 3.3 程序 1. 移除链表元素 1.1 题目描述及链接 原题链接&#xff1a;203. 移除链表…...

DeepSeek-R1技术报告速读

春节将至&#xff0c;DeepSeek又出王炸&#xff01;DeepSeek-R1系列重磅开源。本文对其技术报告做简单解读。 话不多说&#xff0c;show me the benchmark。从各个高难度benchmark结果来看&#xff0c;DeepSeek-R1已经比肩OpenAI-o1-1217&#xff0c;妥妥的第一梯队推理模型。…...

560. 和为 K 的子数组

【题目】&#xff1a;560. 和为 K 的子数组 方法1. 前缀和 class Solution { public:int subarraySum(vector<int>& nums, int k) {int res 0;int n nums.size();vector<int> preSum(n 1, 0); // 下标从1开始存储for(int i 0; i < n; i) {preSum[i 1]…...

鸿蒙仓颉环境配置(仓颉SDK下载,仓颉VsCode开发环境配置,仓颉DevEco开发环境配置)

目录 ​1&#xff09;仓颉的SDK下载 1--进入仓颉的官网 2--点击图片中的下载按钮 3--在新跳转的页面点击即刻下载 4--下载 5--找到你们自己下载好的地方 6--解压软件 2&#xff09;仓颉编程环境配置 1--找到自己的根目录 2--进入命令行窗口 3--输入 envsetup.bat 4--验证是否安…...

NodeJs / Bun 分析文件编码 并将 各种编码格式 转为 另一个编码格式 ( 比如: GB2312→UTF-8, UTF-8→GB2312)

版本号 "iconv-lite": "^0.6.3", "chardet": "^2.0.0",github.com/runk/node-chardet 可以识别文本是 哪种编码 ( 大文件截取一部分进行分析,速度比较快 ) let bun_file_obj Bun.file(full_file_path) let file_bytes await bun_f…...

Java学习笔记(二十五)

1 Kafka Raft 简单介绍 Kafka Raft (KRaft) 是 Kafka 引入的一种新的分布式共识协议&#xff0c;用于取代之前依赖的 Apache ZooKeeper 集群管理机制。从 Kafka 2.8 开始&#xff0c;Kafka 开始支持基于 KRaft 的独立模式&#xff0c;计划在未来完全移除 ZooKeeper 的依赖。 1…...

Baklib如何结合内容中台与人工智能技术实现数字化转型

内容概要 在当前快速发展的数字环境中&#xff0c;企业面临着转型的紧迫性与挑战&#xff0c;尤其是在内容管理和用户互动的领域。内容中台作为一种集成化的解决方案&#xff0c;不仅能够提高企业在资源管理方面的效率&#xff0c;还能够为企业提供一致性和灵活性的内容分发机…...

git困扰的问题

.gitignore中添加的某个忽略文件并不生效 把某些目录或文件加入忽略规则&#xff0c;按照上述方法定义后发现并未生效&#xff0c; gitignore只能忽略那些原来没有被追踪的文件&#xff0c;如果某些文件已经被纳入了版本管理中&#xff0c;则修改.gitignore是无效的。 解决方…...

第05章 12 可视化热量流线图一例

下面是一个使用VTK&#xff08;Visualization Toolkit&#xff09;和C编写的示例代码&#xff0c;展示如何在一个厨房模型中可视化热量流线图&#xff0c;并按照热量传递速度着色显示。这个示例假设你已经安装了VTK库&#xff0c;并且你的开发环境已经配置好来编译和运行VTK程序…...

Vue组件开发-使用 html2canvas 和 jspdf 库实现PDF文件导出 设置页面大小及方向

在 Vue 项目中实现导出 PDF 文件、调整文件页面大小和页面方向的功能&#xff0c;使用 html2canvas 将 HTML 内容转换为图片&#xff0c;再使用 jspdf 把图片添加到 PDF 文件中。以下是详细的实现步骤和代码示例&#xff1a; 步骤 1&#xff1a;安装依赖 首先&#xff0c;在项…...