从零推导线性回归:最小二乘法与梯度下降的数学原理
欢迎来到我的主页:【Echo-Nie】
本篇文章收录于专栏【机器学习】
本文所有内容相关代码都可在以下仓库中找到:
Github-MachineLearning
1 线性回归
1.1 什么是线性回归
线性回归是一种用来预测和分析数据之间关系的工具。它的核心思想是找到一条直线(或者一个平面),让这条直线尽可能地“拟合”已有的数据点,通过这条直线,我们可以预测新的数据。
eg:
假设你想预测房价,你知道房子的大小(面积)和房价之间有关系。线性回归可以帮助你找到一条直线,表示“房子越大,房价越高”的关系。
这条直线的方程可以写成:
y = θ 0 + θ 1 x y = \theta_0 + \theta_1 x y=θ0+θ1x
其中:
- y y y 是房价(目标变量)。
- x x x 是房子的大小(特征)。
- θ 0 \theta_0 θ0 是截距(当房子大小为 0 时的房价)。
- θ 1 \theta_1 θ1 是斜率(房子每增加一平米,房价增加多少)。
简单线性回归:
如果只有一个特征(比如房子的大小),线性回归就是找到一条直线来拟合数据。公式是:
y = θ 0 + θ 1 x y = \theta_0 + \theta_1 x y=θ0+θ1x
- 目标:找到 θ 0 \theta_0 θ0 和 θ 1 \theta_1 θ1,使得这条直线最接近所有的数据点。
- 怎么找?通过调整 θ 0 \theta_0 θ0 和 θ 1 \theta_1 θ1,让预测值 y y y 和实际值之间的误差最小。
多元线性回归:
如果有多个特征(比如房子的大小、房间数量、地段等),线性回归就是找到一个“平面”来拟合数据。公式是:
y = θ 0 + θ 1 x 1 + θ 2 x 2 + ⋯ + θ n x n y = \theta_0 + \theta_1 x_1 + \theta_2 x_2 + \dots + \theta_n x_n y=θ0+θ1x1+θ2x2+⋯+θnxn
这里:
x i x_i xi 是多个特征(比如房子大小、房间数量等); θ i \theta_i θi 是模型的参数,表示每个特征对房价的影响
- 拟合数据:线性回归就像在一堆散点图中画一条直线,让这条直线尽可能靠近所有的点。
- 预测:有了这条直线后,如果有一个新的数据点(比如一个新房子的大小),我们可以用这条直线来预测它的房价。
- 误差:预测值和实际值之间的差距叫误差。线性回归的目标是让所有数据的误差最小。
1.2 年龄和金钱举例
x 1 x_1 x1、 x 2 x_2 x2就是我们的两个特征(年龄,工资); Y Y Y是银行最终会借给我们多少钱。找到最合适的一条线(想象一个高维)来最好的拟合我们的数据点。
假设 θ 1 \theta_1 θ1是年龄的参数, θ 2 \theta_2 θ2是工资的参数
-
拟合平面:
h θ ( x ) = θ 0 + θ 1 x 1 + θ 2 x 2 h_{\theta}(x) = \theta_0 + \theta_1 x_1 + \theta_2 x_2 hθ(x)=θ0+θ1x1+θ2x2
这里, θ 0 \theta_0 θ0 是偏置项, θ 1 \theta_1 θ1 和 θ 2 \theta_2 θ2 是权重, x 1 x_1 x1 和 x 2 x_2 x2 是输入特征。 -
引入 x 0 = 1 x_0 = 1 x0=1:
为了将偏置项 θ 0 \theta_0 θ0 也纳入向量化的形式,我们引入一个额外的特征 x 0 x_0 x0,并设 x 0 = 1 x_0 = 1 x0=1。相当于我们做数据预处理,矩阵最外围加了一圈1,方便计算。这样,公式可以写成:
h θ ( x ) = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 h_{\theta}(x) = \theta_0 x_0 + \theta_1 x_1 + \theta_2 x_2 hθ(x)=θ0x0+θ1x1+θ2x2 -
向量化表示:
现在,我们可以将权重 θ \theta θ 和特征 x x x 表示为向量:
θ = [ θ 0 θ 1 θ 2 ] , x = [ x 0 x 1 x 2 ] \theta = \begin{bmatrix} \theta_0 \\ \theta_1 \\ \theta_2 \end{bmatrix}, \quad x = \begin{bmatrix} x_0 \\ x_1 \\ x_2 \end{bmatrix} θ= θ0θ1θ2 ,x= x0x1x2
这样,原始公式可以表示为这两个向量的点积:
h θ ( x ) = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 = θ T x h_{\theta}(x) = \theta_0 x_0 + \theta_1 x_1 + \theta_2 x_2 = \theta^T x hθ(x)=θ0x0+θ1x1+θ2x2=θTx
其中, θ T \theta^T θT 是向量 θ \theta θ 的转置。 -
一般化到 n n n 个特征:
如果有 n n n 个特征,公式可以推广为:
h θ ( x ) = ∑ i = 0 n θ i x i = θ T x h_{\theta}(x) = \sum_{i=0}^{n} \theta_i x_i = \theta^T x hθ(x)=i=0∑nθixi=θTx
这里, θ \theta θ 和 x x x 都是 n + 1 n+1 n+1 维的向量。 -
误差表示:
误差真实值和预测值之间肯定是要存在差异的(用 ε \varepsilon ε 来表示该误差)。
对于每个样本:
y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i)
参数意义:
- y ( i ) y^{(i)} y(i) 是第 i i i 个样本的真实值。
- θ T x ( i ) \theta^T x^{(i)} θTx(i) 是模型对第 i i i 个样本的预测值。
- ε ( i ) \varepsilon^{(i)} ε(i) 是第 i i i 个样本的误差,表示真实值和预测值之间的差异。
1.3 误差
误差是模型预测值和真实值之间的差距。比如你预测房价是 10 万,但实际房价是 15 万,那么误差就是 5 万。
误差 ε ( i ) \varepsilon^{(i)} ε(i) 是独立并且具有相同的分布,并且服从均值为 0 0 0、方差为 σ 2 \sigma^2 σ2 的高斯分布。
- 独立:每个数据的误差是独立的,不会互相影响。比如张三贷款多了,不会影响李四的贷款。
- 同分布:所有数据的误差都来自同一个“规律”。比如所有贷款误差都来自同一家银行的规则。
- 高斯分布:误差大多数情况下比较小,偶尔会有一些大的误差,但这种情况很少。比如银行通常不会多给或少给太多钱,但偶尔可能会有一些特殊情况。
误差 ε ( i ) \varepsilon^{(i)} ε(i) 服从高斯分布:
ε ( i ) ∼ N ( 0 , σ 2 ) \varepsilon^{(i)} \sim \mathcal{N}(0, \sigma^2) ε(i)∼N(0,σ2)
为什么假设误差是高斯分布?
因为高斯分布(也叫正态分布)符合现实世界中很多现象的规律。比如大多数人的身高、体重等都集中在某个平均值附近,极端值很少。
1.4 误差的数学推导
线性回归模型假设目标变量 y y y 是输入特征 x x x 的线性组合加上一个误差项 ε \varepsilon ε :
y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i)
其中:
- y ( i ) y^{(i)} y(i) 是第 i i i个观测值的实际值。
- θ T x ( i ) \theta^T x^{(i)} θTx(i) 是模型对第 i i i个观测值的预测值。
- ε ( i ) \varepsilon^{(i)} ε(i) 是第 i i i个观测值的误差。
1.4.1 误差的高斯分布假设
假设误差 ε ( i ) \varepsilon^{(i)} ε(i) 服从均值为 0 0 0、方差为 σ 2 \sigma^2 σ2 的高斯分布(正态分布)。其概率密度函数为:
p ( ε ( i ) ) = 1 2 π σ exp ( − ( ε ( i ) ) 2 2 σ 2 ) p(\varepsilon^{(i)}) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(\varepsilon^{(i)})^{2}}{2\sigma^{2}}\right) p(ε(i))=2πσ1exp(−2σ2(ε(i))2)
这个假设的意义是:
- 误差是随机的,且大多数误差集中在 0 附近。
- 误差的分布是对称的,且较大的误差出现的概率较低。
1.4.2 将线性回归模型代入误差分布
从线性回归模型 y ( i ) = θ T x ( i ) + ε ( i ) y^{(i)} = \theta^T x^{(i)} + \varepsilon^{(i)} y(i)=θTx(i)+ε(i),我们可以将误差 ε ( i ) \varepsilon^{(i)} ε(i) 表示为:
ε ( i ) = y ( i ) − θ T x ( i ) \varepsilon^{(i)} = y^{(i)} - \theta^T x^{(i)} ε(i)=y(i)−θTx(i)
将这个表达式代入误差的高斯分布公式中,得到:
p ( ε ( i ) ) = 1 2 π σ exp ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) p(\varepsilon^{(i)}) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) p(ε(i))=2πσ1exp(−2σ2(y(i)−θTx(i))2)
1.4.3 条件概率分布
由于 ε ( i ) = y ( i ) − θ T x ( i ) \varepsilon^{(i)} = y^{(i)} - \theta^T x^{(i)} ε(i)=y(i)−θTx(i) ,我们可以将 p ( ε ( i ) ) p(\varepsilon^{(i)}) p(ε(i)) 重新表示为 y ( i ) y^{(i)} y(i) 的条件概率分布:
p ( y ( i ) ∣ x ( i ) ; θ ) = 1 2 π σ exp ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) p(y^{(i)}|x^{(i)};\theta) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) p(y(i)∣x(i);θ)=2πσ1exp(−2σ2(y(i)−θTx(i))2)
在给定输入 x ( i ) x^{(i)} x(i) 和参数 θ \theta θ 的情况下, y ( i ) y^{(i)} y(i) 的概率分布。 y ( i ) y^{(i)} y(i) 的分布是以( θ T \theta^T θT 和 x ( i ) x^{(i)} x(i) )的某个组合为均值、 σ 2 \sigma^2 σ2为方差的高斯分布。什么意思呢?就是说,让( θ T \theta^T θT 和 x ( i ) x^{(i)} x(i) )的组合成为 y y y 的可能性越大越好
1.4.4 最大似然估计
为了找到最优的参数 θ \theta θ ,我们使用最大似然估计(Maximum Likelihood Estimation, MLE)。具体步骤如下:
似然函数:
假设所有观测值 y ( i ) y^{(i)} y(i) 是独立同分布的,则整个数据集的联合概率(似然函数)为:
L ( θ ) = ∏ i = 1 n p ( y ( i ) ∣ x ( i ) ; θ ) L(\theta) = \prod_{i=1}^{n} p(y^{(i)}|x^{(i)};\theta) L(θ)=i=1∏np(y(i)∣x(i);θ)
在独立同分布的前提下,联合概率等于边缘密度的乘积。
将条件概率分布代入:
L ( θ ) = ∏ i = 1 n 1 2 π σ exp ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) L(\theta) = \prod_{i=1}^{n} \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) L(θ)=i=1∏n2πσ1exp(−2σ2(y(i)−θTx(i))2)
上述表示的是:什么样的参数与我们的数据组合之后恰好是真实值。
对数似然:
为了简化计算,取对数似然函数(乘法难解,改成加法):
ℓ ( θ ) = log L ( θ ) = ∑ i = 1 n log ( 1 2 π σ exp ( − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) ) \ell(\theta) = \log L(\theta) = \sum_{i=1}^{n} \log \left( \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}}\right) \right) ℓ(θ)=logL(θ)=i=1∑nlog(2πσ1exp(−2σ2(y(i)−θTx(i))2))
Q:为什么能这么做呢?
A:因为我们只需要求出极值点 θ \theta θ ,而不是极值,改成 l o g log log 之后,极值 m m m 会变为 l o g m log^{m} logm ,但是极值点不变啊。
展开:
ℓ ( θ ) = ∑ i = 1 n ( log 1 2 π σ − ( y ( i ) − θ T x ( i ) ) 2 2 σ 2 ) \ell(\theta) = \sum_{i=1}^{n} \left( \log \frac{1}{\sqrt{2\pi}\sigma} - \frac{(y^{(i)} - \theta^T x^{(i)})^{2}}{2\sigma^{2}} \right) ℓ(θ)=i=1∑n(log2πσ1−2σ2(y(i)−θTx(i))2)
去掉常数项(与 θ \theta θ 无关的项):
ℓ ( θ ) = − 1 2 σ 2 ∑ i = 1 n ( y ( i ) − θ T x ( i ) ) 2 + 常数 \ell(\theta) = -\frac{1}{2\sigma^{2}} \sum_{i=1}^{n} (y^{(i)} - \theta^T x^{(i)})^{2} + \text{常数} ℓ(θ)=−2σ21i=1∑n(y(i)−θTx(i))2+常数
最大化对数似然:
最大化对数似然函数等价于最小化以下目标函数:
J ( θ ) = 1 2 ∑ i = 1 n ( y ( i ) − θ T x ( i ) ) 2 J(\theta) = \frac{1}{2} \sum_{i=1}^{n} (y^{(i)} - \theta^T x^{(i)})^{2} J(θ)=21i=1∑n(y(i)−θTx(i))2
Q:什么越大什么越小,乱七八糟的?
A:因为 ℓ ( θ ) \ell(\theta) ℓ(θ) 是一个常数减去对数似然,我们要让 ℓ ( θ ) \ell(\theta) ℓ(θ) 最大的话,就是让对数似然最小,因为对数似然恒正。
这就是最小二乘法的目标函数,以上就是通过最大似然估计,推导出最小二乘法。
1.5 目标函数
目标函数 J ( θ ) J(\theta) J(θ) 定义为:
J ( θ ) = 1 2 ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 = 1 2 ( X θ − y ) T ( X θ − y ) J(\theta) = \frac{1}{2} \sum_{i=1}^m (h_\theta(x^{(i)}) - y^{(i)})^2 = \frac{1}{2} (X\theta - y)^T (X\theta - y) J(θ)=21i=1∑m(hθ(x(i))−y(i))2=21(Xθ−y)T(Xθ−y)
这里, h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 是假设函数, y ( i ) y^{(i)} y(i) 是实际值, X X X 是特征矩阵, θ \theta θ 是参数向量。目标函数衡量了预测值与实际值之间的误差。
这个目标函数实际上就是最小二乘法的目标函数,其中 θ T x ( i ) \theta^T x^{(i)} θTx(i) 就是假设函数 h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i))
h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 是为了表示模型的预测值,它与实际值 y ( i ) y^{(i)} y(i) 之间的差异(误差)被用来构建目标函数。
求偏导:
为了找到最小化目标函数的参数 θ \theta θ,对 J ( θ ) J(\theta) J(θ) 求偏导并令其等于零。首先展开目标函数:
J ( θ ) = 1 2 ( X θ − y ) T ( X θ − y ) J(\theta) = \frac{1}{2} (X\theta - y)^T (X\theta - y) J(θ)=21(Xθ−y)T(Xθ−y)
展开后得到:
J ( θ ) = 1 2 ( θ T X T − y T ) ( X θ − y ) J(\theta) = \frac{1}{2} (\theta^T X^T - y^T)(X\theta - y) J(θ)=21(θTXT−yT)(Xθ−y)
进一步展开:
J ( θ ) = 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) J(\theta) = \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) J(θ)=21(θTXTXθ−θTXTy−yTXθ+yTy)
对 θ \theta θ 求偏导:
∇ θ J ( θ ) = ∇ θ ( 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) ) \nabla_\theta J(\theta) = \nabla_\theta \left( \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) \right) ∇θJ(θ)=∇θ(21(θTXTXθ−θTXTy−yTXθ+yTy))
矩阵微分,得到:
∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − ( y T X ) T ) = X T X θ − X T y \nabla_\theta J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - (y^T X)^T) = X^T X\theta - X^T y ∇θJ(θ)=21(2XTXθ−XTy−(yTX)T)=XTXθ−XTy
偏导等于零:
为了找到最小值,令偏导等于零:
X T X θ − X T y = 0 X^T X\theta - X^T y = 0 XTXθ−XTy=0
解这个方程得到 θ \theta θ 的解析解:
θ = ( X T X ) − 1 X T y \theta = (X^T X)^{-1} X^T y θ=(XTX)−1XTy
这个解称为正规方程Normal Equation,它给出了最小化目标函数的参数 θ \theta θ 的值
注:这里我们发现有逆矩阵,众所周知,不是所有矩阵均可逆,那么解不出 θ \theta θ 怎么办?接下来就引出了我们经常说的优化算法。
2 梯度下降
梯度下降是一种优化方法,用来找到目标函数的最小值。你可以把它想象成“下山”:你站在山顶上,目标是找到山谷的最低点。梯度下降就是一步步往下走,直到你到达最低点。
为什么要用梯度下降?
- 直接求解不一定可行:有时候目标函数很复杂,甚至没有直接的数学公式可以解出来(比如非线性问题)。线性回归是个特例,可以直接用公式求解,但大多数情况下不行。
- 机器学习的套路:我们给机器一堆数据,告诉它“什么样的学习方式是对的”(目标函数),然后让它自己慢慢调整参数,找到最好的结果。
梯度下降怎么工作?
- 一口吃不成胖子:你不能一步就从山顶跳到山谷,而是需要一步步往下走。每一步都根据当前的位置,找到最陡的下坡方向(梯度),然后往那个方向迈一小步。
- 迭代优化:每次优化一点点,累积起来就是一个很大的进步。比如,你第一次走一步,第二次再走一步,重复很多次后,你就离目标越来越近了。
举例
想象你在玩一个游戏,目标是找到地图上的最低点(比如宝藏的位置)。你每走一步,都会看看周围哪个方向是下坡,然后往那个方向走。梯度下降就是这个过程:
- 目标函数:地图的高度(越低越好)。
- 梯度:你每走一步时,判断哪个方向是下坡。
- 迭代:一步步走,直到找到最低点。
2.1 目标函数
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 J(\theta) = \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 J(θ)=2m1i=1∑m(hθ(x(i))−y(i))2
目标:我们希望模型的预测值 h θ ( x ( i ) ) h_\theta(x^{(i)}) hθ(x(i)) 尽可能接近实际值 y ( i ) y^{(i)} y(i)。
平方误差: ( h θ ( x ( i ) ) − y ( i ) ) 2 (h_\theta(x^{(i)}) - y^{(i)})^2 (hθ(x(i))−y(i))2 表示单个样本的预测误差。平方的作用是:消除正负误差的抵消(比如 − 2 -2 −2 和 2 2 2 的误差平方后都是 4 4 4)。对大误差给予更大的惩罚。
均值: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1∑i=1m 表示对所有样本的误差取平均,反映模型的整体误差。
1 2 \frac{1}{2} 21:为了简化求导时的计算,因为对平方项求导后会多出一个 2 2 2, 1 2 \frac{1}{2} 21 可以抵消这个 2 2 2。
2.2 批量梯度下降(Batch Gradient Descent)
梯度公式:
∂ J ( θ ) ∂ θ j = − 1 m ∑ i = 1 m ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) \frac{\partial J(\theta)}{\partial \theta_j} = -\frac{1}{m} \sum_{i=1}^{m} (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} ∂θj∂J(θ)=−m1i=1∑m(y(i)−hθ(x(i)))xj(i)
梯度:梯度是目标函数对参数 θ j \theta_j θj 的偏导数,表示目标函数在 θ j \theta_j θj 方向上的变化率。
误差项: ( y ( i ) − h θ ( x ( i ) ) ) (y^{(i)} - h_\theta(x^{(i)})) (y(i)−hθ(x(i))) 是第 i i i 个样本的预测误差。
特征值: x j ( i ) x_j^{(i)} xj(i) 是第 i i i 个样本的第 j j j 个特征值,表示误差对 θ j \theta_j θj 的影响。
均值: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1∑i=1m 表示对所有样本的梯度取平均,确保更新方向是全局最优的。
参数更新公式:
θ j : = θ j − α ⋅ ∂ J ( θ ) ∂ θ j \theta_j := \theta_j - \alpha \cdot \frac{\partial J(\theta)}{\partial \theta_j} θj:=θj−α⋅∂θj∂J(θ)
负梯度方向:梯度 ∂ J ( θ ) ∂ θ j \frac{\partial J(\theta)}{\partial \theta_j} ∂θj∂J(θ) 表示目标函数上升最快的方向,因此负梯度方向是下降最快的方向。
学习率 α \alpha α:控制每次更新的步长。如果 α \alpha α 太大,可能会跳过最优解;如果 α \alpha α 太小,收敛速度会变慢。
容易得到最优解,但是每次要考虑所有的样本,所以很慢。
2.3 随机梯度下降(Stochastic Gradient Descent, SGD)
参数更新公式:
θ j : = θ j − α ⋅ ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) \theta_j := \theta_j - \alpha \cdot (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} θj:=θj−α⋅(y(i)−hθ(x(i)))xj(i)
单个样本:每次只使用一个样本 ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i)) 计算梯度。
随机性:由于每次只用一个样本,梯度方向可能不完全准确,但计算速度快。
更新方向: ( y ( i ) − h θ ( x ( i ) ) ) x j ( i ) (y^{(i)} - h_\theta(x^{(i)}))x_j^{(i)} (y(i)−hθ(x(i)))xj(i) 是单个样本的梯度,表示当前样本对参数 θ j \theta_j θj 的影响。
每次找一个样本,迭代速度快,但不一定每次都朝着收敛的方向(说直白点就是看命)
2.4 小批量梯度下降(Mini-batch Gradient Descent)
参数更新公式:
θ j : = θ j − α ⋅ 1 b ∑ k = i i + b − 1 ( h θ ( x ( k ) ) − y ( k ) ) x j ( k ) \theta_j := \theta_j - \alpha \cdot \frac{1}{b} \sum_{k=i}^{i+b-1} (h_\theta(x^{(k)}) - y^{(k)})x_j^{(k)} θj:=θj−α⋅b1k=i∑i+b−1(hθ(x(k))−y(k))xj(k)
小批量样本:每次使用 b b b 个样本计算梯度,既不是全部样本,也不是单个样本。
均值: 1 b ∑ k = i i + b − 1 \frac{1}{b} \sum_{k=i}^{i+b-1} b1∑k=ii+b−1 表示对小批量样本的梯度取平均,平衡了批量梯度下降的稳定性和随机梯度下降的效率。
更新方向: ( h θ ( x ( k ) ) − y ( k ) ) x j ( k ) (h_\theta(x^{(k)}) - y^{(k)})x_j^{(k)} (hθ(x(k))−y(k))xj(k) 是小批量样本的梯度,表示当前小批量样本对参数 θ j \theta_j θj 的影响。
b a t c h batch batch 一般有32、64、128、256;batch越大结果越精确,但是会更慢。所以需要自己去权衡,一般越大越好,很少有低于64的。
每次更新选择一小部分数据来算,比较合理的
补充 对 θ \theta θ 求偏导
怕之后会忘记这部分的知识,笔者在这里就补充一下推导过程。
对 θ \theta θ 求偏导:
∇ θ J ( θ ) = ∇ θ ( 1 2 ( θ T X T X θ − θ T X T y − y T X θ + y T y ) ) \nabla_\theta J(\theta) = \nabla_\theta \left( \frac{1}{2} (\theta^T X^T X\theta - \theta^T X^T y - y^T X\theta + y^T y) \right) ∇θJ(θ)=∇θ(21(θTXTXθ−θTXTy−yTXθ+yTy))
矩阵微分,得到:
∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − ( y T X ) T ) = X T X θ − X T y \nabla_\theta J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - (y^T X)^T) = X^T X\theta - X^T y ∇θJ(θ)=21(2XTXθ−XTy−(yTX)T)=XTXθ−XTy
我们需要计算目标函数 J ( θ ) J(\theta) J(θ) 对 θ \theta θ 的偏导数 ∇ θ J ( θ ) \nabla_{\theta} J(\theta) ∇θJ(θ)。具体步骤如下:
对 J ( θ ) J(\theta) J(θ) 中的每一项分别求导:
第一项: θ T X T X θ \theta^T X^T X\theta θTXTXθ
关于 θ \theta θ 的二次型:
∇ θ ( θ T X T X θ ) = 2 X T X θ \nabla_{\theta} (\theta^T X^T X\theta) = 2X^T X\theta ∇θ(θTXTXθ)=2XTXθ
第二项: − θ T X T y -\theta^T X^T y −θTXTy
关于 θ \theta θ 的线性项:
∇ θ ( − θ T X T y ) = − X T y \nabla_{\theta} (-\theta^T X^T y) = -X^T y ∇θ(−θTXTy)=−XTy
第三项: − y T X θ -y^T X\theta −yTXθ
关于 θ \theta θ 的线性项:
∇ θ ( − y T X θ ) = − X T y \nabla_{\theta} (-y^T X\theta) = -X^T y ∇θ(−yTXθ)=−XTy
(因为 y T X θ y^T X\theta yTXθ 是一个标量,其转置等于自身,即 y T X θ = ( y T X θ ) T = θ T X T y y^T X\theta = (y^T X\theta)^T = \theta^T X^T y yTXθ=(yTXθ)T=θTXTy)
第四项: y T y y^T y yTy
这是一个常数项,与 θ \theta θ 无关:
∇ θ ( y T y ) = 0 \nabla_{\theta} (y^T y) = 0 ∇θ(yTy)=0
将上述各项的导数合并:
∇ θ J ( θ ) = 1 2 ( 2 X T X θ − X T y − X T y + 0 ) \nabla_{\theta} J(\theta) = \frac{1}{2} (2X^T X\theta - X^T y - X^T y + 0) ∇θJ(θ)=21(2XTXθ−XTy−XTy+0)
简化后:
∇ θ J ( θ ) = 1 2 ( 2 X T X θ − 2 X T y ) = X T X θ − X T y \nabla_{\theta} J(\theta) = \frac{1}{2} (2X^T X\theta - 2X^T y) = X^T X\theta - X^T y ∇θJ(θ)=21(2XTXθ−2XTy)=XTXθ−XTy
相关文章:
从零推导线性回归:最小二乘法与梯度下降的数学原理
欢迎来到我的主页:【Echo-Nie】 本篇文章收录于专栏【机器学习】 本文所有内容相关代码都可在以下仓库中找到: Github-MachineLearning 1 线性回归 1.1 什么是线性回归 线性回归是一种用来预测和分析数据之间关系的工具。它的核心思想是找到一条直…...
OpenSIPS-从安装部署开始认识一个组件
前期讲到了Kamailio,它是一个不错的开源SIP(Session Initiation Protocol)服务器,主要用于构建高效的VoIP(Voice over IP)平台以及即时通讯服务。但是在同根同源(OpenSER)的分支上&a…...
数据结构(树)
每一个节点包含:父节点地址 值 左子节点地址 右子节点地址 如果一个节点不含有:父节点地址或左子节点地址 右子节点地址就记为null 二叉树 度:每一个节点的子节点数量 二叉树中,任意节点的度<2 树的结构: 二叉查…...
[Dialog屏幕开发] 设置搜索帮助
阅读该篇文章之前,可先阅读下述资料 [Dialog屏幕开发] 屏幕绘制(使用向导创建Tabstrip Control标签条控件)https://blog.csdn.net/Hudas/article/details/145372195?spm1001.2014.3001.5501https://blog.csdn.net/Hudas/article/details/145372195?spm1001.2014.…...
C语言从入门到进阶
视频:https://www.bilibili.com/video/BV1Vm4y1r7jY?spm_id_from333.788.player.switch&vd_sourcec988f28ad9af37435316731758625407&p23 //枚举常量 enum Sex{MALE,FEMALE,SECRET };printf("%d\n", MALE);//0 printf("%d\n", FEMALE…...
Node.js下载安装及环境配置教程 (详细版)
Node.js:是一个基于 Chrome V8 引擎的 JavaScript 运行时,用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型,使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序࿰…...
Mac Electron 应用签名(signature)和公证(notarization)
在MacOS 10.14.5之后,如果应用没有在苹果官方平台进行公证notarization(我们可以理解为安装包需要审核,来判断是否存在病毒),那么就不能被安装。当然现在很多人的解决方案都是使用sudo spctl --master-disable,取消验证模式&#…...
redis安装 windows版本
下载 github下载5.x版本redis 安装以及启动 解压文件,目标如下 进入cmd至安装路径 执行如下命令启动redis redis-server.exe redis.windows.conf 进入redis,另外启动cmd在当前目录执行进入redis 服务 redis-cli 测试命令 至此安装成功,但是这只是…...
关联传播和 Python 和 Scikit-learn 实现
文章目录 一、说明二、什么是 Affinity Propagation。2.1 先说Affinity 传播的工作原理2.2 更多细节2.3 传播两种类型的消息2.4 计算责任和可用性的分数2.4.1 责任2.4.2 可用性分解2.4.3 更新分数:集群是如何形成的2.4.4 估计集群本身的数量。 三、亲和力传播的一些…...
若依基本使用及改造记录
若依框架想必大家都了解得不少,不可否认这是一款及其简便易用的框架。 在某种情况下(比如私活)使用起来可谓是快得一匹。 在这里小兵结合自身实际使用情况,记录一下我对若依框架的使用和改造情况。 一、源码下载 前往码云进行…...
c语言网 1127 尼科彻斯定理
原题 题目描述 验证尼科彻斯定理,即:任何一个整数m的立方都可以写成m个连续奇数之和。 输入格式 任一正整数 输出格式 该数的立方分解为一串连续奇数的和 样例输入 13 样例输出 13*13*132197157159161163165167169171173175177179181 #include<ios…...
能说说MyBatis的工作原理吗?
大家好,我是锋哥。今天分享关于【Redis为什么这么快?】面试题。希望对大家有帮助; 能说说MyBatis的工作原理吗? MyBatis 是一款流行的持久层框架,它通过简化数据库操作,帮助开发者更高效地与数据库进行交互。MyBatis…...
卡特兰数学习
1,概念 卡特兰数(英语:Catalan number),又称卡塔兰数,明安图数。是组合数学中一种常出现于各种计数问题中的数列。它在不同的计数问题中频繁出现。 2,公式 卡特兰数的递推公式为:f(…...
【算法】多源 BFS
多源 BFS 1.矩阵距离2.刺杀大使 单源最短路问题 vs 多源最短路问题 当问题中只存在一个起点时,这时的最短路问题就是单源最短路问题。当问题中存在多个起点而不是单一起点时,这时的最短路问题就是多源最短路问题。 多源 BFS:多源最短路问题…...
解锁数字经济新动能:探寻 Web3 核心价值
随着科技的快速发展,我们正迈入一个全新的数字时代,Web3作为这一时代的核心构成之一,正在为全球数字经济带来革命性的变革。本文将探讨Web3的核心价值,并如何推动数字经济的新动能。 Web3是什么? Web3,通常…...
CAN总线数据采集与分析
CAN总线数据采集与分析 目录 CAN总线数据采集与分析1. 引言2. 数据采集2.1 数据采集简介2.2 数据采集实现 3. 数据分析3.1 数据分析简介3.2 数据分析实现 4. 数据可视化4.1 数据可视化简介4.2 数据可视化实现 5. 案例说明5.1 案例1:数据采集实现5.2 案例2࿱…...
appium自动化环境搭建
一、appium介绍 appium介绍 appium是一个开源工具、支持跨平台、用于自动化ios、安卓手机和windows桌面平台上面的原生、移动web和混合应用,支持多种编程语言(python,java,Ruby,Javascript、PHP等) 原生应用和混合应用…...
二叉树高频题目——下——不含树型dp
一,普通二叉树上寻找两个节点的最近的公共祖先 1,介绍 LCA(Lowest Common Ancestor,最近公共祖先)是二叉树中经常讨论的一个问题。给定二叉树中的两个节点,它的LCA是指这两个节点的最低(最深&…...
Java并发学习:进程与线程的区别
进程的基本原理 一个进程是一个程序的一次启动和执行,是操作系统程序装入内存,给程序分配必要的系统资源,并且开始运行程序的指令。 同一个程序可以多次启动,对应多个进程,例如同一个浏览器打开多次。 一个进程由程…...
【ProxyBroker】用Python打破网络限制的利器
ProxyBroker 1. 什么是ProxyBroker2. ProxyBroker的功能3. ProxyBroker的优势4. ProxyBroker的使用方法5. ProxyBroker的应用场景6.结语项目地址: 1. 什么是ProxyBroker ProxyBroker是一个开源工具,它可以异步地从多个来源找到公共代理,并同…...
Gradle buildSrc模块详解:集中管理构建逻辑的利器
文章目录 buildSrc模块二 buildSrc的使命三 如何使用buildSrc1. 创建目录结构2. 配置buildSrc的构建脚本3. 编写共享逻辑4. 在模块中引用 四 典型使用场景1. 统一依赖版本管理2. 自定义Gradle任务 3. 封装通用插件4. 扩展Gradle API 五 注意事项六 与复合构建(Compo…...
2025数学建模美赛|F题成品论文
国家安全政策与网络安全 摘要 随着互联网技术的迅猛发展,网络犯罪问题已成为全球网络安全中的重要研究课题,且网络犯罪的形式和影响日益复杂和严重。本文针对网络犯罪中的问题,基于多元回归分析和差异中的差异(DiD)思…...
【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.10 文本数据炼金术:从CSV到结构化数组
1.10 《文本数据炼金术:从CSV到结构化数组》 目录 #mermaid-svg-TNkACjzvaSXnULaB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-TNkACjzvaSXnULaB .error-icon{fill:#552222;}#mermaid-svg-TNkACjzva…...
「蓝桥杯题解」蜗牛(Java)
题目链接 这道题我感觉状态定义不太好想,需要一定的经验 import java.util.*; /*** 蜗牛* 状态定义:* dp[i][0]:到达(x[i],0)最小时间* dp[i][1]:到达 xi 上方的传送门最小时间*/public class Main {static Scanner in new Scanner(System.in);static f…...
基于springboot+vue的流浪动物救助系统的设计与实现
开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…...
51单片机开发:IO扩展(串转并)实验
实验目标:通过扩展口从下至上依次点亮点阵屏的行。 下图左边是74HC595 模块电路图,右边是点阵屏电图图。 SRCLK上升沿时,将SER输入的数据移送至内部的移位寄存器。 RCLK上升沿时,将数据从移位寄存器移动至存储寄存器,…...
JAVA实战开源项目:购物商城网站(Vue+SpringBoot) 附源码
本文项目编号 T 032 ,文末自助获取源码 \color{red}{T032,文末自助获取源码} T032,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…...
C++学习——认识和与C的区别
目录 前言 一、什么是C 二、C关键字 三、与C语言不同的地方 3.1头文件 四、命名空间 4.1命名空间的概念写法 4.2命名空间的访问 4.3命名空间的嵌套 4.4命名空间在实际中的几种写法 五、输入输出 5.1cout 5.2endl 5.3cin 总结 前言 开启新的篇章,这里…...
Open FPV VTX开源之ardupilot双OSD配置摄像头
Open FPV VTX开源之ardupilot双OSD配置 1 源由2. 分析3. 配置4. 解决办法5. 参考资料 1 源由 鉴于笔者这台Mark4 Copter已经具备一定的历史,目前机载了两个FPV摄像头: 模拟摄像头数字摄像头(OpenIPC) 测试场景: 从稳定性的角度࿱…...
基于微信小程序高校课堂教学管理系统 课堂管理系统微信小程序(源码+文档)
目录 一.研究目的 二.需求分析 三.数据库设计 四.系统页面展示 五.免费源码获取 一.研究目的 困扰管理层的许多问题当中,高校课堂教学管理也是不敢忽视的一块。但是管理好高校课堂教学又面临很多麻烦需要解决,如何在工作琐碎,记录繁多的情况下将高校课堂教学的当前情况反…...
unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?
需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分,则为真。 如果一个点在可步行导航网表面之上或之下,在任何距离,如果它不在更近的不可步行节点之上 / 之下,则认为它在导航网上。 使用方法 Ast…...
三星手机人脸识别解锁需要点击一下电源键,能够不用点击直接解锁吗
三星手机的人脸识别解锁功能默认需要滑动或点击屏幕来解锁。这是为了增强安全性,防止误解锁的情况。如果希望在检测到人脸后直接进入主界面,可以通过以下设置调整: 打开设置: 进入三星手机的【设置】应用。 进入生物识别和安全&a…...
read+write实现:链表放到文件+文件数据放到链表 的功能
思路 一、 定义链表: 1 节点结构(数据int型) 2 链表操作(创建节点、插入节点、释放链表、打印链表)。 二、链表保存到文件 1打开文件 2遍历链表、写文件: 遍历链表,write()将节点数据写入文件。…...
猫怎么分公的母的?
各位铲屎官们,是不是刚领养了一只小猫咪,却分不清它是公是母?别急,今天就来给大家好好揭秘,如何轻松辨别猫咪的性别,让你不再为“它”是“他”还是“她”而烦恼! 一、观察生殖器位置 最直接的方…...
为何SAP S4系统中要设置MRP区域?MD04中可否同时显示工厂级、库存地点级的数据?
【SAP系统PP模块研究】 一、物料主数据的MRP区域设置 SAP ECC系统中想要指定不影响MRP运算的库存地点,是针对库存地点设置MRP标识,路径为:SPRO->生产->物料需求计划->计划->定义每一个工厂的存储地点MRP,如下图所示: 另外,在给物料主数据MMSC扩充库存地点时…...
Redis for AI
Redis存储和索引语义上表示非结构化数据(包括文本通道、图像、视频或音频)的向量嵌入。将向量和关联的元数据存储在哈希或JSON文档中,用于索引和查询。 Redis包括一个高性能向量数据库,允许您对向量嵌入执行语义搜索。可以通过过…...
初阶2 类与对象
本章重点 上篇1.面向过程和面向对象初步认识2.类的引入---结构体3.类的定义3.1 语法3.2 组成3.3 定义类的两种方法: 4.类的访问限定符及封装4.1 访问限定符4.2封装---面向对象的三大特性之一 5.类的作用域6.类的实例化7.类对象模型7.1 如何计算类对象的大小 8.this指…...
kafka-部署安装
一. 简述: Kafka 是一个分布式流处理平台,常用于构建实时数据管道和流应用。 二. 安装部署: 1. 依赖: a). Java:Kafka 需要 Java 8 或更高版本。 b). zookeeper: #tar fxvz zookeeper-3.7.0.tar.gz #…...
深入探讨防抖函数中的 this 上下文
深入剖析防抖函数中的 this 上下文 最近我在研究防抖函数实现的时候,发现一个耗费脑子的问题,出现了令我困惑的问题。接下来,我将通过代码示例,深入探究这些现象背后的原理。 示例代码 function debounce(fn, delay) {let time…...
人工智能丨Midscene:让UI自动化测试变得更简单
在这个数字化时代,每一个细节的优化都可能成为产品脱颖而出的关键。而对于测试人员来说,确保产品界面的稳定性和用户体验的流畅性至关重要。今天,我要向大家介绍一款名为Midscene的工具,它利用自然语言处理(NLP&#x…...
【数据结构】_链表经典算法OJ(力扣版)
目录 1. 移除链表元素 1.1 题目描述及链接 1.2 解题思路 1.3 程序 2. 反转链表 2.1 题目描述及链接 2.2 解题思路 2.3 程序 3. 链表的中间结点 3.1 题目描述及链接 3.2 解题思路 3.3 程序 1. 移除链表元素 1.1 题目描述及链接 原题链接:203. 移除链表…...
DeepSeek-R1技术报告速读
春节将至,DeepSeek又出王炸!DeepSeek-R1系列重磅开源。本文对其技术报告做简单解读。 话不多说,show me the benchmark。从各个高难度benchmark结果来看,DeepSeek-R1已经比肩OpenAI-o1-1217,妥妥的第一梯队推理模型。…...
560. 和为 K 的子数组
【题目】:560. 和为 K 的子数组 方法1. 前缀和 class Solution { public:int subarraySum(vector<int>& nums, int k) {int res 0;int n nums.size();vector<int> preSum(n 1, 0); // 下标从1开始存储for(int i 0; i < n; i) {preSum[i 1]…...
鸿蒙仓颉环境配置(仓颉SDK下载,仓颉VsCode开发环境配置,仓颉DevEco开发环境配置)
目录 1)仓颉的SDK下载 1--进入仓颉的官网 2--点击图片中的下载按钮 3--在新跳转的页面点击即刻下载 4--下载 5--找到你们自己下载好的地方 6--解压软件 2)仓颉编程环境配置 1--找到自己的根目录 2--进入命令行窗口 3--输入 envsetup.bat 4--验证是否安…...
NodeJs / Bun 分析文件编码 并将 各种编码格式 转为 另一个编码格式 ( 比如: GB2312→UTF-8, UTF-8→GB2312)
版本号 "iconv-lite": "^0.6.3", "chardet": "^2.0.0",github.com/runk/node-chardet 可以识别文本是 哪种编码 ( 大文件截取一部分进行分析,速度比较快 ) let bun_file_obj Bun.file(full_file_path) let file_bytes await bun_f…...
Java学习笔记(二十五)
1 Kafka Raft 简单介绍 Kafka Raft (KRaft) 是 Kafka 引入的一种新的分布式共识协议,用于取代之前依赖的 Apache ZooKeeper 集群管理机制。从 Kafka 2.8 开始,Kafka 开始支持基于 KRaft 的独立模式,计划在未来完全移除 ZooKeeper 的依赖。 1…...
Baklib如何结合内容中台与人工智能技术实现数字化转型
内容概要 在当前快速发展的数字环境中,企业面临着转型的紧迫性与挑战,尤其是在内容管理和用户互动的领域。内容中台作为一种集成化的解决方案,不仅能够提高企业在资源管理方面的效率,还能够为企业提供一致性和灵活性的内容分发机…...
git困扰的问题
.gitignore中添加的某个忽略文件并不生效 把某些目录或文件加入忽略规则,按照上述方法定义后发现并未生效, gitignore只能忽略那些原来没有被追踪的文件,如果某些文件已经被纳入了版本管理中,则修改.gitignore是无效的。 解决方…...
第05章 12 可视化热量流线图一例
下面是一个使用VTK(Visualization Toolkit)和C编写的示例代码,展示如何在一个厨房模型中可视化热量流线图,并按照热量传递速度着色显示。这个示例假设你已经安装了VTK库,并且你的开发环境已经配置好来编译和运行VTK程序…...
Vue组件开发-使用 html2canvas 和 jspdf 库实现PDF文件导出 设置页面大小及方向
在 Vue 项目中实现导出 PDF 文件、调整文件页面大小和页面方向的功能,使用 html2canvas 将 HTML 内容转换为图片,再使用 jspdf 把图片添加到 PDF 文件中。以下是详细的实现步骤和代码示例: 步骤 1:安装依赖 首先,在项…...