梯度下降算法
矩阵形式证明
设输入\(X\in R^{m\times n}\)为\(m\)组数据,每组数据\(x_i(i\in[1,m]) \in R^{1\times n}\)有\(n\)个特征: \[ X=\begin{bmatrix} x_{11}&\cdots & x_{1n}\\ x_{21}&\cdots & x_{2n}\\ \vdots&\vdots&\vdots\\ x_{m1}&\cdots & x_{mn} \end{bmatrix}= \begin{bmatrix} x_1\\ x_2\\ \vdots\\ x_m\\ \end{bmatrix} \] 权重为\(w \in R^{n \times 1}\),预测输出\(\hat y \in R^{m\times 1}\)(这里为了简化,未考虑偏置参数\(bias\)) \[ \hat y=Xw=\begin{bmatrix} x_1\\ x_2\\ \vdots\\ x_m\\ \end{bmatrix}w= \begin{bmatrix} x_1 w\\ x_2 w\\ \vdots\\ x_m w\\ \end{bmatrix} \] 实际输出值\(y \in R^{m\times 1}\),均方误差为(\(\frac{1}{2}\)是为了后面约分): \[ \mathbb{MSE}=\frac{1}{2}(Xw-y)^{T}(Xw-y)=\frac{1}{2}(w^T X^T Xw-w^T X^T y-y^T Xw+y^T y) \] 注意到\(\mathbb{MSE}\in R^{1\times 1}=\mathbb{SCALAR}\)为标量,则: \[ \mathbb{MSE}=tr(\mathbb{MSE}) \] 求梯度有: \[ \frac{\partial \mathbb{MSE}}{\partial w}=\frac{\partial tr(\mathbb{MSE})}{\partial w} =\frac{1}{2}\frac{\partial(tr(w^T X^T Xw-w^T X^T y-y^T Xw+y^T y))}{\partial w} \]
\[ =\frac{1}{2}(\frac{\partial(tr(w^T X^T Xw))}{\partial w}-\frac{\partial(tr(w^T X^T y))}{\partial w} -\frac{\partial(tr(y^T Xw))}{\partial w}+\frac{\partial(tr(y^T y))}{\partial w}) \]
由矩阵的迹的相关性质可以很容易计算出: \[ \frac{\partial(tr(w^T X^T Xw))}{\partial w} =\frac{\partial(tr(w^T (X^T X)w))}{\partial w} =X^TXw+(X^TX)^Tw=2X^TXw \]
\[ \frac{\partial(tr(y^T Xw))}{\partial w}=\frac{\partial(tr(w^T X^Ty))}{\partial w}=X^Ty \]
\[ \frac{\partial(tr(y^T Xw))}{\partial w} =\frac{\partial(tr((y^T Xw)^T))}{\partial w} =\frac{\partial(tr(w^T X^Ty))}{\partial w}=X^Ty \]
因为\(y、y^T\)与\(w\)无关,即: \[ \frac{\partial(tr(y^T y))}{\partial w}=0 \] 因此: \[ \frac{\partial \mathbb{MSE}}{\partial w}=\frac{1}{2}(2X^TXw-X^Ty-X^Ty)=X^T(Xw-y) \] 令\(\frac{\partial \mathbb{MSE}}{\partial w}=0\),可以求解\(w\)(前面线性回归中提到的Normal Equation) \[ X^T(Xw-y)=0 \Longrightarrow X^TXw=X^T y \Longrightarrow w=(X^{\top}X)^{-1}X^{\top}y \] 用梯度下降的方式求解\(w\),设学习率为\(\mathbb{LR}\),迭代次数为\(\mathbb{EPOCH}\),则:
\(\mathbb{INIT}:\mathbb{C}=1\)
\(\mathbb{LOOP(\mathbb{WHILE}:\mathbb{C} \le \mathbb{EPOCH})}:\)
\(\mathbb{C}=\mathbb{C}+1\)
\(w=w-\mathbb{LR} \cdot \frac{\partial \mathbb{MSE}}{\partial w}=w-\mathbb{LR}\cdot X^T(Xw-y)\)
算法实现
用python的numpy库简单实现了这个算法:
import numpy as np
class Model:
def __init__(self, X, y):
"""
:param X: np.mat,source data,
if you wan to calculate normal equation,it can't be singular matrix
:param y: np.mat,target data
"""
self.X = X
self.y = y
self.w = np.mat(np.random.normal(loc=0.0, scale=1.0, size=[self.X.shape[1]])).squeeze().transpose()
self.normal_equation = np.mat(np.random.normal(loc=0.0, scale=1.0, size=[self.X.shape[1], 1]))
def calculate_gradient_down(self, epoch=100, lr=1.0e-5):
"""
:param epoch: num of iteration
:param lr: learning rate
:return:
"""
for i in range(epoch):
self.w -= lr * self.X.transpose() * (self.X * self.w - self.y)
def calculate_normal_equation(self):
self.normal_equation = (self.X.transpose() * self.X).I * self.X.transpose() * self.y
def predict(self, X, use_normal_equation=False):
"""
:param X: np.mat,the data for prediction
:param use_normal_equation: use normal equation for prediction or not
:return:prediction
"""
if not use_normal_equation:
return X * self.w
else:
return X * self.normal_equation
def test():
X = np.mat(np.random.randint(0, 100, size=[20, 3]), dtype=np.float64)
w = np.mat([3, 2, 1], dtype=np.float64).squeeze(axis=0).transpose()
y = X * w
Xtest = np.mat(np.random.randint(0, 100, size=[20, 3]), dtype=np.float64)
ytest = Xtest * w
model = Model(X, y)
model.calculate_gradient_down()
model.calculate_normal_equation()
print("w:", model.w.transpose())
print("normal_equation:", model.normal_equation.transpose())
err1 = model.predict(Xtest) - ytest
err2 = model.predict(Xtest, use_normal_equation=True) - ytest
print("w predict error:", err1.transpose() * err1)
print("normal_equation predict error:", err1.transpose() * err2)
if __name__ == '__main__':
"""
w: [[3.00000005 1.99999983 1.00000014]]
normal_equation: [[3. 2. 1.]]
w predict error: [[6.36796808e-10]]
normal_equation predict error: [[5.57834374e-18]]
"""
test()
- 本文链接:https://morisa66.github.io/2021/02/02/grad_down/
- 版权声明:本博客所有文章除特别声明外,均默认采用 许可协议。

